Expr filtering: VSIM FILTER first draft.

This commit is contained in:
antirez 2025-02-22 17:10:33 +01:00
parent 025790fc50
commit 5304318335
5 changed files with 112 additions and 28 deletions

View File

@ -55,7 +55,7 @@ all: vset.so
vset.xo: redismodule.h
vset.so: vset.xo hnsw.xo
vset.so: vset.xo hnsw.xo cJSON.xo
$(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc
# Example sources / objects

36
expr.c
View File

@ -12,6 +12,12 @@
#include <math.h>
#include "cJSON.h"
#ifdef TEST_MAIN
#define RedisModule_Alloc malloc
#define RedisModule_Realloc realloc
#define RedisModule_Free free
#endif
#define EXPR_TOKEN_EOF 0
#define EXPR_TOKEN_NUM 1
#define EXPR_TOKEN_STR 2
@ -121,8 +127,8 @@ struct {
/* ================================ Expr token ============================== */
void exprFreeToken(exprtoken *t) {
if (t == NULL) return;
if (t->heapstr != NULL) free(t->heapstr);
free(t);
if (t->heapstr != NULL) RedisModule_Free(t->heapstr);
RedisModule_Free(t);
}
/* ============================== Stack handling ============================ */
@ -134,7 +140,7 @@ void exprFreeToken(exprtoken *t) {
/* Initialize a new expression stack. */
void exprStackInit(exprstack *stack) {
stack->items = malloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
stack->numitems = 0;
stack->allocsize = EXPR_STACK_INITIAL_SIZE;
}
@ -146,7 +152,7 @@ int exprStackPush(exprstack *stack, exprtoken *token) {
if (stack->numitems == stack->allocsize) {
size_t newsize = stack->allocsize * 2;
exprtoken **newitems =
realloc(stack->items, sizeof(exprtoken*) * newsize);
RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize);
if (newitems == NULL) return 0;
stack->items = newitems;
stack->allocsize = newsize;
@ -175,7 +181,7 @@ exprtoken *exprStackPeek(exprstack *stack) {
void exprStackFree(exprstack *stack) {
for (int j = 0; j < stack->numitems; j++)
exprFreeToken(stack->items[j]);
free(stack->items);
RedisModule_Free(stack->items);
}
/* Just reset the stack removing all the items, but leaving it in a state
@ -284,7 +290,7 @@ void exprFree(exprstate *es) {
if (es == NULL) return;
/* Free the original expression string. */
if (es->expr) free(es->expr);
if (es->expr) RedisModule_Free(es->expr);
/* Free all stacks. */
exprStackFree(&es->values_stack);
@ -293,7 +299,7 @@ void exprFree(exprstate *es) {
exprStackFree(&es->program);
/* Free the state object itself. */
free(es);
RedisModule_Free(es);
}
/* Split the provided expression into a stack of tokens. Returns
@ -344,7 +350,7 @@ int exprTokenize(exprstate *es, int *errpos) {
}
/* Allocate and copy current token to tokens stack */
exprtoken *token = malloc(sizeof(exprtoken));
exprtoken *token = RedisModule_Alloc(sizeof(exprtoken));
if (!token) return 1; // OOM.
*token = es->current; /* Copy the entire structure. */
@ -445,12 +451,12 @@ int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *err
* expression is returned by reference. */
exprstate *exprCompile(char *expr, int *errpos) {
/* Initialize expression state. */
exprstate *es = malloc(sizeof(exprstate));
exprstate *es = RedisModule_Alloc(sizeof(exprstate));
if (!es) return NULL;
es->expr = strdup(expr);
if (!es->expr) {
free(es);
RedisModule_Free(es);
return NULL;
}
es->p = es->expr;
@ -484,7 +490,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
token->token_type == EXPR_TOKEN_STR ||
token->token_type == EXPR_TOKEN_SELECTOR)
{
exprtoken *value_token = malloc(sizeof(exprtoken));
exprtoken *value_token = RedisModule_Alloc(sizeof(exprtoken));
if (!value_token) {
if (errpos) *errpos = token->offset;
exprFree(es);
@ -503,7 +509,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
/* Handle operators. */
if (token->token_type == EXPR_TOKEN_OP) {
exprtoken *op_token = malloc(sizeof(exprtoken));
exprtoken *op_token = RedisModule_Alloc(sizeof(exprtoken));
if (!op_token) {
if (errpos) *errpos = token->offset;
exprFree(es);
@ -620,7 +626,7 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
// Handle selectors by calling the callback.
if (t->token_type == EXPR_TOKEN_SELECTOR) {
exprtoken *result = malloc(sizeof(exprtoken));
exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
if (result != NULL && json != NULL) {
cJSON *attrib = NULL;
if (parsed_json == NULL)
@ -665,14 +671,14 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
// Push non-operator values directly onto the stack.
if (t->token_type != EXPR_TOKEN_OP) {
exprtoken *nt = malloc(sizeof(exprtoken));
exprtoken *nt = RedisModule_Alloc(sizeof(exprtoken));
*nt = *t;
exprStackPush(&es->values_stack, nt);
continue;
}
// Handle operators.
exprtoken *result = malloc(sizeof(exprtoken));
exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
result->token_type = EXPR_TOKEN_NUM;
// Pop operands - we know we have enough from compile-time checks.

31
hnsw.c
View File

@ -780,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) {
* arrays must have space for at least 'k' items.
* norm_query should be set to 1 if the query vector is already
* normalized, otherwise, if 0, the function will copy the vector,
* L2-normalize the copy and search using the normalized version. */
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
* L2-normalize the copy and search using the normalized version.
*
* If the filter_privdata callback is passed, only elements passing the
* specified filter (invoked with privdata and the value associated
* to the node as arguments) are returned. In such case, if max_candidates
* is not NULL, it represents the maximum number of nodes to explore, since
* the search may be otherwise unbound if few or no elements pass the
* filter. */
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized)
int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates)
{
if (!index || !query_vector || !neighbors || k == 0) return -1;
if (!index->enter_point) return 0; // Empty index.
@ -811,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
}
/* Search bottom layer (the most densely populated) with ef = k */
pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot);
pqueue *results = search_layer_with_filter(
index, &query, curr_ep, k, 0, slot, filter_callback,
filter_privdata, max_candidates);
if (!results) {
hnsw_free_tmp_node(&query, query_vector);
return -1;
@ -831,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
return found;
}
/* Wrapper to hnsw_search_with_filter() when no filter is needed. */
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized)
{
return hnsw_search_with_filter(index,query_vector,k,neighbors,
distances,slot,query_vector_is_normalized,
NULL,NULL,0);
}
/* Rescan a node and update the wortst neighbor index.
* The followinng two functions are variants of this function to be used
* when links are added or removed: they may do less work than a full scan. */

6
hnsw.h
View File

@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
int hnsw_search(HNSW *index, const float *query, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized);
int hnsw_search_with_filter
(HNSW *index, const float *query_vector, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates);
void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));

65
vset.c
View File

@ -20,6 +20,10 @@
#include <pthread.h>
#include "hnsw.h"
// We inline directly the expression implementation here so that building
// the module is trivial.
#include "expr.c"
static RedisModuleType *VectorSetType;
static uint64_t VectorSetTypeNextId = 0;
@ -561,6 +565,22 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
}
}
/* HNSW callback to filter items according to a predicate function
* (our FILTER expression in this case). */
int vectorSetFilterCallback(void *value, void *privdata) {
exprstate *expr = privdata;
struct vsetNodeVal *nv = value;
if (nv->attrib == NULL) return 0; // No attributes? No match.
size_t json_len;
char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
#if 0
int res = exprRun(expr,json,json_len);
printf("%s %d\n", json, res);
return res;
#endif
return exprRun(expr,json,json_len);
}
/* Common path for the execution of the VSIM command both threaded and
* not threaded. Note that 'ctx' may be normal context of a thread safe
* context obtained from a blocked client. The locking that is specific
@ -568,7 +588,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
* handles the HNSW locking explicitly. */
void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
float *vec, unsigned long count, float epsilon, unsigned long withscores,
unsigned long ef)
unsigned long ef, exprstate *filter_expr, unsigned long filter_ef)
{
/* In our scan, we can't just collect 'count' elements as
* if count is small we would explore the graph in an insufficient
@ -585,7 +605,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
float *distances = RedisModule_Alloc(sizeof(float)*ef);
int slot = hnsw_acquire_read_slot(vset->hnsw);
unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
unsigned int found;
if (filter_expr == NULL) {
found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
} else {
found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef);
}
hnsw_release_read_slot(vset->hnsw,slot);
RedisModule_Free(vec);
@ -598,7 +623,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
for (unsigned int i = 0; i < found && i < count; i++) {
if (distances[i] > epsilon) break;
RedisModule_ReplyWithString(ctx, neighbors[i]->value);
struct vsetNodeVal *nv = neighbors[i]->value;
RedisModule_ReplyWithString(ctx, nv->item);
arraylen++;
if (withscores) {
/* The similarity score is provided in a 0-1 range. */
@ -613,6 +639,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
RedisModule_Free(neighbors);
RedisModule_Free(distances);
if (filter_expr) exprFree(filter_expr);
}
/* VSIM thread handling the blocked client request. */
@ -628,6 +655,8 @@ void *VSIM_thread(void *arg) {
float epsilon = *((float*)targ[4]);
unsigned long withscores = (unsigned long)targ[5];
unsigned long ef = (unsigned long)targ[6];
exprstate *filter_expr = targ[7];
unsigned long filter_ef = (unsigned long)targ[8];
RedisModule_Free(targ[4]);
RedisModule_Free(targ);
@ -635,7 +664,7 @@ void *VSIM_thread(void *arg) {
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
// Run the query.
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
// Cleanup.
RedisModule_FreeThreadSafeContext(ctx);
@ -644,7 +673,7 @@ void *VSIM_thread(void *arg) {
return NULL;
}
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx);
@ -658,6 +687,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
long long ef = 0; /* Exploration factor (see HNSW paper) */
double epsilon = 2.0; /* Max cosine distance */
/* Things computed later. */
long long filter_ef = 0;
exprstate *filter_expr = NULL;
/* Get key and vector type */
RedisModuleString *key = argv[1];
const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
@ -766,6 +799,19 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
}
j += 2;
} else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
RedisModuleString *exprarg = argv[j+1];
size_t exprlen;
char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
int errpos;
filter_expr = exprCompile(exprstr,&errpos);
if (filter_expr == NULL) {
if ((size_t)errpos >= exprlen) errpos = 0;
return RedisModule_ReplyWithErrorFormat(ctx,
"ERR syntax error in FILTER expression near: %s",
exprstr+errpos);
}
j += 2;
} else {
RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx,
@ -774,6 +820,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
}
int threaded_request = 1; // Run on a thread, by default.
if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
// Disable threaded for MULTI/EXEC and Lua.
if (RedisModule_GetContextFlags(ctx) &
@ -799,7 +846,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
pthread_t tid;
void **targ = RedisModule_Alloc(sizeof(void*)*7);
void **targ = RedisModule_Alloc(sizeof(void*)*9);
targ[0] = bc;
targ[1] = vset;
targ[2] = vec;
@ -808,16 +855,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
*((float*)targ[4]) = epsilon;
targ[5] = (void*)(unsigned long)withscores;
targ[6] = (void*)(unsigned long)ef;
targ[7] = (void*)filter_expr;
targ[8] = (void*)(unsigned long)filter_ef;
if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
pthread_rwlock_unlock(&vset->in_use_lock);
RedisModule_AbortBlock(bc);
RedisModule_Free(vec);
RedisModule_Free(targ[4]);
RedisModule_Free(targ);
return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread");
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
}
} else {
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
}
return REDISMODULE_OK;