VRANDMEMBER command implemented.

This commit is contained in:
antirez 2025-03-17 23:52:15 +01:00
parent 706721f8c8
commit 22ce9f3fad
1 changed files with 144 additions and 0 deletions

144
vset.c
View File

@ -1281,6 +1281,146 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_OK; return REDISMODULE_OK;
} }
/* VRANDMEMBER key [count]
* Return random members from a vector set.
*
* Without count: returns a single random member.
* With positive count: N unique random members (no duplicates).
* With negative count: N random members (with possible duplicates).
*
* If the key doesn't exist, returns NULL if count is not given, or
* an empty array if a count was given. */
int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
/* Check arguments. */
if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx);
/* Parse optional count argument. */
long long count = 1; /* Default is to return a single element. */
int with_count = (argc == 3);
if (with_count) {
if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx,
"ERR COUNT value is not an integer");
}
/* Count = 0 is a special case, return empty array */
if (count == 0) {
return RedisModule_ReplyWithEmptyArray(ctx);
}
}
/* Open key. */
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
int type = RedisModule_KeyType(key);
/* Handle non-existing key. */
if (type == REDISMODULE_KEYTYPE_EMPTY) {
if (!with_count) {
return RedisModule_ReplyWithNull(ctx);
} else {
return RedisModule_ReplyWithEmptyArray(ctx);
}
}
/* Check key type. */
if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
}
/* Get vector set from key. */
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
uint64_t set_size = vset->hnsw->node_count;
/* No elements in the set? */
if (set_size == 0) {
if (!with_count) {
return RedisModule_ReplyWithNull(ctx);
} else {
return RedisModule_ReplyWithEmptyArray(ctx);
}
}
/* Case 1: No count specified: return a single element. */
if (!with_count) {
hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
if (random_node) {
struct vsetNodeVal *nv = random_node->value;
return RedisModule_ReplyWithString(ctx, nv->item);
} else {
return RedisModule_ReplyWithNull(ctx);
}
}
/* Case 2: COUNT option given, return an array of elements. */
int allow_duplicates = (count < 0);
long long abs_count = (count < 0) ? -count : count;
/* Cap the count to the set size if we are not allowing duplicates. */
if (!allow_duplicates && abs_count > (long long)set_size)
abs_count = set_size;
/* Prepare reply. */
RedisModule_ReplyWithArray(ctx, abs_count);
if (allow_duplicates) {
/* Simple case: With duplicates, just pick random nodes
* abs_count times. */
for (long long i = 0; i < abs_count; i++) {
hnswNode *random_node = hnsw_random_node(vset->hnsw,0);
struct vsetNodeVal *nv = random_node->value;
RedisModule_ReplyWithString(ctx, nv->item);
}
} else {
/* Case where count is positive: we need unique elements.
* But, if the user asked for many elements, selecting so
* many (> 20%) random nodes may be too expansive: we just start
* from a random element and follow the next link.
*
* Otherwisem for the <= 20% case, a dictionary is used to
* reject duplicates. */
int use_dict = (abs_count <= set_size * 0.2);
if (use_dict) {
RedisModuleDict *returned = RedisModule_CreateDict(ctx);
long long returned_count = 0;
while (returned_count < abs_count) {
hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
struct vsetNodeVal *nv = random_node->value;
/* Check if we've already returned this element. */
if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) {
/* Mark as returned and add to results. */
RedisModule_DictSet(returned, nv->item, (void*)1);
RedisModule_ReplyWithString(ctx, nv->item);
returned_count++;
}
}
RedisModule_FreeDict(ctx, returned);
} else {
/* For large samples, get a random starting node and walk
* the list. */
hnswNode *start_node = hnsw_random_node(vset->hnsw, 0);
hnswNode *current = start_node;
long long returned_count = 0;
while (returned_count < abs_count) {
if (current == NULL) {
/* Restart from head if we hit the end. */
current = vset->hnsw->head;
}
struct vsetNodeVal *nv = current->value;
RedisModule_ReplyWithString(ctx, nv->item);
returned_count++;
current = current->next;
}
}
}
return REDISMODULE_OK;
}
/* ============================== vset type methods ========================= */ /* ============================== vset type methods ========================= */
#define SAVE_FLAG_HAS_PROJMATRIX (1<<0) #define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
@ -1602,6 +1742,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
return REDISMODULE_ERR; return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx, "VRANDMEMBER",
VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
return REDISMODULE_ERR;
hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
RedisModule_Realloc); RedisModule_Realloc);