Expr filtering: VSIM FILTER first draft.

This commit is contained in:
antirez 2025-02-22 17:10:33 +01:00
parent 025790fc50
commit 5304318335
5 changed files with 112 additions and 28 deletions

View File

@ -55,7 +55,7 @@ all: vset.so
vset.xo: redismodule.h vset.xo: redismodule.h
vset.so: vset.xo hnsw.xo vset.so: vset.xo hnsw.xo cJSON.xo
$(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc $(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc
# Example sources / objects # Example sources / objects

36
expr.c
View File

@ -12,6 +12,12 @@
#include <math.h> #include <math.h>
#include "cJSON.h" #include "cJSON.h"
#ifdef TEST_MAIN
#define RedisModule_Alloc malloc
#define RedisModule_Realloc realloc
#define RedisModule_Free free
#endif
#define EXPR_TOKEN_EOF 0 #define EXPR_TOKEN_EOF 0
#define EXPR_TOKEN_NUM 1 #define EXPR_TOKEN_NUM 1
#define EXPR_TOKEN_STR 2 #define EXPR_TOKEN_STR 2
@ -121,8 +127,8 @@ struct {
/* ================================ Expr token ============================== */ /* ================================ Expr token ============================== */
void exprFreeToken(exprtoken *t) { void exprFreeToken(exprtoken *t) {
if (t == NULL) return; if (t == NULL) return;
if (t->heapstr != NULL) free(t->heapstr); if (t->heapstr != NULL) RedisModule_Free(t->heapstr);
free(t); RedisModule_Free(t);
} }
/* ============================== Stack handling ============================ */ /* ============================== Stack handling ============================ */
@ -134,7 +140,7 @@ void exprFreeToken(exprtoken *t) {
/* Initialize a new expression stack. */ /* Initialize a new expression stack. */
void exprStackInit(exprstack *stack) { void exprStackInit(exprstack *stack) {
stack->items = malloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE); stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
stack->numitems = 0; stack->numitems = 0;
stack->allocsize = EXPR_STACK_INITIAL_SIZE; stack->allocsize = EXPR_STACK_INITIAL_SIZE;
} }
@ -146,7 +152,7 @@ int exprStackPush(exprstack *stack, exprtoken *token) {
if (stack->numitems == stack->allocsize) { if (stack->numitems == stack->allocsize) {
size_t newsize = stack->allocsize * 2; size_t newsize = stack->allocsize * 2;
exprtoken **newitems = exprtoken **newitems =
realloc(stack->items, sizeof(exprtoken*) * newsize); RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize);
if (newitems == NULL) return 0; if (newitems == NULL) return 0;
stack->items = newitems; stack->items = newitems;
stack->allocsize = newsize; stack->allocsize = newsize;
@ -175,7 +181,7 @@ exprtoken *exprStackPeek(exprstack *stack) {
void exprStackFree(exprstack *stack) { void exprStackFree(exprstack *stack) {
for (int j = 0; j < stack->numitems; j++) for (int j = 0; j < stack->numitems; j++)
exprFreeToken(stack->items[j]); exprFreeToken(stack->items[j]);
free(stack->items); RedisModule_Free(stack->items);
} }
/* Just reset the stack removing all the items, but leaving it in a state /* Just reset the stack removing all the items, but leaving it in a state
@ -284,7 +290,7 @@ void exprFree(exprstate *es) {
if (es == NULL) return; if (es == NULL) return;
/* Free the original expression string. */ /* Free the original expression string. */
if (es->expr) free(es->expr); if (es->expr) RedisModule_Free(es->expr);
/* Free all stacks. */ /* Free all stacks. */
exprStackFree(&es->values_stack); exprStackFree(&es->values_stack);
@ -293,7 +299,7 @@ void exprFree(exprstate *es) {
exprStackFree(&es->program); exprStackFree(&es->program);
/* Free the state object itself. */ /* Free the state object itself. */
free(es); RedisModule_Free(es);
} }
/* Split the provided expression into a stack of tokens. Returns /* Split the provided expression into a stack of tokens. Returns
@ -344,7 +350,7 @@ int exprTokenize(exprstate *es, int *errpos) {
} }
/* Allocate and copy current token to tokens stack */ /* Allocate and copy current token to tokens stack */
exprtoken *token = malloc(sizeof(exprtoken)); exprtoken *token = RedisModule_Alloc(sizeof(exprtoken));
if (!token) return 1; // OOM. if (!token) return 1; // OOM.
*token = es->current; /* Copy the entire structure. */ *token = es->current; /* Copy the entire structure. */
@ -445,12 +451,12 @@ int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *err
* expression is returned by reference. */ * expression is returned by reference. */
exprstate *exprCompile(char *expr, int *errpos) { exprstate *exprCompile(char *expr, int *errpos) {
/* Initialize expression state. */ /* Initialize expression state. */
exprstate *es = malloc(sizeof(exprstate)); exprstate *es = RedisModule_Alloc(sizeof(exprstate));
if (!es) return NULL; if (!es) return NULL;
es->expr = strdup(expr); es->expr = strdup(expr);
if (!es->expr) { if (!es->expr) {
free(es); RedisModule_Free(es);
return NULL; return NULL;
} }
es->p = es->expr; es->p = es->expr;
@ -484,7 +490,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
token->token_type == EXPR_TOKEN_STR || token->token_type == EXPR_TOKEN_STR ||
token->token_type == EXPR_TOKEN_SELECTOR) token->token_type == EXPR_TOKEN_SELECTOR)
{ {
exprtoken *value_token = malloc(sizeof(exprtoken)); exprtoken *value_token = RedisModule_Alloc(sizeof(exprtoken));
if (!value_token) { if (!value_token) {
if (errpos) *errpos = token->offset; if (errpos) *errpos = token->offset;
exprFree(es); exprFree(es);
@ -503,7 +509,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
/* Handle operators. */ /* Handle operators. */
if (token->token_type == EXPR_TOKEN_OP) { if (token->token_type == EXPR_TOKEN_OP) {
exprtoken *op_token = malloc(sizeof(exprtoken)); exprtoken *op_token = RedisModule_Alloc(sizeof(exprtoken));
if (!op_token) { if (!op_token) {
if (errpos) *errpos = token->offset; if (errpos) *errpos = token->offset;
exprFree(es); exprFree(es);
@ -620,7 +626,7 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
// Handle selectors by calling the callback. // Handle selectors by calling the callback.
if (t->token_type == EXPR_TOKEN_SELECTOR) { if (t->token_type == EXPR_TOKEN_SELECTOR) {
exprtoken *result = malloc(sizeof(exprtoken)); exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
if (result != NULL && json != NULL) { if (result != NULL && json != NULL) {
cJSON *attrib = NULL; cJSON *attrib = NULL;
if (parsed_json == NULL) if (parsed_json == NULL)
@ -665,14 +671,14 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
// Push non-operator values directly onto the stack. // Push non-operator values directly onto the stack.
if (t->token_type != EXPR_TOKEN_OP) { if (t->token_type != EXPR_TOKEN_OP) {
exprtoken *nt = malloc(sizeof(exprtoken)); exprtoken *nt = RedisModule_Alloc(sizeof(exprtoken));
*nt = *t; *nt = *t;
exprStackPush(&es->values_stack, nt); exprStackPush(&es->values_stack, nt);
continue; continue;
} }
// Handle operators. // Handle operators.
exprtoken *result = malloc(sizeof(exprtoken)); exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
result->token_type = EXPR_TOKEN_NUM; result->token_type = EXPR_TOKEN_NUM;
// Pop operands - we know we have enough from compile-time checks. // Pop operands - we know we have enough from compile-time checks.

31
hnsw.c
View File

@ -780,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) {
* arrays must have space for at least 'k' items. * arrays must have space for at least 'k' items.
* norm_query should be set to 1 if the query vector is already * norm_query should be set to 1 if the query vector is already
* normalized, otherwise, if 0, the function will copy the vector, * normalized, otherwise, if 0, the function will copy the vector,
* L2-normalize the copy and search using the normalized version. */ * L2-normalize the copy and search using the normalized version.
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, *
* If the filter_privdata callback is passed, only elements passing the
* specified filter (invoked with privdata and the value associated
* to the node as arguments) are returned. In such case, if max_candidates
* is not NULL, it represents the maximum number of nodes to explore, since
* the search may be otherwise unbound if few or no elements pass the
* filter. */
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot, hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized) int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates)
{ {
if (!index || !query_vector || !neighbors || k == 0) return -1; if (!index || !query_vector || !neighbors || k == 0) return -1;
if (!index->enter_point) return 0; // Empty index. if (!index->enter_point) return 0; // Empty index.
@ -811,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
} }
/* Search bottom layer (the most densely populated) with ef = k */ /* Search bottom layer (the most densely populated) with ef = k */
pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot); pqueue *results = search_layer_with_filter(
index, &query, curr_ep, k, 0, slot, filter_callback,
filter_privdata, max_candidates);
if (!results) { if (!results) {
hnsw_free_tmp_node(&query, query_vector); hnsw_free_tmp_node(&query, query_vector);
return -1; return -1;
@ -831,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
return found; return found;
} }
/* Wrapper to hnsw_search_with_filter() when no filter is needed. */
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized)
{
return hnsw_search_with_filter(index,query_vector,k,neighbors,
distances,slot,query_vector_is_normalized,
NULL,NULL,0);
}
/* Rescan a node and update the wortst neighbor index. /* Rescan a node and update the wortst neighbor index.
* The followinng two functions are variants of this function to be used * The followinng two functions are variants of this function to be used
* when links are added or removed: they may do less work than a full scan. */ * when links are added or removed: they may do less work than a full scan. */

6
hnsw.h
View File

@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
int hnsw_search(HNSW *index, const float *query, uint32_t k, int hnsw_search(HNSW *index, const float *query, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot, hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized); int query_vector_is_normalized);
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates);
void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));

65
vset.c
View File

@ -20,6 +20,10 @@
#include <pthread.h> #include <pthread.h>
#include "hnsw.h" #include "hnsw.h"
// We inline directly the expression implementation here so that building
// the module is trivial.
#include "expr.c"
static RedisModuleType *VectorSetType; static RedisModuleType *VectorSetType;
static uint64_t VectorSetTypeNextId = 0; static uint64_t VectorSetTypeNextId = 0;
@ -561,6 +565,22 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
} }
} }
/* HNSW callback to filter items according to a predicate function
* (our FILTER expression in this case). */
int vectorSetFilterCallback(void *value, void *privdata) {
exprstate *expr = privdata;
struct vsetNodeVal *nv = value;
if (nv->attrib == NULL) return 0; // No attributes? No match.
size_t json_len;
char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
#if 0
int res = exprRun(expr,json,json_len);
printf("%s %d\n", json, res);
return res;
#endif
return exprRun(expr,json,json_len);
}
/* Common path for the execution of the VSIM command both threaded and /* Common path for the execution of the VSIM command both threaded and
* not threaded. Note that 'ctx' may be normal context of a thread safe * not threaded. Note that 'ctx' may be normal context of a thread safe
* context obtained from a blocked client. The locking that is specific * context obtained from a blocked client. The locking that is specific
@ -568,7 +588,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
* handles the HNSW locking explicitly. */ * handles the HNSW locking explicitly. */
void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset, void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
float *vec, unsigned long count, float epsilon, unsigned long withscores, float *vec, unsigned long count, float epsilon, unsigned long withscores,
unsigned long ef) unsigned long ef, exprstate *filter_expr, unsigned long filter_ef)
{ {
/* In our scan, we can't just collect 'count' elements as /* In our scan, we can't just collect 'count' elements as
* if count is small we would explore the graph in an insufficient * if count is small we would explore the graph in an insufficient
@ -585,7 +605,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
float *distances = RedisModule_Alloc(sizeof(float)*ef); float *distances = RedisModule_Alloc(sizeof(float)*ef);
int slot = hnsw_acquire_read_slot(vset->hnsw); int slot = hnsw_acquire_read_slot(vset->hnsw);
unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); unsigned int found;
if (filter_expr == NULL) {
found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
} else {
found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef);
}
hnsw_release_read_slot(vset->hnsw,slot); hnsw_release_read_slot(vset->hnsw,slot);
RedisModule_Free(vec); RedisModule_Free(vec);
@ -598,7 +623,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
for (unsigned int i = 0; i < found && i < count; i++) { for (unsigned int i = 0; i < found && i < count; i++) {
if (distances[i] > epsilon) break; if (distances[i] > epsilon) break;
RedisModule_ReplyWithString(ctx, neighbors[i]->value); struct vsetNodeVal *nv = neighbors[i]->value;
RedisModule_ReplyWithString(ctx, nv->item);
arraylen++; arraylen++;
if (withscores) { if (withscores) {
/* The similarity score is provided in a 0-1 range. */ /* The similarity score is provided in a 0-1 range. */
@ -613,6 +639,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
RedisModule_Free(neighbors); RedisModule_Free(neighbors);
RedisModule_Free(distances); RedisModule_Free(distances);
if (filter_expr) exprFree(filter_expr);
} }
/* VSIM thread handling the blocked client request. */ /* VSIM thread handling the blocked client request. */
@ -628,6 +655,8 @@ void *VSIM_thread(void *arg) {
float epsilon = *((float*)targ[4]); float epsilon = *((float*)targ[4]);
unsigned long withscores = (unsigned long)targ[5]; unsigned long withscores = (unsigned long)targ[5];
unsigned long ef = (unsigned long)targ[6]; unsigned long ef = (unsigned long)targ[6];
exprstate *filter_expr = targ[7];
unsigned long filter_ef = (unsigned long)targ[8];
RedisModule_Free(targ[4]); RedisModule_Free(targ[4]);
RedisModule_Free(targ); RedisModule_Free(targ);
@ -635,7 +664,7 @@ void *VSIM_thread(void *arg) {
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
// Run the query. // Run the query.
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
// Cleanup. // Cleanup.
RedisModule_FreeThreadSafeContext(ctx); RedisModule_FreeThreadSafeContext(ctx);
@ -644,7 +673,7 @@ void *VSIM_thread(void *arg) {
return NULL; return NULL;
} }
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */ /* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx); RedisModule_AutoMemory(ctx);
@ -658,6 +687,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
long long ef = 0; /* Exploration factor (see HNSW paper) */ long long ef = 0; /* Exploration factor (see HNSW paper) */
double epsilon = 2.0; /* Max cosine distance */ double epsilon = 2.0; /* Max cosine distance */
/* Things computed later. */
long long filter_ef = 0;
exprstate *filter_expr = NULL;
/* Get key and vector type */ /* Get key and vector type */
RedisModuleString *key = argv[1]; RedisModuleString *key = argv[1];
const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
@ -766,6 +799,19 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
} }
j += 2; j += 2;
} else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
RedisModuleString *exprarg = argv[j+1];
size_t exprlen;
char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
int errpos;
filter_expr = exprCompile(exprstr,&errpos);
if (filter_expr == NULL) {
if ((size_t)errpos >= exprlen) errpos = 0;
return RedisModule_ReplyWithErrorFormat(ctx,
"ERR syntax error in FILTER expression near: %s",
exprstr+errpos);
}
j += 2;
} else { } else {
RedisModule_Free(vec); RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx, return RedisModule_ReplyWithError(ctx,
@ -774,6 +820,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
} }
int threaded_request = 1; // Run on a thread, by default. int threaded_request = 1; // Run on a thread, by default.
if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
// Disable threaded for MULTI/EXEC and Lua. // Disable threaded for MULTI/EXEC and Lua.
if (RedisModule_GetContextFlags(ctx) & if (RedisModule_GetContextFlags(ctx) &
@ -799,7 +846,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
pthread_t tid; pthread_t tid;
void **targ = RedisModule_Alloc(sizeof(void*)*7); void **targ = RedisModule_Alloc(sizeof(void*)*9);
targ[0] = bc; targ[0] = bc;
targ[1] = vset; targ[1] = vset;
targ[2] = vec; targ[2] = vec;
@ -808,16 +855,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
*((float*)targ[4]) = epsilon; *((float*)targ[4]) = epsilon;
targ[5] = (void*)(unsigned long)withscores; targ[5] = (void*)(unsigned long)withscores;
targ[6] = (void*)(unsigned long)ef; targ[6] = (void*)(unsigned long)ef;
targ[7] = (void*)filter_expr;
targ[8] = (void*)(unsigned long)filter_ef;
if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
pthread_rwlock_unlock(&vset->in_use_lock); pthread_rwlock_unlock(&vset->in_use_lock);
RedisModule_AbortBlock(bc); RedisModule_AbortBlock(bc);
RedisModule_Free(vec); RedisModule_Free(vec);
RedisModule_Free(targ[4]); RedisModule_Free(targ[4]);
RedisModule_Free(targ); RedisModule_Free(targ);
return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread"); VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
} }
} else { } else {
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef); VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
} }
return REDISMODULE_OK; return REDISMODULE_OK;