diff --git a/Makefile b/Makefile index 4478ac5d2..de73b3abf 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ all: vset.so vset.xo: redismodule.h -vset.so: vset.xo hnsw.xo +vset.so: vset.xo hnsw.xo cJSON.xo $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc # Example sources / objects diff --git a/expr.c b/expr.c index 07ec1ec0a..fc271f802 100644 --- a/expr.c +++ b/expr.c @@ -12,6 +12,12 @@ #include #include "cJSON.h" +#ifdef TEST_MAIN +#define RedisModule_Alloc malloc +#define RedisModule_Realloc realloc +#define RedisModule_Free free +#endif + #define EXPR_TOKEN_EOF 0 #define EXPR_TOKEN_NUM 1 #define EXPR_TOKEN_STR 2 @@ -121,8 +127,8 @@ struct { /* ================================ Expr token ============================== */ void exprFreeToken(exprtoken *t) { if (t == NULL) return; - if (t->heapstr != NULL) free(t->heapstr); - free(t); + if (t->heapstr != NULL) RedisModule_Free(t->heapstr); + RedisModule_Free(t); } /* ============================== Stack handling ============================ */ @@ -134,7 +140,7 @@ void exprFreeToken(exprtoken *t) { /* Initialize a new expression stack. */ void exprStackInit(exprstack *stack) { - stack->items = malloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); + stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); stack->numitems = 0; stack->allocsize = EXPR_STACK_INITIAL_SIZE; } @@ -146,7 +152,7 @@ int exprStackPush(exprstack *stack, exprtoken *token) { if (stack->numitems == stack->allocsize) { size_t newsize = stack->allocsize * 2; exprtoken **newitems = - realloc(stack->items, sizeof(exprtoken*) * newsize); + RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize); if (newitems == NULL) return 0; stack->items = newitems; stack->allocsize = newsize; @@ -175,7 +181,7 @@ exprtoken *exprStackPeek(exprstack *stack) { void exprStackFree(exprstack *stack) { for (int j = 0; j < stack->numitems; j++) exprFreeToken(stack->items[j]); - free(stack->items); + RedisModule_Free(stack->items); } /* Just reset the stack removing all the items, but leaving it in a state @@ -284,7 +290,7 @@ void exprFree(exprstate *es) { if (es == NULL) return; /* Free the original expression string. */ - if (es->expr) free(es->expr); + if (es->expr) RedisModule_Free(es->expr); /* Free all stacks. */ exprStackFree(&es->values_stack); @@ -293,7 +299,7 @@ void exprFree(exprstate *es) { exprStackFree(&es->program); /* Free the state object itself. */ - free(es); + RedisModule_Free(es); } /* Split the provided expression into a stack of tokens. Returns @@ -344,7 +350,7 @@ int exprTokenize(exprstate *es, int *errpos) { } /* Allocate and copy current token to tokens stack */ - exprtoken *token = malloc(sizeof(exprtoken)); + exprtoken *token = RedisModule_Alloc(sizeof(exprtoken)); if (!token) return 1; // OOM. *token = es->current; /* Copy the entire structure. */ @@ -445,12 +451,12 @@ int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *err * expression is returned by reference. */ exprstate *exprCompile(char *expr, int *errpos) { /* Initialize expression state. */ - exprstate *es = malloc(sizeof(exprstate)); + exprstate *es = RedisModule_Alloc(sizeof(exprstate)); if (!es) return NULL; es->expr = strdup(expr); if (!es->expr) { - free(es); + RedisModule_Free(es); return NULL; } es->p = es->expr; @@ -484,7 +490,7 @@ exprstate *exprCompile(char *expr, int *errpos) { token->token_type == EXPR_TOKEN_STR || token->token_type == EXPR_TOKEN_SELECTOR) { - exprtoken *value_token = malloc(sizeof(exprtoken)); + exprtoken *value_token = RedisModule_Alloc(sizeof(exprtoken)); if (!value_token) { if (errpos) *errpos = token->offset; exprFree(es); @@ -503,7 +509,7 @@ exprstate *exprCompile(char *expr, int *errpos) { /* Handle operators. */ if (token->token_type == EXPR_TOKEN_OP) { - exprtoken *op_token = malloc(sizeof(exprtoken)); + exprtoken *op_token = RedisModule_Alloc(sizeof(exprtoken)); if (!op_token) { if (errpos) *errpos = token->offset; exprFree(es); @@ -620,7 +626,7 @@ int exprRun(exprstate *es, char *json, size_t json_len) { // Handle selectors by calling the callback. if (t->token_type == EXPR_TOKEN_SELECTOR) { - exprtoken *result = malloc(sizeof(exprtoken)); + exprtoken *result = RedisModule_Alloc(sizeof(exprtoken)); if (result != NULL && json != NULL) { cJSON *attrib = NULL; if (parsed_json == NULL) @@ -665,14 +671,14 @@ int exprRun(exprstate *es, char *json, size_t json_len) { // Push non-operator values directly onto the stack. if (t->token_type != EXPR_TOKEN_OP) { - exprtoken *nt = malloc(sizeof(exprtoken)); + exprtoken *nt = RedisModule_Alloc(sizeof(exprtoken)); *nt = *t; exprStackPush(&es->values_stack, nt); continue; } // Handle operators. - exprtoken *result = malloc(sizeof(exprtoken)); + exprtoken *result = RedisModule_Alloc(sizeof(exprtoken)); result->token_type = EXPR_TOKEN_NUM; // Pop operands - we know we have enough from compile-time checks. diff --git a/hnsw.c b/hnsw.c index 0d48c0dce..fd284e635 100644 --- a/hnsw.c +++ b/hnsw.c @@ -780,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) { * arrays must have space for at least 'k' items. * norm_query should be set to 1 if the query vector is already * normalized, otherwise, if 0, the function will copy the vector, - * L2-normalize the copy and search using the normalized version. */ -int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + * L2-normalize the copy and search using the normalized version. + * + * If the filter_privdata callback is passed, only elements passing the + * specified filter (invoked with privdata and the value associated + * to the node as arguments) are returned. In such case, if max_candidates + * is not NULL, it represents the maximum number of nodes to explore, since + * the search may be otherwise unbound if few or no elements pass the + * filter. */ +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, hnswNode **neighbors, float *distances, uint32_t slot, - int query_vector_is_normalized) + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates) + { if (!index || !query_vector || !neighbors || k == 0) return -1; if (!index->enter_point) return 0; // Empty index. @@ -811,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, } /* Search bottom layer (the most densely populated) with ef = k */ - pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot); + pqueue *results = search_layer_with_filter( + index, &query, curr_ep, k, 0, slot, filter_callback, + filter_privdata, max_candidates); if (!results) { hnsw_free_tmp_node(&query, query_vector); return -1; @@ -831,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, return found; } +/* Wrapper to hnsw_search_with_filter() when no filter is needed. */ +int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized) +{ + return hnsw_search_with_filter(index,query_vector,k,neighbors, + distances,slot,query_vector_is_normalized, + NULL,NULL,0); +} + /* Rescan a node and update the wortst neighbor index. * The followinng two functions are variants of this function to be used * when links are added or removed: they may do less work than a full scan. */ diff --git a/hnsw.h b/hnsw.h index 3d104cc5e..5cc1b1cd2 100644 --- a/hnsw.h +++ b/hnsw.h @@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, int hnsw_search(HNSW *index, const float *query, uint32_t k, hnswNode **neighbors, float *distances, uint32_t slot, int query_vector_is_normalized); +int hnsw_search_with_filter + (HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized, + int (*filter_callback)(void *value, void *privdata), + void *filter_privdata, uint32_t max_candidates); void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); diff --git a/vset.c b/vset.c index 619a9dc0f..37ecd61fd 100644 --- a/vset.c +++ b/vset.c @@ -20,6 +20,10 @@ #include #include "hnsw.h" +// We inline directly the expression implementation here so that building +// the module is trivial. +#include "expr.c" + static RedisModuleType *VectorSetType; static uint64_t VectorSetTypeNextId = 0; @@ -561,6 +565,22 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } } +/* HNSW callback to filter items according to a predicate function + * (our FILTER expression in this case). */ +int vectorSetFilterCallback(void *value, void *privdata) { + exprstate *expr = privdata; + struct vsetNodeVal *nv = value; + if (nv->attrib == NULL) return 0; // No attributes? No match. + size_t json_len; + char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len); + #if 0 + int res = exprRun(expr,json,json_len); + printf("%s %d\n", json, res); + return res; + #endif + return exprRun(expr,json,json_len); +} + /* Common path for the execution of the VSIM command both threaded and * not threaded. Note that 'ctx' may be normal context of a thread safe * context obtained from a blocked client. The locking that is specific @@ -568,7 +588,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { * handles the HNSW locking explicitly. */ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, float *vec, unsigned long count, float epsilon, unsigned long withscores, - unsigned long ef) + unsigned long ef, exprstate *filter_expr, unsigned long filter_ef) { /* In our scan, we can't just collect 'count' elements as * if count is small we would explore the graph in an insufficient @@ -585,7 +605,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); float *distances = RedisModule_Alloc(sizeof(float)*ef); int slot = hnsw_acquire_read_slot(vset->hnsw); - unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); + unsigned int found; + if (filter_expr == NULL) { + found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); + } else { + found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef); + } hnsw_release_read_slot(vset->hnsw,slot); RedisModule_Free(vec); @@ -598,7 +623,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, for (unsigned int i = 0; i < found && i < count; i++) { if (distances[i] > epsilon) break; - RedisModule_ReplyWithString(ctx, neighbors[i]->value); + struct vsetNodeVal *nv = neighbors[i]->value; + RedisModule_ReplyWithString(ctx, nv->item); arraylen++; if (withscores) { /* The similarity score is provided in a 0-1 range. */ @@ -613,6 +639,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, RedisModule_Free(neighbors); RedisModule_Free(distances); + if (filter_expr) exprFree(filter_expr); } /* VSIM thread handling the blocked client request. */ @@ -628,6 +655,8 @@ void *VSIM_thread(void *arg) { float epsilon = *((float*)targ[4]); unsigned long withscores = (unsigned long)targ[5]; unsigned long ef = (unsigned long)targ[6]; + exprstate *filter_expr = targ[7]; + unsigned long filter_ef = (unsigned long)targ[8]; RedisModule_Free(targ[4]); RedisModule_Free(targ); @@ -635,7 +664,7 @@ void *VSIM_thread(void *arg) { RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); // Run the query. - VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); // Cleanup. RedisModule_FreeThreadSafeContext(ctx); @@ -644,7 +673,7 @@ void *VSIM_thread(void *arg) { return NULL; } -/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */ +/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModule_AutoMemory(ctx); @@ -658,6 +687,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { long long ef = 0; /* Exploration factor (see HNSW paper) */ double epsilon = 2.0; /* Max cosine distance */ + /* Things computed later. */ + long long filter_ef = 0; + exprstate *filter_expr = NULL; + /* Get key and vector type */ RedisModuleString *key = argv[1]; const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); @@ -766,6 +799,19 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); } j += 2; + } else if (!strcasecmp(opt, "FILTER") && j+1 < argc) { + RedisModuleString *exprarg = argv[j+1]; + size_t exprlen; + char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen); + int errpos; + filter_expr = exprCompile(exprstr,&errpos); + if (filter_expr == NULL) { + if ((size_t)errpos >= exprlen) errpos = 0; + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR syntax error in FILTER expression near: %s", + exprstr+errpos); + } + j += 2; } else { RedisModule_Free(vec); return RedisModule_ReplyWithError(ctx, @@ -774,6 +820,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { } int threaded_request = 1; // Run on a thread, by default. + if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes. // Disable threaded for MULTI/EXEC and Lua. if (RedisModule_GetContextFlags(ctx) & @@ -799,7 +846,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); pthread_t tid; - void **targ = RedisModule_Alloc(sizeof(void*)*7); + void **targ = RedisModule_Alloc(sizeof(void*)*9); targ[0] = bc; targ[1] = vset; targ[2] = vec; @@ -808,16 +855,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { *((float*)targ[4]) = epsilon; targ[5] = (void*)(unsigned long)withscores; targ[6] = (void*)(unsigned long)ef; + targ[7] = (void*)filter_expr; + targ[8] = (void*)(unsigned long)filter_ef; if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { pthread_rwlock_unlock(&vset->in_use_lock); RedisModule_AbortBlock(bc); RedisModule_Free(vec); RedisModule_Free(targ[4]); RedisModule_Free(targ); - return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread"); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); } } else { - VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); + VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef); } return REDISMODULE_OK;