mirror of https://mirror.osredm.com/root/redis.git
Expr filtering: VSIM FILTER first draft.
This commit is contained in:
parent
025790fc50
commit
5304318335
2
Makefile
2
Makefile
|
@ -55,7 +55,7 @@ all: vset.so
|
|||
|
||||
vset.xo: redismodule.h
|
||||
|
||||
vset.so: vset.xo hnsw.xo
|
||||
vset.so: vset.xo hnsw.xo cJSON.xo
|
||||
$(CC) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc
|
||||
|
||||
# Example sources / objects
|
||||
|
|
36
expr.c
36
expr.c
|
@ -12,6 +12,12 @@
|
|||
#include <math.h>
|
||||
#include "cJSON.h"
|
||||
|
||||
#ifdef TEST_MAIN
|
||||
#define RedisModule_Alloc malloc
|
||||
#define RedisModule_Realloc realloc
|
||||
#define RedisModule_Free free
|
||||
#endif
|
||||
|
||||
#define EXPR_TOKEN_EOF 0
|
||||
#define EXPR_TOKEN_NUM 1
|
||||
#define EXPR_TOKEN_STR 2
|
||||
|
@ -121,8 +127,8 @@ struct {
|
|||
/* ================================ Expr token ============================== */
|
||||
void exprFreeToken(exprtoken *t) {
|
||||
if (t == NULL) return;
|
||||
if (t->heapstr != NULL) free(t->heapstr);
|
||||
free(t);
|
||||
if (t->heapstr != NULL) RedisModule_Free(t->heapstr);
|
||||
RedisModule_Free(t);
|
||||
}
|
||||
|
||||
/* ============================== Stack handling ============================ */
|
||||
|
@ -134,7 +140,7 @@ void exprFreeToken(exprtoken *t) {
|
|||
|
||||
/* Initialize a new expression stack. */
|
||||
void exprStackInit(exprstack *stack) {
|
||||
stack->items = malloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
|
||||
stack->items = RedisModule_Alloc(sizeof(exprtoken*) * EXPR_STACK_INITIAL_SIZE);
|
||||
stack->numitems = 0;
|
||||
stack->allocsize = EXPR_STACK_INITIAL_SIZE;
|
||||
}
|
||||
|
@ -146,7 +152,7 @@ int exprStackPush(exprstack *stack, exprtoken *token) {
|
|||
if (stack->numitems == stack->allocsize) {
|
||||
size_t newsize = stack->allocsize * 2;
|
||||
exprtoken **newitems =
|
||||
realloc(stack->items, sizeof(exprtoken*) * newsize);
|
||||
RedisModule_Realloc(stack->items, sizeof(exprtoken*) * newsize);
|
||||
if (newitems == NULL) return 0;
|
||||
stack->items = newitems;
|
||||
stack->allocsize = newsize;
|
||||
|
@ -175,7 +181,7 @@ exprtoken *exprStackPeek(exprstack *stack) {
|
|||
void exprStackFree(exprstack *stack) {
|
||||
for (int j = 0; j < stack->numitems; j++)
|
||||
exprFreeToken(stack->items[j]);
|
||||
free(stack->items);
|
||||
RedisModule_Free(stack->items);
|
||||
}
|
||||
|
||||
/* Just reset the stack removing all the items, but leaving it in a state
|
||||
|
@ -284,7 +290,7 @@ void exprFree(exprstate *es) {
|
|||
if (es == NULL) return;
|
||||
|
||||
/* Free the original expression string. */
|
||||
if (es->expr) free(es->expr);
|
||||
if (es->expr) RedisModule_Free(es->expr);
|
||||
|
||||
/* Free all stacks. */
|
||||
exprStackFree(&es->values_stack);
|
||||
|
@ -293,7 +299,7 @@ void exprFree(exprstate *es) {
|
|||
exprStackFree(&es->program);
|
||||
|
||||
/* Free the state object itself. */
|
||||
free(es);
|
||||
RedisModule_Free(es);
|
||||
}
|
||||
|
||||
/* Split the provided expression into a stack of tokens. Returns
|
||||
|
@ -344,7 +350,7 @@ int exprTokenize(exprstate *es, int *errpos) {
|
|||
}
|
||||
|
||||
/* Allocate and copy current token to tokens stack */
|
||||
exprtoken *token = malloc(sizeof(exprtoken));
|
||||
exprtoken *token = RedisModule_Alloc(sizeof(exprtoken));
|
||||
if (!token) return 1; // OOM.
|
||||
|
||||
*token = es->current; /* Copy the entire structure. */
|
||||
|
@ -445,12 +451,12 @@ int exprProcessOperator(exprstate *es, exprtoken *op, int *stack_items, int *err
|
|||
* expression is returned by reference. */
|
||||
exprstate *exprCompile(char *expr, int *errpos) {
|
||||
/* Initialize expression state. */
|
||||
exprstate *es = malloc(sizeof(exprstate));
|
||||
exprstate *es = RedisModule_Alloc(sizeof(exprstate));
|
||||
if (!es) return NULL;
|
||||
|
||||
es->expr = strdup(expr);
|
||||
if (!es->expr) {
|
||||
free(es);
|
||||
RedisModule_Free(es);
|
||||
return NULL;
|
||||
}
|
||||
es->p = es->expr;
|
||||
|
@ -484,7 +490,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
|
|||
token->token_type == EXPR_TOKEN_STR ||
|
||||
token->token_type == EXPR_TOKEN_SELECTOR)
|
||||
{
|
||||
exprtoken *value_token = malloc(sizeof(exprtoken));
|
||||
exprtoken *value_token = RedisModule_Alloc(sizeof(exprtoken));
|
||||
if (!value_token) {
|
||||
if (errpos) *errpos = token->offset;
|
||||
exprFree(es);
|
||||
|
@ -503,7 +509,7 @@ exprstate *exprCompile(char *expr, int *errpos) {
|
|||
|
||||
/* Handle operators. */
|
||||
if (token->token_type == EXPR_TOKEN_OP) {
|
||||
exprtoken *op_token = malloc(sizeof(exprtoken));
|
||||
exprtoken *op_token = RedisModule_Alloc(sizeof(exprtoken));
|
||||
if (!op_token) {
|
||||
if (errpos) *errpos = token->offset;
|
||||
exprFree(es);
|
||||
|
@ -620,7 +626,7 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
|
|||
|
||||
// Handle selectors by calling the callback.
|
||||
if (t->token_type == EXPR_TOKEN_SELECTOR) {
|
||||
exprtoken *result = malloc(sizeof(exprtoken));
|
||||
exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
|
||||
if (result != NULL && json != NULL) {
|
||||
cJSON *attrib = NULL;
|
||||
if (parsed_json == NULL)
|
||||
|
@ -665,14 +671,14 @@ int exprRun(exprstate *es, char *json, size_t json_len) {
|
|||
|
||||
// Push non-operator values directly onto the stack.
|
||||
if (t->token_type != EXPR_TOKEN_OP) {
|
||||
exprtoken *nt = malloc(sizeof(exprtoken));
|
||||
exprtoken *nt = RedisModule_Alloc(sizeof(exprtoken));
|
||||
*nt = *t;
|
||||
exprStackPush(&es->values_stack, nt);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle operators.
|
||||
exprtoken *result = malloc(sizeof(exprtoken));
|
||||
exprtoken *result = RedisModule_Alloc(sizeof(exprtoken));
|
||||
result->token_type = EXPR_TOKEN_NUM;
|
||||
|
||||
// Pop operands - we know we have enough from compile-time checks.
|
||||
|
|
31
hnsw.c
31
hnsw.c
|
@ -780,10 +780,21 @@ void hnsw_free_tmp_node(hnswNode *node, const float *vector) {
|
|||
* arrays must have space for at least 'k' items.
|
||||
* norm_query should be set to 1 if the query vector is already
|
||||
* normalized, otherwise, if 0, the function will copy the vector,
|
||||
* L2-normalize the copy and search using the normalized version. */
|
||||
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
|
||||
* L2-normalize the copy and search using the normalized version.
|
||||
*
|
||||
* If the filter_privdata callback is passed, only elements passing the
|
||||
* specified filter (invoked with privdata and the value associated
|
||||
* to the node as arguments) are returned. In such case, if max_candidates
|
||||
* is not NULL, it represents the maximum number of nodes to explore, since
|
||||
* the search may be otherwise unbound if few or no elements pass the
|
||||
* filter. */
|
||||
int hnsw_search_with_filter
|
||||
(HNSW *index, const float *query_vector, uint32_t k,
|
||||
hnswNode **neighbors, float *distances, uint32_t slot,
|
||||
int query_vector_is_normalized)
|
||||
int query_vector_is_normalized,
|
||||
int (*filter_callback)(void *value, void *privdata),
|
||||
void *filter_privdata, uint32_t max_candidates)
|
||||
|
||||
{
|
||||
if (!index || !query_vector || !neighbors || k == 0) return -1;
|
||||
if (!index->enter_point) return 0; // Empty index.
|
||||
|
@ -811,7 +822,9 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
|
|||
}
|
||||
|
||||
/* Search bottom layer (the most densely populated) with ef = k */
|
||||
pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot);
|
||||
pqueue *results = search_layer_with_filter(
|
||||
index, &query, curr_ep, k, 0, slot, filter_callback,
|
||||
filter_privdata, max_candidates);
|
||||
if (!results) {
|
||||
hnsw_free_tmp_node(&query, query_vector);
|
||||
return -1;
|
||||
|
@ -831,6 +844,16 @@ int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
|
|||
return found;
|
||||
}
|
||||
|
||||
/* Wrapper to hnsw_search_with_filter() when no filter is needed. */
|
||||
int hnsw_search(HNSW *index, const float *query_vector, uint32_t k,
|
||||
hnswNode **neighbors, float *distances, uint32_t slot,
|
||||
int query_vector_is_normalized)
|
||||
{
|
||||
return hnsw_search_with_filter(index,query_vector,k,neighbors,
|
||||
distances,slot,query_vector_is_normalized,
|
||||
NULL,NULL,0);
|
||||
}
|
||||
|
||||
/* Rescan a node and update the wortst neighbor index.
|
||||
* The followinng two functions are variants of this function to be used
|
||||
* when links are added or removed: they may do less work than a full scan. */
|
||||
|
|
6
hnsw.h
6
hnsw.h
|
@ -119,6 +119,12 @@ hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
|
|||
int hnsw_search(HNSW *index, const float *query, uint32_t k,
|
||||
hnswNode **neighbors, float *distances, uint32_t slot,
|
||||
int query_vector_is_normalized);
|
||||
int hnsw_search_with_filter
|
||||
(HNSW *index, const float *query_vector, uint32_t k,
|
||||
hnswNode **neighbors, float *distances, uint32_t slot,
|
||||
int query_vector_is_normalized,
|
||||
int (*filter_callback)(void *value, void *privdata),
|
||||
void *filter_privdata, uint32_t max_candidates);
|
||||
void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
|
||||
void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));
|
||||
|
||||
|
|
65
vset.c
65
vset.c
|
@ -20,6 +20,10 @@
|
|||
#include <pthread.h>
|
||||
#include "hnsw.h"
|
||||
|
||||
// We inline directly the expression implementation here so that building
|
||||
// the module is trivial.
|
||||
#include "expr.c"
|
||||
|
||||
static RedisModuleType *VectorSetType;
|
||||
static uint64_t VectorSetTypeNextId = 0;
|
||||
|
||||
|
@ -561,6 +565,22 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
}
|
||||
}
|
||||
|
||||
/* HNSW callback to filter items according to a predicate function
|
||||
* (our FILTER expression in this case). */
|
||||
int vectorSetFilterCallback(void *value, void *privdata) {
|
||||
exprstate *expr = privdata;
|
||||
struct vsetNodeVal *nv = value;
|
||||
if (nv->attrib == NULL) return 0; // No attributes? No match.
|
||||
size_t json_len;
|
||||
char *json = (char*)RedisModule_StringPtrLen(nv->attrib,&json_len);
|
||||
#if 0
|
||||
int res = exprRun(expr,json,json_len);
|
||||
printf("%s %d\n", json, res);
|
||||
return res;
|
||||
#endif
|
||||
return exprRun(expr,json,json_len);
|
||||
}
|
||||
|
||||
/* Common path for the execution of the VSIM command both threaded and
|
||||
* not threaded. Note that 'ctx' may be normal context of a thread safe
|
||||
* context obtained from a blocked client. The locking that is specific
|
||||
|
@ -568,7 +588,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
* handles the HNSW locking explicitly. */
|
||||
void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
|
||||
float *vec, unsigned long count, float epsilon, unsigned long withscores,
|
||||
unsigned long ef)
|
||||
unsigned long ef, exprstate *filter_expr, unsigned long filter_ef)
|
||||
{
|
||||
/* In our scan, we can't just collect 'count' elements as
|
||||
* if count is small we would explore the graph in an insufficient
|
||||
|
@ -585,7 +605,12 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
|
|||
hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef);
|
||||
float *distances = RedisModule_Alloc(sizeof(float)*ef);
|
||||
int slot = hnsw_acquire_read_slot(vset->hnsw);
|
||||
unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
|
||||
unsigned int found;
|
||||
if (filter_expr == NULL) {
|
||||
found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0);
|
||||
} else {
|
||||
found = hnsw_search_with_filter(vset->hnsw, vec, ef, neighbors, distances, slot, 0, vectorSetFilterCallback, filter_expr, filter_ef);
|
||||
}
|
||||
hnsw_release_read_slot(vset->hnsw,slot);
|
||||
RedisModule_Free(vec);
|
||||
|
||||
|
@ -598,7 +623,8 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
|
|||
|
||||
for (unsigned int i = 0; i < found && i < count; i++) {
|
||||
if (distances[i] > epsilon) break;
|
||||
RedisModule_ReplyWithString(ctx, neighbors[i]->value);
|
||||
struct vsetNodeVal *nv = neighbors[i]->value;
|
||||
RedisModule_ReplyWithString(ctx, nv->item);
|
||||
arraylen++;
|
||||
if (withscores) {
|
||||
/* The similarity score is provided in a 0-1 range. */
|
||||
|
@ -613,6 +639,7 @@ void VSIM_execute(RedisModuleCtx *ctx, struct vsetObject *vset,
|
|||
|
||||
RedisModule_Free(neighbors);
|
||||
RedisModule_Free(distances);
|
||||
if (filter_expr) exprFree(filter_expr);
|
||||
}
|
||||
|
||||
/* VSIM thread handling the blocked client request. */
|
||||
|
@ -628,6 +655,8 @@ void *VSIM_thread(void *arg) {
|
|||
float epsilon = *((float*)targ[4]);
|
||||
unsigned long withscores = (unsigned long)targ[5];
|
||||
unsigned long ef = (unsigned long)targ[6];
|
||||
exprstate *filter_expr = targ[7];
|
||||
unsigned long filter_ef = (unsigned long)targ[8];
|
||||
RedisModule_Free(targ[4]);
|
||||
RedisModule_Free(targ);
|
||||
|
||||
|
@ -635,7 +664,7 @@ void *VSIM_thread(void *arg) {
|
|||
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc);
|
||||
|
||||
// Run the query.
|
||||
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
|
||||
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
|
||||
|
||||
// Cleanup.
|
||||
RedisModule_FreeThreadSafeContext(ctx);
|
||||
|
@ -644,7 +673,7 @@ void *VSIM_thread(void *arg) {
|
|||
return NULL;
|
||||
}
|
||||
|
||||
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */
|
||||
/* VSIM key [ELE|FP32|VALUES] <vector or ele> [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] [FILTER expression] [FILTER-EF exploration-factor] */
|
||||
int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
||||
RedisModule_AutoMemory(ctx);
|
||||
|
||||
|
@ -658,6 +687,10 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
long long ef = 0; /* Exploration factor (see HNSW paper) */
|
||||
double epsilon = 2.0; /* Max cosine distance */
|
||||
|
||||
/* Things computed later. */
|
||||
long long filter_ef = 0;
|
||||
exprstate *filter_expr = NULL;
|
||||
|
||||
/* Get key and vector type */
|
||||
RedisModuleString *key = argv[1];
|
||||
const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL);
|
||||
|
@ -766,6 +799,19 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
|
||||
}
|
||||
j += 2;
|
||||
} else if (!strcasecmp(opt, "FILTER") && j+1 < argc) {
|
||||
RedisModuleString *exprarg = argv[j+1];
|
||||
size_t exprlen;
|
||||
char *exprstr = (char*)RedisModule_StringPtrLen(exprarg,&exprlen);
|
||||
int errpos;
|
||||
filter_expr = exprCompile(exprstr,&errpos);
|
||||
if (filter_expr == NULL) {
|
||||
if ((size_t)errpos >= exprlen) errpos = 0;
|
||||
return RedisModule_ReplyWithErrorFormat(ctx,
|
||||
"ERR syntax error in FILTER expression near: %s",
|
||||
exprstr+errpos);
|
||||
}
|
||||
j += 2;
|
||||
} else {
|
||||
RedisModule_Free(vec);
|
||||
return RedisModule_ReplyWithError(ctx,
|
||||
|
@ -774,6 +820,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
}
|
||||
|
||||
int threaded_request = 1; // Run on a thread, by default.
|
||||
if (filter_ef == 0) filter_ef = count * 100; // Max filter visited nodes.
|
||||
|
||||
// Disable threaded for MULTI/EXEC and Lua.
|
||||
if (RedisModule_GetContextFlags(ctx) &
|
||||
|
@ -799,7 +846,7 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
|
||||
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0);
|
||||
pthread_t tid;
|
||||
void **targ = RedisModule_Alloc(sizeof(void*)*7);
|
||||
void **targ = RedisModule_Alloc(sizeof(void*)*9);
|
||||
targ[0] = bc;
|
||||
targ[1] = vset;
|
||||
targ[2] = vec;
|
||||
|
@ -808,16 +855,18 @@ int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
|||
*((float*)targ[4]) = epsilon;
|
||||
targ[5] = (void*)(unsigned long)withscores;
|
||||
targ[6] = (void*)(unsigned long)ef;
|
||||
targ[7] = (void*)filter_expr;
|
||||
targ[8] = (void*)(unsigned long)filter_ef;
|
||||
if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) {
|
||||
pthread_rwlock_unlock(&vset->in_use_lock);
|
||||
RedisModule_AbortBlock(bc);
|
||||
RedisModule_Free(vec);
|
||||
RedisModule_Free(targ[4]);
|
||||
RedisModule_Free(targ);
|
||||
return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread");
|
||||
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
|
||||
}
|
||||
} else {
|
||||
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef);
|
||||
VSIM_execute(ctx, vset, vec, count, epsilon, withscores, ef, filter_expr, filter_ef);
|
||||
}
|
||||
|
||||
return REDISMODULE_OK;
|
||||
|
|
Loading…
Reference in New Issue