mirror of https://mirror.osredm.com/root/redis.git
VRANDMEMBER command implemented.
This commit is contained in:
parent
706721f8c8
commit
22ce9f3fad
144
vset.c
144
vset.c
|
@ -1281,6 +1281,146 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
||||||
return REDISMODULE_OK;
|
return REDISMODULE_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* VRANDMEMBER key [count]
|
||||||
|
* Return random members from a vector set.
|
||||||
|
*
|
||||||
|
* Without count: returns a single random member.
|
||||||
|
* With positive count: N unique random members (no duplicates).
|
||||||
|
* With negative count: N random members (with possible duplicates).
|
||||||
|
*
|
||||||
|
* If the key doesn't exist, returns NULL if count is not given, or
|
||||||
|
* an empty array if a count was given. */
|
||||||
|
int VRANDMEMBER_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
||||||
|
RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
|
||||||
|
|
||||||
|
/* Check arguments. */
|
||||||
|
if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx);
|
||||||
|
|
||||||
|
/* Parse optional count argument. */
|
||||||
|
long long count = 1; /* Default is to return a single element. */
|
||||||
|
int with_count = (argc == 3);
|
||||||
|
|
||||||
|
if (with_count) {
|
||||||
|
if (RedisModule_StringToLongLong(argv[2], &count) != REDISMODULE_OK) {
|
||||||
|
return RedisModule_ReplyWithError(ctx,
|
||||||
|
"ERR COUNT value is not an integer");
|
||||||
|
}
|
||||||
|
/* Count = 0 is a special case, return empty array */
|
||||||
|
if (count == 0) {
|
||||||
|
return RedisModule_ReplyWithEmptyArray(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Open key. */
|
||||||
|
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ);
|
||||||
|
int type = RedisModule_KeyType(key);
|
||||||
|
|
||||||
|
/* Handle non-existing key. */
|
||||||
|
if (type == REDISMODULE_KEYTYPE_EMPTY) {
|
||||||
|
if (!with_count) {
|
||||||
|
return RedisModule_ReplyWithNull(ctx);
|
||||||
|
} else {
|
||||||
|
return RedisModule_ReplyWithEmptyArray(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check key type. */
|
||||||
|
if (RedisModule_ModuleTypeGetType(key) != VectorSetType) {
|
||||||
|
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Get vector set from key. */
|
||||||
|
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
|
||||||
|
uint64_t set_size = vset->hnsw->node_count;
|
||||||
|
|
||||||
|
/* No elements in the set? */
|
||||||
|
if (set_size == 0) {
|
||||||
|
if (!with_count) {
|
||||||
|
return RedisModule_ReplyWithNull(ctx);
|
||||||
|
} else {
|
||||||
|
return RedisModule_ReplyWithEmptyArray(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Case 1: No count specified: return a single element. */
|
||||||
|
if (!with_count) {
|
||||||
|
hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
|
||||||
|
if (random_node) {
|
||||||
|
struct vsetNodeVal *nv = random_node->value;
|
||||||
|
return RedisModule_ReplyWithString(ctx, nv->item);
|
||||||
|
} else {
|
||||||
|
return RedisModule_ReplyWithNull(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Case 2: COUNT option given, return an array of elements. */
|
||||||
|
int allow_duplicates = (count < 0);
|
||||||
|
long long abs_count = (count < 0) ? -count : count;
|
||||||
|
|
||||||
|
/* Cap the count to the set size if we are not allowing duplicates. */
|
||||||
|
if (!allow_duplicates && abs_count > (long long)set_size)
|
||||||
|
abs_count = set_size;
|
||||||
|
|
||||||
|
/* Prepare reply. */
|
||||||
|
RedisModule_ReplyWithArray(ctx, abs_count);
|
||||||
|
|
||||||
|
if (allow_duplicates) {
|
||||||
|
/* Simple case: With duplicates, just pick random nodes
|
||||||
|
* abs_count times. */
|
||||||
|
for (long long i = 0; i < abs_count; i++) {
|
||||||
|
hnswNode *random_node = hnsw_random_node(vset->hnsw,0);
|
||||||
|
struct vsetNodeVal *nv = random_node->value;
|
||||||
|
RedisModule_ReplyWithString(ctx, nv->item);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* Case where count is positive: we need unique elements.
|
||||||
|
* But, if the user asked for many elements, selecting so
|
||||||
|
* many (> 20%) random nodes may be too expansive: we just start
|
||||||
|
* from a random element and follow the next link.
|
||||||
|
*
|
||||||
|
* Otherwisem for the <= 20% case, a dictionary is used to
|
||||||
|
* reject duplicates. */
|
||||||
|
int use_dict = (abs_count <= set_size * 0.2);
|
||||||
|
|
||||||
|
if (use_dict) {
|
||||||
|
RedisModuleDict *returned = RedisModule_CreateDict(ctx);
|
||||||
|
|
||||||
|
long long returned_count = 0;
|
||||||
|
while (returned_count < abs_count) {
|
||||||
|
hnswNode *random_node = hnsw_random_node(vset->hnsw, 0);
|
||||||
|
struct vsetNodeVal *nv = random_node->value;
|
||||||
|
|
||||||
|
/* Check if we've already returned this element. */
|
||||||
|
if (RedisModule_DictGet(returned, nv->item, NULL) == NULL) {
|
||||||
|
/* Mark as returned and add to results. */
|
||||||
|
RedisModule_DictSet(returned, nv->item, (void*)1);
|
||||||
|
RedisModule_ReplyWithString(ctx, nv->item);
|
||||||
|
returned_count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RedisModule_FreeDict(ctx, returned);
|
||||||
|
} else {
|
||||||
|
/* For large samples, get a random starting node and walk
|
||||||
|
* the list. */
|
||||||
|
hnswNode *start_node = hnsw_random_node(vset->hnsw, 0);
|
||||||
|
hnswNode *current = start_node;
|
||||||
|
|
||||||
|
long long returned_count = 0;
|
||||||
|
while (returned_count < abs_count) {
|
||||||
|
if (current == NULL) {
|
||||||
|
/* Restart from head if we hit the end. */
|
||||||
|
current = vset->hnsw->head;
|
||||||
|
}
|
||||||
|
struct vsetNodeVal *nv = current->value;
|
||||||
|
RedisModule_ReplyWithString(ctx, nv->item);
|
||||||
|
returned_count++;
|
||||||
|
current = current->next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return REDISMODULE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
/* ============================== vset type methods ========================= */
|
/* ============================== vset type methods ========================= */
|
||||||
|
|
||||||
#define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
|
#define SAVE_FLAG_HAS_PROJMATRIX (1<<0)
|
||||||
|
@ -1602,6 +1742,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
||||||
VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
|
VGETATTR_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR)
|
||||||
return REDISMODULE_ERR;
|
return REDISMODULE_ERR;
|
||||||
|
|
||||||
|
if (RedisModule_CreateCommand(ctx, "VRANDMEMBER",
|
||||||
|
VRANDMEMBER_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR)
|
||||||
|
return REDISMODULE_ERR;
|
||||||
|
|
||||||
hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
|
hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc,
|
||||||
RedisModule_Realloc);
|
RedisModule_Realloc);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue