From 22ce9f3fad6c54bb24b8e7b948503a371237eb36 Mon Sep 17 00:00:00 2001 From: antirez Date: Mon, 17 Mar 2025 23:52:15 +0100 Subject: [PATCH] VRANDMEMBER command implemented. --- vset.c | 144 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/vset.c b/vset.c index 4a0e8f4de..3be8e2f9e 100644 --- a/vset.c +++ b/vset.c @@ -1281,6 +1281,146 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_OK; } +/* VRANDMEMBER key [count] + * Return random members from a vector set. + * + * Without count: returns a single random member. + * With positive count: N unique random members (no duplicates). + * With negative count: N random members (with possible duplicates). + * + * If the key doesn't exist, returns NULL if count is not given, or + * an empty array if a count was given. */ +int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + /* Check arguments. */ + if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx); + + /* Parse optional count argument. */ + long long count = 1; /* Default is to return a single element. */ + int with_count = (argc == 3); + + if (with_count) { + if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, + "ERR COUNT value is not an integer"); + } + /* Count = 0 is a special case, return empty array */ + if (count == 0) { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Open key. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + /* Handle non-existing key. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Check key type. */ + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + uint64_t set_size = vset->hnsw->node_count; + + /* No elements in the set? */ + if (set_size == 0) { + if (!with_count) { + return RedisModule_ReplyWithNull(ctx); + } else { + return RedisModule_ReplyWithEmptyArray(ctx); + } + } + + /* Case 1: No count specified: return a single element. */ + if (!with_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + if (random_node) { + struct vsetNodeVal *nv = random_node->value; + return RedisModule_ReplyWithString(ctx, nv->item); + } else { + return RedisModule_ReplyWithNull(ctx); + } + } + + /* Case 2: COUNT option given, return an array of elements. */ + int allow_duplicates = (count < 0); + long long abs_count = (count < 0) ? -count : count; + + /* Cap the count to the set size if we are not allowing duplicates. */ + if (!allow_duplicates && abs_count > (long long)set_size) + abs_count = set_size; + + /* Prepare reply. */ + RedisModule_ReplyWithArray(ctx, abs_count); + + if (allow_duplicates) { + /* Simple case: With duplicates, just pick random nodes + * abs_count times. */ + for (long long i = 0; i < abs_count; i++) { + hnswNode *random_node = hnsw_random_node(vset->hnsw,0); + struct vsetNodeVal *nv = random_node->value; + RedisModule_ReplyWithString(ctx, nv->item); + } + } else { + /* Case where count is positive: we need unique elements. + * But, if the user asked for many elements, selecting so + * many (> 20%) random nodes may be too expansive: we just start + * from a random element and follow the next link. + * + * Otherwisem for the <= 20% case, a dictionary is used to + * reject duplicates. */ + int use_dict = (abs_count <= set_size * 0.2); + + if (use_dict) { + RedisModuleDict *returned = RedisModule_CreateDict(ctx); + + long long returned_count = 0; + while (returned_count < abs_count) { + hnswNode *random_node = hnsw_random_node(vset->hnsw, 0); + struct vsetNodeVal *nv = random_node->value; + + /* Check if we've already returned this element. */ + if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) { + /* Mark as returned and add to results. */ + RedisModule_DictSet(returned, nv->item, (void*)1); + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + } + } + RedisModule_FreeDict(ctx, returned); + } else { + /* For large samples, get a random starting node and walk + * the list. */ + hnswNode *start_node = hnsw_random_node(vset->hnsw, 0); + hnswNode *current = start_node; + + long long returned_count = 0; + while (returned_count < abs_count) { + if (current == NULL) { + /* Restart from head if we hit the end. */ + current = vset->hnsw->head; + } + struct vsetNodeVal *nv = current->value; + RedisModule_ReplyWithString(ctx, nv->item); + returned_count++; + current = current->next; + } + } + } + return REDISMODULE_OK; +} + /* ============================== vset type methods ========================= */ #define SAVE_FLAG_HAS_PROJMATRIX (1<<0) @@ -1602,6 +1742,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx, "VRANDMEMBER", + VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, RedisModule_Realloc);