From 33d653e24f03adadf469dc1627e15e76580f68b5 Mon Sep 17 00:00:00 2001 From: antirez Date: Mon, 27 Jan 2025 17:24:02 +0100 Subject: [PATCH] First internal release. --- .gitignore | 10 + LICENSE | 2 + Makefile | 77 + README.md | 175 +++ hnsw.c | 2482 ++++++++++++++++++++++++++++++ hnsw.h | 158 ++ redismodule.h | 1704 ++++++++++++++++++++ test.py | 189 +++ tests/basic_commands.py | 21 + tests/basic_similarity.py | 35 + tests/concurrent_vsim_and_del.py | 48 + tests/deletion.py | 173 +++ tests/evict_empty.py | 27 + tests/large_scale.py | 56 + tests/node_update.py | 85 + tests/persistence.py | 83 + tests/reduce.py | 71 + tests/vadd_cas.py | 98 ++ tests/vemb.py | 41 + vset.c | 1208 +++++++++++++++ w2v.c | 315 ++++ 21 files changed, 7058 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 hnsw.c create mode 100644 hnsw.h create mode 100644 redismodule.h create mode 100755 test.py create mode 100644 tests/basic_commands.py create mode 100644 tests/basic_similarity.py create mode 100644 tests/concurrent_vsim_and_del.py create mode 100644 tests/deletion.py create mode 100644 tests/evict_empty.py create mode 100644 tests/large_scale.py create mode 100644 tests/node_update.py create mode 100644 tests/persistence.py create mode 100644 tests/reduce.py create mode 100644 tests/vadd_cas.py create mode 100644 tests/vemb.py create mode 100644 vset.c create mode 100644 w2v.c diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..16a0848d7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__ +misc +*.so +*.xo +*.o +.DS_Store +w2v +word2vec.bin +TODO +*.txt diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..df7a7a7cd --- /dev/null +++ b/LICENSE @@ -0,0 +1,2 @@ +This code is Copyright (C) 2024-2025 Salvatore Sanfilippo. +All Rights Reserved. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..5c3a66fcb --- /dev/null +++ b/Makefile @@ -0,0 +1,77 @@ +# Compiler settings +CC = gcc + +ifdef SANITIZER +ifeq ($(SANITIZER),address) + SAN=-fsanitize=address +else +ifeq ($(SANITIZER),undefined) + SAN=-fsanitize=undefined +else +ifeq ($(SANITIZER),thread) + SAN=-fsanitize=thread +else + $(error "unknown sanitizer=${SANITIZER}") +endif +endif +endif +endif + +CFLAGS = -O2 -Wall -Wextra -g -ffast-math $(SAN) +LDFLAGS = -lm $(SAN) + +# Detect OS +uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not') + +# Shared library compile flags for linux / osx +ifeq ($(uname_S),Linux) + SHOBJ_CFLAGS ?= -W -Wall -fno-common -g -ggdb -std=c99 -O2 + SHOBJ_LDFLAGS ?= -shared +else + SHOBJ_CFLAGS ?= -W -Wall -dynamic -fno-common -g -ggdb -std=c99 -Ofast -ffast-math + SHOBJ_LDFLAGS ?= -bundle -undefined dynamic_lookup +endif + +# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting. +ifeq ($(uname_S),Darwin) +ifeq ("$(wildcard /usr/lib/libSystem.dylib)","") +LIBS = -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem +endif +endif + +.SUFFIXES: .c .so .xo .o + +all: vset.so + +.c.xo: + $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ + +vset.xo: redismodule.h + +vset.so: vset.xo hnsw.xo + $(LD) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc + +# Example sources / objects +SRCS = hnsw.c w2v.c +OBJS = $(SRCS:.c=.o) + +TARGET = w2v +MODULE = vset.so + +# Default target +all: $(TARGET) $(MODULE) + +# Example linking rule +$(TARGET): $(OBJS) + $(CC) $(OBJS) $(LDFLAGS) -o $(TARGET) + +# Compilation rule for object files +%.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +# Clean rule +clean: + rm -f $(TARGET) $(OBJS) *.xo *.so + +# Declare phony targets +.PHONY: all clean diff --git a/README.md b/README.md new file mode 100644 index 000000000..910023a95 --- /dev/null +++ b/README.md @@ -0,0 +1,175 @@ +This module implements vector sets for Redis, a new Redis data type similar +to sorted sets but having a vector instead of a score. It is possible to +add items and then get them back by similiarity to either a user-provided +vector or a vector of an element already inserted. + +## Installation + + make + +Then load the module with the following command line, or by inserting the needed directives in the `redis.conf` file. + + ./redis-server --loadmodule vset.so + +To run tests, I suggest using this: + + ./redis-server --save "" --enable-debug-command yes + +The execute the tests with: + + ./test.py + +## Commands + +**VADD: add items into a vector set** + + VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT] [BIN] + +Add a new element into the vector set specified by the key. +The vector can be provided as FP32 blob of values, or as floating point +numbers as strings, prefixed by the number of elements (3 in the example): + + VADD mykey VALUES 3 0.1 1.2 0.5 my-element + +The `REDUCE` option implements random projection, in order to reduce the +dimensionality of the vector. The projection matrix is saved and reloaded +along with the vector set. + +The `CAS` option performs the operation partially using threads, in a +check-and-set style. The neighbor candidates collection, which is slow, is +performed in the background, while the command is executed in the main thread. + +The `NOQUANT` option forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default. + +The `BIN` option forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality. + +**VSIM: return elements by vector similarity** + + VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EF exploration-factor] + +The command returns similar vectors, in the example instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements associated with a vector similar to a given element already in the sorted set: + + > VSIM word_embeddings ELE apple + 1) "apple" + 2) "apples" + 3) "pear" + 4) "fruit" + 5) "berry" + 6) "pears" + 7) "strawberry" + 8) "peach" + 9) "potato" + 10) "grape" + +It is possible to specify a `COUNT` and also to get the similarity score (from 1 to 0, where 1 is identical, 0 is opposite vector) between the query and the returned items. + + > VSIM word_embeddings ELE apple WITHSCORES COUNT 3 + 1) "apple" + 2) "0.9998867657923256" + 3) "apples" + 4) "0.8598527610301971" + 5) "pear" + 6) "0.8226882219314575" + +The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000. + +**VDIM: return the dimension of the vectors inside the vector set** + + VDIM keyname + +Example: + + > VDIM word_embeddings + (integer) 300 + +Note that in the case of vectors that were populated using the `REDUCE` +option, for random projection, the vector set will report the size of +the projected (reduced) dimension. Yet the user should perform all the +queries using full-size vectors. + +**VCARD: return the number of elements in a vector set** + + VCARD key + +Example: + + > VCARD word_embeddings + (integer) 3000000 + + +**VREM: remove elements from vector set** + + VREM key element + +Example: + + > VADD vset VALUES 3 1 0 1 bar + (integer) 1 + > VREM vset bar + (integer) 1 + > VREM vset bar + (integer) 0 + +VREM does not perform thumstone / logical deletion, but will actually reclaim +the memory from the vector set, so it is save to add and remove elements +in a vector set in the context of long running applications that continuously +update the same index. + +**VEMB: return the approximated vector of an element** + + VEMB key element + +Example: + + > VEMB word_embeddings SQL + 1) "0.18208661675453186" + 2) "0.08535309880971909" + 3) "0.1365649551153183" + 4) "-0.16501599550247192" + 5) "0.14225517213344574" + ... 295 more elements ... + +Because vector sets perform insertion time normalization and optional +quantization, the returned vector could be approximated. `VEMB` will take +care to de-quantized and de-normalize the vector before returning it. + +**VLINKS: introspection command that shows neighbors for a node** + + VLINKS key element [WITHSCORES] + +The command reports the neighbors for each level. + +**VINFO: introspection command that shows info about a vector set** + + VINFO key + +Example: + + > VINFO word_embeddings + 1) quant-type + 2) int8 + 3) vector-dim + 4) (integer) 300 + 5) size + 6) (integer) 3000000 + 7) max-level + 8) (integer) 12 + 9) vset-uid + 10) (integer) 1 + 11) hnsw-max-node-uid + 12) (integer) 3000000 + +## Known bugs + +* When VADD with REDUCE is replicated, we should probably send the replicas the random matrix, in order for VEMB to read the same things. This is not critical, because the behavior of VADD / VSIM should be transparent if you don't look at the transformed vectors, but still probably worth doing. +* Replication code is pretty much untested, and very vanilla (replicating the commands verbatim). + +## Implementation details + +Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality. + +The main features are: + +* Proper nodes deletion with relinking. +* 8 bits quantization. +* Threaded queries. diff --git a/hnsw.c b/hnsw.c new file mode 100644 index 000000000..0d8ade987 --- /dev/null +++ b/hnsw.c @@ -0,0 +1,2482 @@ +/* HNSW (Hierarchical Navigable Small World) Implementation. + * + * Based on the paper by Yu. A. Malkov, D. A. Yashunin. + * + * Many details of this implementation, not covered in the paper, were + * obtained simulating different workloads and checking the connection + * quality of the graph. + * + * Notably, this implementation: + * + * 1. Only uses bi-directional links, implementing strategies in order to + * link new nodes even when candidates are full, and our new node would + * be not close enough to replace old links in candidate. + * + * 2. We normalize on-insert, making cosine similarity and dot product the + * same. This means we can't use euclidian distance or alike here. + * Together with quantization, this provides an important speedup that + * makes HNSW more practical. + * + * 3. The quantization used is int8. And it is performed per-vector, so the + * "range" (max abs value) is also stored alongside with the quantized data. + * + * 4. This library implements true elements deletion, not just marking the + * element as deleted, but removing it (we can do it since our links are + * bidirectional), and reliking the nodes orphaned of one link among + * them. + * + * Copyright(C) 2024-2025 Salvatore Sanfilippo. All Rights Reserved. + */ + +#define _DEFAULT_SOURCE +#define _POSIX_C_SOURCE 200809L + +#include +#include +#include +#include +#include +#include /* for INFINITY if not in math.h */ +#include +#include "hnsw.h" + +#if 0 +#define debugmsg printf +#else +#define debugmsg if(0) printf +#endif + +#ifndef INFINITY +#define INFINITY (1.0/0.0) +#endif + +#define MIN(a,b) ((a) < (b) ? (a) : (b)) + +/* Algorithm parameters. */ + +#define HNSW_M 16 /* Number of max connections per node. Note that + * layer zero has twice as many. Also note that + * when a new node is added, we will populate + * even layer 0 links to just HNSW_M neighbors, so + * initially half layer 0 slots will be empty. */ +#define HNSW_M0 (HNSW_M*2) /* Maximum number of connections for layer 0 */ +#define HNSW_P 0.25 /* Probability of level increase. */ +#define HNSW_MAX_LEVEL 16 /* Max level nodes can reach. */ +#define HNSW_EF_C 200 /* Default size of dynamic candidate list while + * inserting a new node, in case 0 is passed to + * the 'ef' argument while inserting. This is also + * used when deleting nodes for the search step + * needed sometimes to reconnect nodes that remain + * orphaned of one link. */ + + +void (*hfree)(void *p) = free; +void *(*hmalloc)(size_t s) = malloc; +void *(*hrealloc)(void *old, size_t s) = realloc; + +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)) +{ + hfree = free_ptr; + hmalloc = malloc_ptr; + hrealloc = realloc_ptr; +} + +// Get a warning if you use the libc allocator functions for mistake. +#define malloc use_hmalloc_instead +#define realloc use_hrealloc_instead +#define free use_hfree_instead + +/* ============================== Prototypes ================================ */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted); + +/* ============================ Priority queue ================================ + * We need a priority queue to take an ordered list of candidates. Right now + * it is implemented as a linear array, since it is relatively small. + * + * You may find it to be odd that we take the best element (smaller distance) + * at the end of the array, but this way popping from the pqueue is O(1), as + * we need to just decrement the count, and this is a very used operation + * in a critical code path. This makes the priority queue implementation a + * bit more complex in the insertion, but for good reasons. */ + +/* Maximum number of candidates we'll ever need (cit. Bill Gates). */ +#define HNSW_MAX_CANDIDATES 256 + +typedef struct { + hnswNode *node; + float distance; +} pqitem; + +typedef struct { + pqitem *items; /* Array of items. */ + uint32_t count; /* Current number of items. */ + uint32_t cap; /* Maximum capacity. */ +} pqueue; + +/* The HNSW algorithms access the pqueue conceptually from nearest (index 0) + * to farest (larger indexes) node, so the following macros are used to + * access the pqueue in this fashion, even if the internal order is + * actually reversed. */ +#define pq_get_node(q,i) ((q)->items[(q)->count-(i+1)].node) +#define pq_get_distance(q,i) ((q)->items[(q)->count-(i+1)].distance) + +/* Create a new priority queue with given capacity. Adding to the + * pqueue only retains 'capacity' elements with the shortest distance. */ +pqueue *pq_new(uint32_t capacity) { + pqueue *pq = hmalloc(sizeof(*pq)); + if (!pq) return NULL; + + pq->items = hmalloc(sizeof(pqitem) * capacity); + if (!pq->items) { + hfree(pq); + return NULL; + } + + pq->count = 0; + pq->cap = capacity; + return pq; +} + +/* Free a priority queue. */ +void pq_free(pqueue *pq) { + if (!pq) return; + hfree(pq->items); + hfree(pq); +} + +/* Insert maintaining distance order (higher distances first). */ +void pq_push(pqueue *pq, hnswNode *node, float distance) { + if (pq->count < pq->cap) { + /* Queue not full: shift right from high distances to make room. */ + uint32_t i = pq->count; + while (i > 0 && pq->items[i-1].distance < distance) { + pq->items[i] = pq->items[i-1]; + i--; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + pq->count++; + } else { + /* Queue full: if new item is worse than worst, ignore it. */ + if (distance >= pq->items[0].distance) return; + + /* Otherwise shift left from low distances to drop worst. */ + uint32_t i = 0; + while (i < pq->cap-1 && pq->items[i+1].distance > distance) { + pq->items[i] = pq->items[i+1]; + i++; + } + pq->items[i].node = node; + pq->items[i].distance = distance; + } +} + +/* Remove and return the top (closest) element, which is at count-1 + * since we store elements with higher distances first. + * Runs in constant time. */ +hnswNode *pq_pop(pqueue *pq, float *distance) { + if (pq->count == 0) return NULL; + pq->count--; + *distance = pq->items[pq->count].distance; + return pq->items[pq->count].node; +} + +/* Get distance of the furthest element. + * An empty priority queue has infinite distance as its furthest element, + * note that this behavior is needed by the algorithms below. */ +float pq_max_distance(pqueue *pq) { + if (pq->count == 0) return INFINITY; + return pq->items[0].distance; +} + +/* ============================ HNSW algorithm ============================== */ + +/* Dot product: our vectors are already normalized. + * Version for not quantized vectors of floats. */ +float vectors_distance_float(const float *x, const float *y, uint32_t dim) { + /* Use two accumulators to reduce dependencies among multiplications. + * This provides a clear speed boost in Apple silicon, but should be + * help in general. */ + float dot0 = 0.0f, dot1 = 0.0f; + uint32_t i; + + // Process 8 elements per iteration, 50/50 with the two accumulators. + for (i = 0; i + 7 < dim; i += 8) { + dot0 += x[i] * y[i] + + x[i+1] * y[i+1] + + x[i+2] * y[i+2] + + x[i+3] * y[i+3]; + + dot1 += x[i+4] * y[i+4] + + x[i+5] * y[i+5] + + x[i+6] * y[i+6] + + x[i+7] * y[i+7]; + } + + /* Handle the remaining elements. These are a minority in the case + * of a smal vector, don't optimze this part. */ + for (; i < dim; i++) dot0 += x[i] * y[i]; + + /* The following line may be counter intuitive. The dot product of + * normalized vectors is equivalent to their cosine similarity. The + * cosine will be from -1 (vectors facing opposite directions in the + * N-dim space) to 1 (vectors are facing in the same direction). + * + * We kinda want a "score" of distance from 0 to 2 (this is a distance + * function and we want minimize the distance for K-NN searches), so we + * can't just add 1: that would return a number in the 0-2 range, with + * 0 meaning opposite vectors and 2 identical vectors, so this is + * similarity, not distance. + * + * Returning instead (1 - dotprod) inverts the meaning: 0 is identical + * and 2 is opposite, hence it is their distance. + * + * Why don't normalize the similarity right now, and return from 0 to + * 1? Because division is costly. */ + return 1.0f - (dot0 + dot1); +} + +/* Q8 quants dotproduct. We do integer math and later fix it by range. */ +float vectors_distance_q8(const int8_t *x, const int8_t *y, uint32_t dim, + float range_a, float range_b) { + // Handle zero vectors special case. + if (range_a == 0 || range_b == 0) { + /* Zero vector distance from anything is 1.0 + * (since 1.0 - dot_product where dot_product = 0). */ + return 1.0f; + } + + /* Each vector is quantized from [-max_abs, +max_abs] to [-127, 127] + * where range = 2*max_abs. */ + const float scale_product = (range_a/127) * (range_b/127); + + int32_t dot0 = 0, dot1 = 0; + uint32_t i; + + // Process 8 elements at a time for better pipeline utilization. + for (i = 0; i + 7 < dim; i += 8) { + dot0 += ((int32_t)x[i]) * ((int32_t)y[i]) + + ((int32_t)x[i+1]) * ((int32_t)y[i+1]) + + ((int32_t)x[i+2]) * ((int32_t)y[i+2]) + + ((int32_t)x[i+3]) * ((int32_t)y[i+3]); + + dot1 += ((int32_t)x[i+4]) * ((int32_t)y[i+4]) + + ((int32_t)x[i+5]) * ((int32_t)y[i+5]) + + ((int32_t)x[i+6]) * ((int32_t)y[i+6]) + + ((int32_t)x[i+7]) * ((int32_t)y[i+7]); + } + + // Handle remaining elements. + for (; i < dim; i++) dot0 += ((int32_t)x[i]) * ((int32_t)y[i]); + + // Convert to original range. + float dotf = (dot0 + dot1) * scale_product; + float distance = 1.0f - dotf; + + // Clamp distance to [0, 2]. + if (distance < 0) distance = 0; + else if (distance > 2) distance = 2; + return distance; +} + +static inline int popcount64(uint64_t x) { + x = (x & 0x5555555555555555) + ((x >> 1) & 0x5555555555555555); + x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); + x = (x & 0x0F0F0F0F0F0F0F0F) + ((x >> 4) & 0x0F0F0F0F0F0F0F0F); + x = (x & 0x00FF00FF00FF00FF) + ((x >> 8) & 0x00FF00FF00FF00FF); + x = (x & 0x0000FFFF0000FFFF) + ((x >> 16) & 0x0000FFFF0000FFFF); + x = (x & 0x00000000FFFFFFFF) + ((x >> 32) & 0x00000000FFFFFFFF); + return x; +} + +/* Binary vectors distance. */ +float vectors_distance_bin(const uint64_t *x, const uint64_t *y, uint32_t dim) { + uint32_t len = (dim+63)/64; + uint32_t opposite = 0; + for (uint32_t j = 0; j < len; j++) { + int64_t xor = x[j]^y[j]; + opposite += popcount64(xor); + } + return (float)opposite*2/dim; +} + +/* Dot product between nodes. Will call the right version depending on the + * quantization used. */ +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: + return vectors_distance_float(a->vector,b->vector,index->vector_dim); + case HNSW_QUANT_Q8: + return vectors_distance_q8(a->vector,b->vector,index->vector_dim,a->quants_range,b->quants_range); + case HNSW_QUANT_BIN: + return vectors_distance_bin(a->vector,b->vector,index->vector_dim); + default: + assert(1 != 1); + return 0; + } +} + +/* This do Q8 'range' quantization. + * For people looking at this code thinking: Oh, I could use min/max + * quants instead! Well: I tried with min/max normalization but the dot + * product needs to accumulate the sum for later correction, and it's slower. */ +void quantize_to_q8(float *src, int8_t *dst, uint32_t dim, float *rangeptr) { + float max_abs = 0; + for (uint32_t j = 0; j < dim; j++) { + if (src[j] > max_abs) max_abs = src[j]; + if (-src[j] > max_abs) max_abs = -src[j]; + } + + if (max_abs == 0) { + if (rangeptr) *rangeptr = 0; + memset(dst, 0, dim); + return; + } + + const float scale = 127.0f / max_abs; // Scale to map to [-127, 127]. + + for (uint32_t j = 0; j < dim; j++) { + dst[j] = (int8_t)roundf(src[j] * scale); + } + if (rangeptr) *rangeptr = max_abs; // Return max_abs instead of 2*max_abs. +} + +/* Binary quantization of vector 'src' to 'dst'. We use full words of + * 64 bit as smallest unit, we will just set all the unused bits to 0 + * so that they'll be the same in all the vectors, and when xor+popcount + * is used to compute the distance, such bits are not considered. This + * allows to go faster. */ +void quantize_to_bin(float *src, uint64_t *dst, uint32_t dim) { + memset(dst,0,(dim+63)/64*sizeof(uint64_t)); + for (uint32_t j = 0; j < dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + /* Since cosine similarity checks the vector direction and + * not magnitudo, we do likewise in the binary quantization and + * just remember if the component is positive or negative. */ + if (src[j] > 0) dst[word] |= 1ULL<quant_type = quant_type; + index->enter_point = NULL; + index->max_level = 0; + index->vector_dim = vector_dim; + index->node_count = 0; + index->last_id = 0; + index->head = NULL; + index->cursors = NULL; + + /* Initialize epochs array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + index->current_epoch[i] = 0; + + /* Initialize locks. */ + if (pthread_rwlock_init(&index->global_lock, NULL) != 0) { + hfree(index); + return NULL; + } + + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_init(&index->slot_locks[i], NULL) != 0) { + /* Clean up previously initialized mutexes. */ + for (int j = 0; j < i; j++) + pthread_mutex_destroy(&index->slot_locks[j]); + pthread_rwlock_destroy(&index->global_lock); + hfree(index); + return NULL; + } + } + + /* Initialize atomic variables. */ + index->next_slot = 0; + index->version = 0; + return index; +} + +/* Fill 'vec' with the node vector, de-normalizing and de-quantizing it + * as needed. Note that this function will return an approximated version + * of the original vector. */ +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec) { + if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(vec,node->vector,index->vector_dim*sizeof(float)); + } else if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) + vec[j] = (quants[j]*node->quants_range)/127; + } else if (index->quant_type == HNSW_QUANT_BIN) { + uint64_t *bits = node->vector; + for (uint32_t j = 0; j < index->vector_dim; j++) { + uint32_t word = j/64; + uint32_t bit = j&63; + vec[j] = (bits[word] & (1ULL<vector_dim; j++) + vec[j] *= node->l2; +} + +/* Return the number of bytes needed to represent a vector in the index, + * that is function of the dimension of the vectors and the quantization + * type used. */ +uint32_t hnsw_quants_bytes(HNSW *index) { + switch(index->quant_type) { + case HNSW_QUANT_NONE: return index->vector_dim * sizeof(float); + case HNSW_QUANT_Q8: return index->vector_dim; + case HNSW_QUANT_BIN: return (index->vector_dim+63)/64*8; + default: assert(0 && "Quantization type not supported."); + } +} + +/* Create new node. Returns NULL on out of memory. + * It is possible to pass the vector as floats or, in case this index + * was already stored on disk and is being loaded, or serialized and + * transmitted in any form, the already quantized version in + * 'qvector'. + * + * Only vector or qvector should be non-NULL. The reason why passing + * a quantized vector is useful, is that because re-normalizing and + * re-quantizing several times the same vector may accumulate rounding + * errors. So if you work with quantized indexes, you should save + * the quantized indexes. + * + * Note that, together with qvector, the quantization range is needed, + * since this library uses per-vector quantization. In case of quantized + * vectors the l2 is considered to be '1', so if you want to restore + * the right l2 (to use the API that returns an approximation of the + * original vector) make sure to save the l2 on disk and set it back + * after the node creation (see later for the serialization API that + * handles this and more). */ +hnswNode *hnsw_node_new(HNSW *index, uint64_t id, const float *vector, const int8_t *qvector, float qrange, uint32_t level) { + hnswNode *node = hmalloc(sizeof(hnswNode)+(sizeof(hnswNodeLayer)*(level+1))); + if (!node) return NULL; + + if (id == 0) id = ++index->last_id; + node->level = level; + node->id = id; + node->next = NULL; + node->vector = NULL; + node->l2 = 1; // Default in case of already quantized vectors. It is + // up to the caller to fill this later, if needed. + + /* Initialize visited epoch array. */ + for (int i = 0; i < HNSW_MAX_THREADS; i++) + node->visited_epoch[i] = 0; + + if (qvector == NULL) { + /* Copy input vector. */ + node->vector = hmalloc(sizeof(float) * index->vector_dim); + if (!node->vector) { + hfree(node); + return NULL; + } + memcpy(node->vector, vector, sizeof(float) * index->vector_dim); + hnsw_normalize_vector(node->vector,&node->l2,index->vector_dim); + + /* Handle quantization. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + hfree(node->vector); + hfree(node); + return NULL; + } + + // Quantize. + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector,quants,index->vector_dim,&node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector,quants,index->vector_dim); + break; + default: + assert(0 && "Quantization type not handled."); + break; + } + + // Discard the full precision vector. + hfree(node->vector); + node->vector = quants; + } + } else { + // We got the already quantized vector. Just copy it. + assert(index->quant_type != HNSW_QUANT_NONE); + uint32_t vector_bytes = hnsw_quants_bytes(index); + node->vector = hmalloc(vector_bytes); + node->quants_range = qrange; + if (node->vector == NULL) { + hfree(node); + return NULL; + } + memcpy(node->vector,qvector,vector_bytes); + } + + /* Initialize each layer. */ + for (uint32_t i = 0; i <= level; i++) { + uint32_t max_links = (i == 0) ? HNSW_M0 : HNSW_M; + node->layers[i].max_links = max_links; + node->layers[i].num_links = 0; + node->layers[i].worst_distance = 0; + node->layers[i].worst_idx = 0; + node->layers[i].links = hmalloc(sizeof(hnswNode*) * max_links); + if (!node->layers[i].links) { + for (uint32_t j = 0; j < i; j++) hfree(node->layers[j].links); + hfree(node->layers); + hfree(node->vector); + hfree(node); + return NULL; + } + } + + return node; +} + +/* Free a node. */ +void hnsw_node_free(hnswNode *node) { + if (!node) return; + + for (uint32_t i = 0; i <= node->level; i++) + hfree(node->layers[i].links); + + hfree(node->vector); + hfree(node); +} + +/* Free the entire index. */ +void hnsw_free(HNSW *index,void(*free_value)(void*value)) { + if (!index) return; + + hnswNode *current = index->head; + while (current) { + hnswNode *next = current->next; + if (free_value) free_value(current->value); + hnsw_node_free(current); + current = next; + } + + /* Destroy locks */ + pthread_rwlock_destroy(&index->global_lock); + for (int i = 0; i < HNSW_MAX_THREADS; i++) { + pthread_mutex_destroy(&index->slot_locks[i]); + } + + hfree(index); +} + +/* Add node to linked list of nodes. We may need to scan the whole + * HNSW graph for several reasons. The list is doubly linked since we + * also need the ability to remove a node without scanning the whole thing. */ +void hnsw_add_node(HNSW *index, hnswNode *node) { + node->next = index->head; + node->prev = NULL; + if (index->head) + index->head->prev = node; + index->head = node; + index->node_count++; +} + +/* Search the specified layer starting from the specified entry point + * to collect 'ef' nodes that are near to 'query'. */ +pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point, + uint32_t ef, uint32_t layer, uint32_t slot) +{ + // Mark visited nodes with a never seen epoch. + index->current_epoch[slot]++; + + pqueue *candidates = pq_new(HNSW_MAX_CANDIDATES); + pqueue *results = pq_new(ef); + if (!candidates || !results) { + if (candidates) pq_free(candidates); + if (results) pq_free(results); + return NULL; + } + + // Add entry point. + float dist = hnsw_distance(index, query, entry_point); + pq_push(candidates, entry_point, dist); + pq_push(results, entry_point, dist); + entry_point->visited_epoch[slot] = index->current_epoch[slot]; + + // Process candidates. + while (candidates->count > 0) { + // Pop closest element and use its saved distance. + float cur_dist; + hnswNode *current = pq_pop(candidates, &cur_dist); + + /* Stop if we can't get better results. Note that this can + * be true only if we already collected 'ef' elements in + * the priority queue. */ + float furthest = pq_max_distance(results); + if (cur_dist > furthest) break; + + /* Check neighbors. */ + for (uint32_t i = 0; i < current->layers[layer].num_links; i++) { + hnswNode *neighbor = current->layers[layer].links[i]; + + if (neighbor->visited_epoch[slot] == index->current_epoch[slot]) + continue; // Already visited during this scan. + + neighbor->visited_epoch[slot] = index->current_epoch[slot]; + float neighbor_dist = hnsw_distance(index, query, neighbor); + // Add to results if better than current max or results not full. + furthest = pq_max_distance(results); + if (neighbor_dist < furthest || results->count < ef) { + pq_push(candidates, neighbor, neighbor_dist); + pq_push(results, neighbor, neighbor_dist); + } + } + } + + pq_free(candidates); + return results; +} + +/* This function is used in order to initialize a node allocated in the + * function stack with the specified vector. The idea is that we can + * easily use hnsw_distance() from a vector and the HNSW nodes this way: + * + * hnswNode myQuery; + * hnsw_init_tmp_node(myIndex,&myQuery,0,some_vector); + * hnsw_distance(&myQuery, some_hnsw_node); + * + * Make sure to later free the node with: + * + * hnsw_free_tmp_node(&myQuery,some_vector); + * You have to pass the vector to the free function, because sometimes + * hnsw_init_tmp_node() may just avoid allocating a vector at all, + * just reusing 'some_vector' pointer. + * + * Return 0 on out of memory, 1 on success. + */ +int hnsw_init_tmp_node(HNSW *index, hnswNode *node, int is_normalized, const float *vector) { + node->vector = NULL; + + /* Work on a normalized query vector if the input vector is + * not normalized. */ + if (!is_normalized) { + node->vector = hmalloc(sizeof(float)*index->vector_dim); + if (node->vector == NULL) return 0; + memcpy(node->vector,vector,sizeof(float)*index->vector_dim); + hnsw_normalize_vector(node->vector,NULL,index->vector_dim); + } else { + node->vector = (float*)vector; + } + + /* If quantization is enabled, our query fake node should be + * quantized as well. */ + if (index->quant_type != HNSW_QUANT_NONE) { + void *quants = hmalloc(hnsw_quants_bytes(index)); + if (quants == NULL) { + if (node->vector != vector) hfree(node->vector); + return 0; + } + switch(index->quant_type) { + case HNSW_QUANT_Q8: + quantize_to_q8(node->vector, quants, index->vector_dim, &node->quants_range); + break; + case HNSW_QUANT_BIN: + quantize_to_bin(node->vector, quants, index->vector_dim); + } + if (node->vector != vector) hfree(node->vector); + node->vector = quants; + } + return 1; +} + +/* Free the stack allocated node initialized by hnsw_init_tmp_node(). */ +void hnsw_free_tmp_node(hnswNode *node, const float *vector) { + if (node->vector != vector) hfree(node->vector); +} + +/* Return approximated K-NN items. Note that neighbors and distances + * arrays must have space for at least 'k' items. + * norm_query should be set to 1 if the query vector is already + * normalized, otherwise, if 0, the function will copy the vector, + * L2-normalize the copy and search using the normalized version. */ +int hnsw_search(HNSW *index, const float *query_vector, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized) +{ + if (!index || !query_vector || !neighbors || k == 0) return -1; + if (!index->enter_point) return 0; // Empty index. + + /* Use a fake node that holds the query vector, this way we can + * use our normal node to node distance functions when checking + * the distance between query and graph nodes. */ + hnswNode query; + if (hnsw_init_tmp_node(index,&query,query_vector_is_normalized,query_vector) == 0) return -1; + + // Start searching from the entry point. + hnswNode *curr_ep = index->enter_point; + + /* Start from higher layer to layer 1 (layer 0 is handled later) + * in the next section. Descend to the most similar node found + * so far. */ + for (int lc = index->max_level; lc > 0; lc--) { + pqueue *results = search_layer(index, &query, curr_ep, 1, lc, slot); + if (!results) continue; + + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + + /* Search bottom layer (the most densely populated) with ef = k */ + pqueue *results = search_layer(index, &query, curr_ep, k, 0, slot); + if (!results) { + hnsw_free_tmp_node(&query, query_vector); + return -1; + } + + /* Copy results. */ + uint32_t found = MIN(k, results->count); + for (uint32_t i = 0; i < found; i++) { + neighbors[i] = pq_get_node(results,i); + if (distances) { + distances[i] = pq_get_distance(results,i); + } + } + + pq_free(results); + hnsw_free_tmp_node(&query, query_vector); + return found; +} + +/* Rescan a node and update the wortst neighbor index. + * The followinng two functions are variants of this function to be used + * when links are added or removed: they may do less work than a full scan. */ +void hnsw_update_worst_neighbor(HNSW *index, hnswNode *node, uint32_t layer) { + float worst_dist = 0; + uint32_t worst_idx = 0; + for (uint32_t i = 0; i < node->layers[layer].num_links; i++) { + float dist = hnsw_distance(index, node, node->layers[layer].links[i]); + if (dist > worst_dist) { + worst_dist = dist; + worst_idx = i; + } + } + node->layers[layer].worst_distance = worst_dist; + node->layers[layer].worst_idx = worst_idx; +} + +/* Update node worst neighbor distance information when a new neighbor + * is added. */ +void hnsw_update_worst_neighbor_on_add(HNSW *index, hnswNode *node, uint32_t layer, uint32_t added_index, float distance) { + (void) index; // Unused but here for API symmetry. + if (node->layers[layer].num_links == 1 || // First neighbor? + distance > node->layers[layer].worst_distance) // New worst? + { + node->layers[layer].worst_distance = distance; + node->layers[layer].worst_idx = added_index; + } +} + +/* Update node worst neighbor distance information when a linked neighbor + * is removed. */ +void hnsw_update_worst_neighbor_on_remove(HNSW *index, hnswNode *node, uint32_t layer, uint32_t removed_idx) +{ + if (node->layers[layer].num_links == 0) { + node->layers[layer].worst_distance = 0; + node->layers[layer].worst_idx = 0; + } else if (removed_idx == node->layers[layer].worst_idx) { + hnsw_update_worst_neighbor(index,node,layer); + } else if (removed_idx < node->layers[layer].worst_idx) { + // Just update index if we removed element before worst. + node->layers[layer].worst_idx--; + } +} + +/* We have a list of candidate nodes to link to the new node, when iserting + * one. This function selects which nodes to link and performs the linking. + * + * Parameters: + * + * - 'candidates' is the priority queue of potential good nodes to link to the + * new node 'new_node'. + * - 'required_links' is as many links we would like our new_node to get + * at the specified layer. + * - 'aggressive' changes the startegy used to find good neighbors as follows: + * + * This function is called with aggressive=0 for all the layers, including + * layer 0. When called like that, it will use the diversity of links and + * quality of links checks before linking our new node with some candidate. + * + * However if the insert function finds that at layer 0, with aggressive=0, + * few connections were made, it calls this function again with agressiveness + * levels greater up to 2. + * + * At aggressive=1, the diversity checks are disabled, and the candidate + * node for linking is accepted even if it is nearest to an already accepted + * neighbor than it is to the new node. + * + * When we link our new node by replacing the link of a candidate neighbor + * that already has the max number of links, inevitably some other node loses + * a connection (to make space for our new node link). In this case: + * + * 1. If such "dropped" node would remain with too little links, we try with + * some different neighbor instead, however as the 'aggressive' paramter + * has incremental values (0, 1, 2) we are more and more willing to leave + * the dropped node with fever connections. + * 2. If aggressive=2, we will scan the candidate neighbor node links to + * find a different linked-node to replace, one better connected even if + * its distance is not the worse. + * + * Note: this function is also called during deletion of nodes in order to + * provide certain nodes with additional links. + */ +void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, + uint32_t layer, uint32_t required_links, int aggressive) +{ + uint32_t max_links = (layer == 0) ? HNSW_M0 : HNSW_M; + + for (uint32_t i = 0; i < candidates->count; i++) { + hnswNode *neighbor = pq_get_node(candidates,i); + if (neighbor == new_node) continue; // Don't link node with itself. + + /* Use our cached distance among the new node and the candidate. */ + float dist = pq_get_distance(candidates,i); + + /* First of all, since our links are all bidirectional, if the + * new node for any reason has no longer room, or if it accumulated + * the required number of links, return ASAP. */ + if (new_node->layers[layer].num_links >= new_node->layers[layer].max_links || + new_node->layers[layer].num_links >= required_links) return; + + /* If aggressive is true, it is possible that the new node + * already got some link among the candidates (see the top comment, + * this function gets re-called in case of too few links). + * So we need to check if this candidate is already linked to + * the new node. */ + if (aggressive) { + int duplicated = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + if (new_node->layers[layer].links[j] == neighbor) { + duplicated = 1; + break; + } + } + if (duplicated) continue; + } + + /* Diversity check. We accept new candidates + * only if there is no element already accepted that is nearest + * to the candidate than the new element itself. + * However this check is disabled if we have pressure to find + * new links (aggressive != 0) */ + if (!aggressive) { + int diversity_failed = 0; + for (uint32_t j = 0; j < new_node->layers[layer].num_links; j++) { + float link_dist = hnsw_distance(index, neighbor, + new_node->layers[layer].links[j]); + if (link_dist < dist) { + diversity_failed = 1; + break; + } + } + if (diversity_failed) continue; + } + + /* If potential neighbor node has space, simply add the new link. + * We will have space as well. */ + uint32_t n = neighbor->layers[layer].num_links; + if (n < max_links) { + /* Link candidate to new node. */ + neighbor->layers[layer].links[n] = new_node; + neighbor->layers[layer].num_links++; + + /* Update candidate worst link info. */ + hnsw_update_worst_neighbor_on_add(index,neighbor,layer,n,dist); + + /* Link new node to candidate. */ + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + /* Update new node worst link info. */ + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + continue; + } + + /* ==================================================================== + * Replacing existing candidate neighbor link step. + * ================================================================== */ + + /* If we are here, our accepted candidate for linking is full. + * + * If new node is more distant to candidate than its current worst link + * then we skip it: we would not be able to establish a bidirectional + * connection without compromising link quality of candidate. + * + * At aggressiveness > 0 we don't care about this check. */ + if (!aggressive && dist >= neighbor->layers[layer].worst_distance) + continue; + + /* We can add it: we are ready to replace the candidate neighbor worst + * link with the new node, assuming certain conditions are met. */ + hnswNode *worst_node = neighbor->layers[layer].links[neighbor->layers[layer].worst_idx]; + + /* The worst node linked to our candidate may remain too disconnected + * if we remove the candidate node as its link. Let's check if + * this is the case: */ + if (aggressive == 0 && + worst_node->layers[layer].num_links <= HNSW_M/2) + continue; + + /* Aggressive level = 1. It's ok if the node remains with just + * HNSW_M/4 links. */ + else if (aggressive == 1 && + worst_node->layers[layer].num_links <= HNSW_M/4) + continue; + + /* If aggressive is set to 2, then the new node we are adding failed + * to find enough neighbors. We can't insert an almost orphaned new + * node, so let's see if the target node has some other link + * that is well connected in the graph: we could drop it instead + * of the worst link. */ + if (aggressive == 2 && worst_node->layers[layer].num_links <= HNSW_M/4) + { + /* Let's see if we can find at least a candidate link that + * would remain with a few connections. Track the one + * that is the farest away (worst distance) from our candidate + * neighbor (in order to remove the less interesting link). */ + worst_node = NULL; + uint32_t worst_idx = 0; + float max_dist = 0; + for (uint32_t j = 0; j < neighbor->layers[layer].num_links; j++) { + hnswNode *to_drop = neighbor->layers[layer].links[j]; + + /* Skip this if it would remain too disconnected as well. + * + * NOTE about HNSW_M/4 min connections requirement: + * + * It is not too strict, since leaving a node with just a + * single link does not just leave it too weakly connected, but + * also sometimes creates cycles with few disconnected + * nodes linked among them. */ + if (to_drop->layers[layer].num_links <= HNSW_M/4) continue; + + float link_dist = hnsw_distance(index, neighbor, to_drop); + if (worst_node == NULL || link_dist > max_dist) { + worst_node = to_drop; + max_dist = link_dist; + worst_idx = j; + } + } + + if (worst_node != NULL) { + /* We found a node that we can drop. Let's pretend this is + * the worst node of the candidate to unify the following + * code path. Later we will fix the worst node info anyway. */ + neighbor->layers[layer].worst_distance = max_dist; + neighbor->layers[layer].worst_idx = worst_idx; + } else { + /* Otherwise we have no other option than reallocating + * the max number of links for this target node, and + * ensure at least a few connections for our new node. + * + * XXX: Implement this part. */ + debugmsg("Node overbooking needed: allocate more\n"); + continue; + } + } + + // Remove backlink from the worst node of our candidate. + for (uint64_t j = 0; j < worst_node->layers[layer].num_links; j++) { + if (worst_node->layers[layer].links[j] == neighbor) { + memmove(&worst_node->layers[layer].links[j], + &worst_node->layers[layer].links[j+1], + (worst_node->layers[layer].num_links - j - 1) * sizeof(hnswNode*)); + worst_node->layers[layer].num_links--; + hnsw_update_worst_neighbor_on_remove(index,worst_node,layer,j); + break; + } + } + + /* Replace worst link with the new node. */ + neighbor->layers[layer].links[neighbor->layers[layer].worst_idx] = new_node; + + /* Update the worst link in the target node, at this point + * the link that we replaced may no longer be the worst. */ + hnsw_update_worst_neighbor(index,neighbor,layer); + + // Add new node -> candidate link. + uint32_t new_links = new_node->layers[layer].num_links; + new_node->layers[layer].links[new_links] = neighbor; + new_node->layers[layer].num_links++; + + // Update new node worst link. + hnsw_update_worst_neighbor_on_add(index,new_node,layer,new_links,dist); + } +} + +/* This function implements node reconnection after a node deletion in HNSW. + * When a node is deleted, other nodes at the specified layer lose one + * connection (all the neighbors of the deleted node). This function attempts + * to pair such nodes together in a way that maximizes connection quality + * among the M nodes that were former neighbors of our deleted node. + * + * The algorithm works by first building a distance matrix among the nodes: + * + * N0 N1 N2 N3 + * N0 0 1.2 0.4 0.9 + * N1 1.2 0 0.8 0.5 + * N2 0.4 0.8 0 1.1 + * N3 0.9 0.5 1.1 0 + * + * For each potential pairing (i,j) we compute a score that combines: + * 1. The direct cosine distance between the two nodes + * 2. The average distance to other nodes that would no longer be + * available for pairing if we select this pair + * + * We want to balance local node-to-node requirements and global requirements. + * For instance sometimes connecting A with B, while optimal, would leave + * C and D to be connected without other choices, and this could be a very + * bad connection. Maybe instead A and C and B and D are both relatively high + * quality connections. + * + * The formula used to calculate the score of each connection is: + * + * score[i,j] = W1*(2-distance[i,j]) + W2*((new_avg_i + new_avg_j)/2) + * where new_avg_x is the average of distances in row x excluding distance[i,j] + * + * So the score is directly proportional to the SIMILARITY of the two nodes + * and also directly proportional to the DISTANCE of the potential other + * connections that we lost by pairign i,j. So we have a cost for missed + * opportunities, or better, in this case, a reward if the missing + * opportunities are not so good (big average distance). + * + * W1 and W2 are weights (defaults: 0.7 and 0.3) that determine the relative + * importance of immediate connection quality vs future pairing potential. + * + * After the initial pairing phase, any nodes that couldn't be paired + * (due to odd count or existing connections) are handled by searching + * the broader graph using the standard HNSW neighbor selection logic. + */ +void hnsw_reconnect_nodes(HNSW *index, hnswNode **nodes, int count, uint32_t layer) { + if (count <= 0) return; + debugmsg("Reconnecting %d nodes\n", count); + + /* Step 1: Build the distance matrix between all nodes. + * Since distance(i,j) = distance(j,i), we only compute the upper triangle + * and mirror it to the lower triangle. */ + float *distances = hmalloc(count * count * sizeof(float)); + if (!distances) return; + + for (int i = 0; i < count; i++) { + distances[i*count + i] = 0; // Distance to self is 0 + for (int j = i+1; j < count; j++) { + float dist = hnsw_distance(index, nodes[i], nodes[j]); + distances[i*count + j] = dist; // Upper triangle. + distances[j*count + i] = dist; // Lower triangle. + } + } + + /* Step 2: Calculate row averages (will be used in scoring): + * please note that we just calculate row averages and not + * colums averages since the matrix is symmetrical, so those + * are the same: check the image in the top comment if you have any + * doubt about this. */ + float *row_avgs = hmalloc(count * sizeof(float)); + if (!row_avgs) { + hfree(distances); + return; + } + + for (int i = 0; i < count; i++) { + float sum = 0; + int valid_count = 0; + for (int j = 0; j < count; j++) { + if (i != j) { + sum += distances[i*count + j]; + valid_count++; + } + } + row_avgs[i] = valid_count ? sum / valid_count : 0; + } + + /* Step 3: Build scoring matrix. What we do here is to combine how + * good is a given i,j nodes connection, with how badly connecting + * i,j will affect the remaining quality of connections left to + * pair the other nodes. */ + float *scores = hmalloc(count * count * sizeof(float)); + if (!scores) { + hfree(distances); + hfree(row_avgs); + return; + } + + /* Those weights were obtained manually... No guarantee that they + * are optimal. However with these values the algorithm is certain + * better than its greedy version that just attempts to pick the + * best pair each time (verified experimentally). */ + const float W1 = 0.7; // Weight for immediate distance. + const float W2 = 0.3; // Weight for future potential. + + for (int i = 0; i < count; i++) { + for (int j = 0; j < count; j++) { + if (i == j) { + scores[i*count + j] = -1; // Invalid pairing. + continue; + } + + // Check for existing connection between i and j. + int already_linked = 0; + for (uint32_t k = 0; k < nodes[i]->layers[layer].num_links; k++) + { + if (nodes[i]->layers[layer].links[k] == nodes[j]) { + scores[i*count + j] = -1; // Already linked. + already_linked = 1; + break; + } + } + if (already_linked) continue; + + float dist = distances[i*count + j]; + + /* Calculate new averages excluding this pair. + * Handle edge case where we might have too few elements. + * Note that it would be not very smart to recompute the average + * each time scanning the row, we can remove the element + * and adjust the average without it. */ + float new_avg_i = 0, new_avg_j = 0; + if (count > 2) { + new_avg_i = (row_avgs[i] * (count-1) - dist) / (count-2); + new_avg_j = (row_avgs[j] * (count-1) - dist) / (count-2); + } + + /* Final weighted score: the more similar i,j, the better + * the score. The more distant are the pairs we lose by + * connecting i,j, the better the score. */ + scores[i*count + j] = W1*(2-dist) + W2*((new_avg_i + new_avg_j)/2); + } + } + + // Step 5: Pair nodes greedily based on scores. + int *used = calloc(count, sizeof(int)); + if (!used) { + hfree(distances); + hfree(row_avgs); + hfree(scores); + return; + } + + /* Scan the matrix looking each time for the potential + * link with the best score. */ + while(1) { + float max_score = -1; + int best_j = -1, best_i = -1; + + // Seek best score i,j values. + for (int i = 0; i < count; i++) { + if (used[i]) continue; // Already connected. + + /* No space left? Not possible after a node deletion but makes + * this function more future-proof. */ + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) continue; + + for (int j = 0; j < count; j++) { + if (i == j) continue; // Same node, skip. + if (used[j]) continue; // Already connected. + float score = scores[i*count + j]; + if (score < 0) continue; // Invalid link. + + /* If the target node has space, and its score is better + * than any other seen so far... remember it is the best. */ + if (score > max_score && + nodes[j]->layers[layer].num_links < + nodes[j]->layers[layer].max_links) + { + // Track the best connection found so far. + max_score = score; + best_j = j; + best_i = i; + } + } + } + + // Possible link found? Connect i and j. + if (best_j != -1) { + debugmsg("[%d] linking %d with %d: %f\n", layer, (int)best_i, (int)best_j, max_score); + // Link i -> j. + int link_idx = nodes[best_i]->layers[layer].num_links; + nodes[best_i]->layers[layer].links[link_idx] = nodes[best_j]; + nodes[best_i]->layers[layer].num_links++; + + // Update worst distance if needed. + float dist = distances[best_i*count + best_j]; + hnsw_update_worst_neighbor_on_add(index,nodes[best_i],layer,link_idx,dist); + + // Link j -> i. + link_idx = nodes[best_j]->layers[layer].num_links; + nodes[best_j]->layers[layer].links[link_idx] = nodes[best_i]; + nodes[best_j]->layers[layer].num_links++; + + // Update worst distance if needed. + hnsw_update_worst_neighbor_on_add(index,nodes[best_j],layer,link_idx,dist); + + // Mark connection as used. + used[best_i] = used[best_j] = 1; + } else { + break; // No more valid connections available. + } + } + + /* Step 6: Handle remaining unpaired nodes using the standard HNSW + * neighbor selection. */ + for (int i = 0; i < count; i++) { + if (used[i]) continue; + + // Skip if node is already at max connections. + if (nodes[i]->layers[layer].num_links >= + nodes[i]->layers[layer].max_links) + continue; + + debugmsg("[%d] Force linking %d\n", layer, i); + + /* First, try with local nodes as candidates. + * Some candidate may have space. */ + pqueue *candidates = pq_new(count); + if (!candidates) continue; + + /* Add all the local nodes having some space as candidates + * to be linked with this node. */ + for (int j = 0; j < count; j++) { + if (i != j && // Must not be itself. + nodes[j]->layers[layer].num_links < // Must not be full. + nodes[j]->layers[layer].max_links) + { + float dist = distances[i*count + j]; + pq_push(candidates, nodes[j], dist); + } + } + + /* Try local candidates first with aggressive = 1. + * So we will link only if there is space. + * We want one link more than the links we already have. */ + uint32_t wanted_links = nodes[i]->layers[layer].num_links+1; + if (candidates->count > 0) { + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, 1); + debugmsg("Final links after attempt with local nodes: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + } + + // If still no connection, search the broader graph. + if (nodes[i]->layers[layer].num_links != wanted_links) { + debugmsg("No force linking possible with local candidats\n"); + pq_free(candidates); + + // Find entry point for target layer by descending through levels. + hnswNode *curr_ep = index->enter_point; + for (uint32_t lc = index->max_level; lc > layer; lc--) { + pqueue *results = search_layer(index, nodes[i], curr_ep, 1, lc, 0); + if (results) { + if (results->count > 0) { + curr_ep = pq_get_node(results,0); + } + pq_free(results); + } + } + + if (curr_ep) { + /* Search this layer for candidates. + * Use the defalt EF_C in this case, since it's not an + * "insert" operation, and we don't know the user + * specified "EF". */ + candidates = search_layer(index, nodes[i], curr_ep, HNSW_EF_C, layer, 0); + if (candidates) { + /* Try to connect with aggressiveness proportional to the + * node linking condition. */ + int aggressiveness = + nodes[i]->layers[layer].num_links > HNSW_M / 2 ? 1 : 2; + select_neighbors(index, candidates, nodes[i], layer, wanted_links, aggressiveness); + debugmsg("Final links with broader search: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); + pq_free(candidates); + } + } + } else { + pq_free(candidates); + } + } + + // Cleanup. + hfree(distances); + hfree(row_avgs); + hfree(scores); + hfree(used); +} + +/* This is an helper function in order to support node deletion. + * It's goal is just to: + * + * 1. Remove the node from the bidirectional links of neighbors in the graph. + * 2. Remove the node from the linked list of nodes. + * 3. Fix the entry point in the graph. We just select one of the neighbors + * of the deleted node at a lower level. If none is found, we do + * a full scan. + * 4. The node itself amd its aux value field are NOT freed. It's up to the + * caller to do it, by using hnsw_node_free(). + * 5. The node associated value (node->value) is NOT freed. + * + * Why this function will not free the node? Because in node updates it + * could be a good idea to reuse the node allocation for different reasons + * (currently not implemented). + * In general it is more future-proof to be able to reuse the node if + * needed. Right now this library reuses the node only when links are + * not touched (see hnsw_update() for more information). */ +void hnsw_unlink_node(HNSW *index, hnswNode *node) { + if (!index || !node) return; + + index->version++; // This node may be missing in an already compiled list + // of neighbors. Make optimistic concurrent inserts fail. + + /* Remove all bidirectional links at each level. + * Note that in this implementation all the + * links are guaranteed to be bedirectional. */ + + /* For each level of the deleted node... */ + for (uint32_t level = 0; level <= node->level; level++) { + /* For each linked node of the deleted node... */ + for (uint32_t i = 0; i < node->layers[level].num_links; i++) { + hnswNode *linked = node->layers[level].links[i]; + /* Find and remove the backlink in the linked node */ + for (uint32_t j = 0; j < linked->layers[level].num_links; j++) { + if (linked->layers[level].links[j] == node) { + /* Remove by shifting remaining links left */ + memmove(&linked->layers[level].links[j], + &linked->layers[level].links[j + 1], + (linked->layers[level].num_links - j - 1) * sizeof(hnswNode*)); + linked->layers[level].num_links--; + hnsw_update_worst_neighbor_on_remove(index,linked,level,j); + break; + } + } + } + } + + /* Update cursors pointing at this element. */ + if (index->cursors) hnsw_cursor_element_deleted(index,node); + + /* Update the previous node's next pointer. */ + if (node->prev) { + node->prev->next = node->next; + } else { + /* If there's no previous node, this is the head. */ + index->head = node->next; + } + + /* Update the next node's prev pointer. */ + if (node->next) node->next->prev = node->prev; + + /* Update node count. */ + index->node_count--; + + /* If this node was the enter_point, we need to update it. */ + if (node == index->enter_point) { + /* Reset entry point - we'll find a new one (unless the HNSW is + * now empty) */ + index->enter_point = NULL; + index->max_level = 0; + + /* Step 1: Try to find a replacement by scanning levels + * from top to bottom. Under normal conditions, if there is + * any other node at the same level, we have a link. Anyway + * we descend levels to find any neighbor at the higher level + * possible. */ + for (int level = node->level; level >= 0; level--) { + if (node->layers[level].num_links > 0) { + index->enter_point = node->layers[level].links[0]; + break; + } + } + + /* Step 2: If no links were found at any level, do a full scan. + * This should never happen in practice if the HNSW is not + * empty. */ + if (!index->enter_point) { + uint32_t new_max_level = 0; + hnswNode *current = index->head; + + while (current) { + if (current != node && current->level >= new_max_level) { + new_max_level = current->level; + index->enter_point = current; + } + current = current->next; + } + } + + /* Update max_level. */ + if (index->enter_point) + index->max_level = index->enter_point->level; + } + + /* Clear the node's links but don't free the node itself */ + node->prev = node->next = NULL; +} + +/* Higher level API for hnsw_unlink_node() + hnsw_reconnect_nodes() actual work. + * This will get the write lock, will delete the node, free it, + * reconnect the node neighbors among themselves, and unlock again. + * If free_value function pointer is not NULL, then the function provided is + * used to free node->value. */ +void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)) { + pthread_rwlock_wrlock(&index->global_lock); + hnsw_unlink_node(index,node); + if (free_value && node->value) free_value(node->value); + + /* Relink all the nodes orphaned of this node link. + * Do it for all the levels. */ + for (unsigned int j = 0; j <= node->level; j++) { + hnsw_reconnect_nodes(index, node->layers[j].links, + node->layers[j].num_links, j); + } + hnsw_node_free(node); + pthread_rwlock_unlock(&index->global_lock); +} + +/* ============================ Threaded API ================================ + * Concurent readers should use the following API to get a slot assigned + * (and a lock, too), do their read-only call, and unlock the slot. + * + * There is a reason why read operations don't implement opaque transparent + * locking directly on behalf of the user: when we return a result set + * with hnsw_search(), we report a set of nodes. The caller will do something + * with the nodes and the associated values, so the unlocking of the + * slot should happen AFTER the result was already used, otherwise we may + * have changes to the HNSW nodes as the result is being accessed. */ + +/* Try to acquire a read slot. Returns the slot number (0 to HNSW_MAX_THREADS-1) + * on success, -1 on error (pthread mutex errors). */ +int hnsw_acquire_read_slot(HNSW *index) { + /* First try a non-blocking approach on all slots. */ + for (uint32_t i = 0; i < HNSW_MAX_THREADS; i++) { + if (pthread_mutex_trylock(&index->slot_locks[i]) == 0) { + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[i]); + return -1; + } + return i; + } + } + + /* All trylock attempts failed, use atomic increment to select slot. */ + uint32_t slot = index->next_slot++ % HNSW_MAX_THREADS; + + /* Try to lock the selected slot. */ + if (pthread_mutex_lock(&index->slot_locks[slot]) != 0) return -1; + + /* Get read lock. */ + if (pthread_rwlock_rdlock(&index->global_lock) != 0) { + pthread_mutex_unlock(&index->slot_locks[slot]); + return -1; + } + + return slot; +} + +/* Release a previously acquired read slot: note that it is important that + * nodes returned by hnsw_search() are accessed while the read lock is + * still active, to be sure that nodes are not freed. */ +void hnsw_release_read_slot(HNSW *index, int slot) { + if (slot < 0 || slot >= HNSW_MAX_THREADS) return; + pthread_rwlock_unlock(&index->global_lock); + pthread_mutex_unlock(&index->slot_locks[slot]); +} + +/* ============================ Nodes insertion ============================= + * We have an optimistic API separating the read-only candidates search + * and the write side (actual node insertion). We internally also use + * this API to provide the plain hnsw_insert() function for code unification. */ + +struct InsertContext { + pqueue *level_queues[HNSW_MAX_LEVEL]; /* Candidates for each level. */ + hnswNode *node; /* Pre-allocated node ready for insertion */ + uint64_t version; /* Index version at preparation time. This is used + * for CAS-like locking during change commit. */ +}; + +/* Optimistic insertion API. + * + * WARNING: Note that this is an internal function: users should call + * hnsw_prepare_insert() instead. + * + * This is how it works: you use hnsw_prepare_insert() and it will return + * a context where good candidate neighbors are already pre-selected. + * This step only uses read locks. + * + * Then finally you try to actually commit the new node with + * hnsw_try_commit_insert(): this time we will require a write lock, but + * for less time than it would be otherwise needed if using directly + * hnsw_insert(). When you try to commit the write, if no node was deleted in + * the meantime, your operation will succeed, otherwise it will fail, and + * you should try to just use the hnsw_insert() API, since there is + * contention. + * + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. */ +InsertContext *hnsw_prepare_insert_nolock(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, void *value, + int slot, int ef) +{ + InsertContext *ctx = hmalloc(sizeof(*ctx)); + if (!ctx) return NULL; + + memset(ctx, 0, sizeof(*ctx)); + ctx->version = index->version; + + /* Crete a new node that we may be able to insert into the + * graph later, when calling the commit function. */ + uint32_t level = random_level(); + ctx->node = hnsw_node_new(index, id, vector, qvector, qrange, level); + if (!ctx->node) { + hfree(ctx); + return NULL; + } + ctx->node->value = value; + + hnswNode *curr_ep = index->enter_point; + + /* Empty graph, no need to collect candidates. */ + if (curr_ep == NULL) return ctx; + + /* Phase 1: Find good entry point on the highest level of the new + * node we are going to insert. */ + for (unsigned int lc = index->max_level; lc > level; lc--) { + pqueue *results = search_layer(index, ctx->node, curr_ep, 1, lc, slot); + + if (results) { + if (results->count > 0) curr_ep = pq_get_node(results,0); + pq_free(results); + } + } + + /* Phase 2: Collect a set of potential connections for each layer of + * the new node. */ + for (int lc = MIN(level, index->max_level); lc >= 0; lc--) { + pqueue *candidates = + search_layer(index, ctx->node, curr_ep, ef, lc, slot); + + if (!candidates) continue; + curr_ep = (candidates->count > 0) ? pq_get_node(candidates,0) : curr_ep; + ctx->level_queues[lc] = candidates; + } + + return ctx; +} + +/* External API for hnsw_prepare_insert_nolock(), handling locking. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, + const int8_t *qvector, float qrange, uint64_t id, void *value, + int ef) +{ + InsertContext *ctx; + int slot = hnsw_acquire_read_slot(index); + ctx = hnsw_prepare_insert_nolock(index,vector,qvector,qrange,id,value,slot,ef); + hnsw_release_read_slot(index,slot); + return ctx; +} + +/* Free an insert context and all its resources. */ +void hnsw_free_insert_context(InsertContext *ctx) { + if (!ctx) return; + for (uint32_t i = 0; i < HNSW_MAX_LEVEL; i++) { + if (ctx->level_queues[i]) pq_free(ctx->level_queues[i]); + } + if (ctx->node) hnsw_node_free(ctx->node); + hfree(ctx); +} + +/* Commit a prepared insert operation. This function is a low level API that + * should not be called by the user. See instead hnsw_try_commit_insert(), that + * will perform the CAS check and acquire the write lock. + * + * See the top comment in hnsw_prepare_insert() for more information + * on the optimistic insertion API. + * + * This function can't fail and always returns the pointer to the + * just inserted node. Out of memory is not possible since no critical + * allocation is never performed in this code path: we populate links + * on already allocated nodes. */ +hnswNode *hnsw_commit_insert_nolock(HNSW *index, InsertContext *ctx) { + hnswNode *node = ctx->node; + + /* Handle first node case. */ + if (index->enter_point == NULL) { + index->version++; // First node, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free it. + hnsw_free_insert_context(ctx); + return node; + } + + /* Connect the node with near neighbors at each level. */ + for (int lc = MIN(node->level,index->max_level); lc >= 0; lc--) { + if (ctx->level_queues[lc] == NULL) continue; + + /* Try to provide HNSW_M connections to our node. The call + * is not guaranteed to be able to provide all the links we would + * like to have for the new node: they must be bi-directional, obey + * certain quality checks, and so forth, so later there are further + * calls to force the hand a bit if needed. + * + * Let's start with aggressiveness = 0. */ + select_neighbors(index, ctx->level_queues[lc], node, lc, HNSW_M, 0); + + /* Layer 0 and too few connections? Let's be more aggressive. */ + if (lc == 0 && node->layers[0].num_links < HNSW_M/2) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + HNSW_M, 1); + + /* Still too few connections? Let's go to + * aggressiveness level '2' in linking strategy. */ + if (node->layers[0].num_links < HNSW_M/4) { + select_neighbors(index, ctx->level_queues[lc], node, lc, + HNSW_M/4, 2); + } + } + } + + /* If new node level is higher than current max, update entry point. */ + if (node->level > index->max_level) { + index->version++; // Entry point changed, make concurrent inserts fail. + index->enter_point = node; + index->max_level = node->level; + } + + /* Add node to the linked list. */ + hnsw_add_node(index, node); + ctx->node = NULL; // So hnsw_free_insert_context() will not free the node. + hnsw_free_insert_context(ctx); + return node; +} + +/* If the context obtained with hnsw_prepare_insert() is still valid + * (nodes not deleted in the meantime) then add the new node to the HNSW + * index and return its pointer. Otherwise NULL is returned and the operation + * should be either performed with the blocking API hnsw_insert() or attempted + * again. */ +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx) { + /* Check if the version changed since preparation. Note that we + * should access index->version under the write lock in order to + * be sure we can safely commit the write: this is just a fast-path + * in order to return ASAP without acquiring the write lock in case + * the version changed. */ + if (ctx->version != index->version) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Try to acquire write lock. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) { + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Check version again under write lock. */ + if (ctx->version != index->version) { + pthread_rwlock_unlock(&index->global_lock); + hnsw_free_insert_context(ctx); + return NULL; + } + + /* Commit the change: note that it's up to hnsw_commit_insert_nolock() + * to free the insertion context. */ + hnswNode *node = hnsw_commit_insert_nolock(index, ctx); + + /* Release the write lock. */ + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Insert a new element into the graph. + * See hnsw_node_new() for information about 'vector' and 'qvector' + * arguments, and which one to pass. + * + * Return NULL on out of memory during insert. Otherwise the newly + * inserted node pointer is returned. */ +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef) { + /* Write lock. We acquire the write lock even for the prepare() + * operation (that is a read-only operation) since we want this function + * to don't fail in the check-and-set stage of commit(). + * + * Basically here we are using the optimistic API in a non-optimistinc + * way in order to have a single insertion code in the implementation. */ + if (pthread_rwlock_wrlock(&index->global_lock) != 0) return NULL; + + // Prepare the insertion - note we pass slot 0 since we're single threaded. + InsertContext *ctx = hnsw_prepare_insert_nolock(index, vector, qvector, + qrange, id, value, 0, ef); + if (!ctx) { + pthread_rwlock_unlock(&index->global_lock); + return NULL; + } + + // Commit the prepared insertion without version checking. + hnswNode *node = hnsw_commit_insert_nolock(index, ctx); + + // Release write lock and return our node pointer. + pthread_rwlock_unlock(&index->global_lock); + return node; +} + +/* Helper function for qsort call in hnsw_should_reuse_node(). */ +static int compare_floats(const float *a, const float *b) { + if (*a < *b) return 1; + if (*a > *b) return -1; + return 0; +} + +/* This function determines if a node can be reused with a new vector by: + * + * 1. Computing average of worst 25% of current distances. + * 2. Checking if at least 50% of new distances stay below this threshold. + * 3. Requiring a minimum number of links for the check to be meaningful. + * + * This check is useful when we want to just update a node that already + * exists in the graph. Often the new vector is a learned embedding generated + * by some model, and the embedding represents some document that perhaps + * changed just slightly compared to the past, so the new embedding will + * be very nearby. We need to find a way do determine if the current node + * neighbors (practically speaking its location in the grapb) are good + * enough even with the new vector. + * + * XXX: this function needs improvements: successive updates to the same + * node with more and more distant vectors will make the node drift away + * from its neighbors. One of the additional metrics used could be + * neighbor-to-neighbor distance, that represents a more absolute check + * of fit for the new vector. */ +int hnsw_should_reuse_node(HNSW *index, hnswNode *node, int is_normalized, const float *new_vector) { + /* Step 1: Not enough links? Advice to avoid reuse. */ + const uint32_t min_links_for_reuse = 4; + uint32_t layer0_connections = node->layers[0].num_links; + if (layer0_connections < min_links_for_reuse) return 0; + + /* Step2: get all current distances and run our heuristic. */ + float *old_distances = hmalloc(sizeof(float) * layer0_connections); + if (!old_distances) return 0; + + // Temporary node with the new vector, to simplify the next logic. + hnswNode tmp_node; + if (hnsw_init_tmp_node(index,&tmp_node,is_normalized,new_vector) == 0) { + hfree(old_distances); + return 0; + } + + /* Get old dinstances and sort them to access the 25% worst + * (bigger) ones. */ + for (uint32_t i = 0; i < layer0_connections; i++) { + old_distances[i] = hnsw_distance(index, node, node->layers[0].links[i]); + } + qsort(old_distances, layer0_connections, sizeof(float), + (int (*)(const void*, const void*))(&compare_floats)); + + uint32_t count = (layer0_connections+3)/4; // 25% approx to larger int. + if (count > layer0_connections) count = layer0_connections; // Futureproof. + float worst_avg = 0; + + // Compute average of 25% worst dinstances. + for (uint32_t i = 0; i < count; i++) worst_avg += old_distances[i]; + worst_avg /= count; + hfree(old_distances); + + // Count how many new distances stay below the threshold. + uint32_t good_distances = 0; + for (uint32_t i = 0; i < layer0_connections; i++) { + float new_dist = hnsw_distance(index, &tmp_node, node->layers[0].links[i]); + if (new_dist <= worst_avg) good_distances++; + } + hnsw_free_tmp_node(&tmp_node,new_vector); + + /* At least 50% of the nodes should pass our quality test, for the + * node to be reused. */ + return good_distances >= layer0_connections/2; +} + +/* ============================= Serialization ============================== + * + * TO SERIALIZE + * ============ + * + * To serialize on disk, you need to persist the vector dimension, number + * of elements, and the quantization type index->quant_type. These are + * global values for the whole index. + * + * Then, to serialize each node: + * + * call hnsw_serialize_node() with each node you find in the linked list + * of nodes, starting at index->head (each node has a next pointer). + * The function will return an hnswSerNode structure, you will need + * to store the following on disk (for each node): + * + * - The sernode->vector data, that is sernode->vector_size bytes. + * - The sernode->params array, that points to an array of uint64_t + * integers. There are sernode->params_count total items. These + * parameters contain everything there is to need about your node: how + * many levels it has, its ID, the list of neighbors for each level (as node + * IDs), and so forth. + * + * You need to to save your own node->value in some way as well, but it already + * belongs to the user of the API, since, for this library, it's just a pointer, + * so the user should know how to serialized its private data. + * + * RELOADING FROM DISK / NET + * ========================= + * + * When reloading nodes, you first load the index vector dimension and + * quantization type, and create the index with: + * + * HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type); + * + * Then you load back, for each node (you stored how many nodes you had) + * the vector and the params array / count. + * You also load the value associated with your node. + * + * At this point you add back the loaded elements into the index with: + * + * hnsw_insert_serialized(HNSW *index, void *vector, uint64_t params, + * uint32_t params_len, void *value); + * + * Once you added all the nodes back, you need to resolve the pointers + * (since so far they are added just with the node IDs as reference), so + * you call: + * + * hnsw_deserialize_index(index); + * + * The index is now ready to be used like if it has been always in memory. + * + * DESIGN NOTES + * ============ + * + * Why this API does not just give you a binary blob to save? Because in + * many systems (and in Redis itself) to save integers / floats can have + * more interesting encodings that just storing a 64 bit value. Many vector + * indexes will be small, and their IDs will be small numbers, so the storage + * system can exploit that and use less disk space, less network bandwidth + * and so forth. + * + * How is the data stored in these arrays of numbers? Oh well, we have + * things that are obviously numbers like node ID, number of levels for the + * node and so forth. Also each of our nodes have an unique incremental ID, + * so we can store a node set of links in terms of linked node IDs. This + * data is put directly in the loaded node pointer space! We just cast the + * integer to the pointer (so THIS IS NOT SAFE for 32 bit systems). Then + * we want to translate such IDs into pointers. To do that, we build an + * hash table, then scan all the nodes again and fix all the links converting + * the ID to the pointer. */ + +/* Return the serialized node information as specified in the top comment + * above. Note that the returned information is true as long as the node + * provided is not deleted or modified, so this function should be called + * when there are no concurrent writes. + * + * The function hnsw_serialize_node() should be called in order to + * free the result of this function. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node) { + /* The first step is calculating the number of uint64_t parameters + * that we need in order to serialize the node. */ + uint32_t num_params = 0; + num_params += 2; // node ID, number of layers. + for (uint32_t i = 0; i <= node->level; i++) { + num_params += 2; // max_links and num_links info for this layer. + num_params += node->layers[i].num_links; // The IDs of linked nodes. + } + + /* We use another 64bit value to store two floats that are about + * the vector: l2 and quantization range (that is only used if the + * vector is quantized). */ + num_params++; + + /* Allocate the return object and the parameters array. */ + hnswSerNode *sn = hmalloc(sizeof(hnswSerNode)); + if (sn == NULL) return NULL; + sn->params = hmalloc(sizeof(uint64_t)*num_params); + if (sn->params == NULL) { + hfree(sn); + return NULL; + } + + /* Fill data. */ + sn->params_count = num_params; + sn->vector = node->vector; + sn->vector_size = hnsw_quants_bytes(index); + + uint32_t param_idx = 0; + sn->params[param_idx++] = node->id; + sn->params[param_idx++] = node->level; + for (uint32_t i = 0; i <= node->level; i++) { + sn->params[param_idx++] = node->layers[i].num_links; + sn->params[param_idx++] = node->layers[i].max_links; + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + sn->params[param_idx++] = node->layers[i].links[j]->id; + } + } + uint64_t l2_and_range = 0; + unsigned char *aux = (unsigned char*)&l2_and_range; + memcpy(aux,&node->l2,sizeof(float)); + memcpy(aux+4,&node->quants_range,sizeof(float)); + sn->params[param_idx++] = l2_and_range; + + /* Better safe than sorry: */ + assert(param_idx == num_params); + return sn; +} + +/* This is needed in order to free hnsw_serialize_node() returned + * structure. */ +void hnsw_free_serialized_node(hnswSerNode *sn) { + hfree(sn->params); + hfree(sn); +} + +/* Load a serialized node. See the top comment in this section of code + * for the documentation about how to use this. + * + * The function returns NULL both on out of memory and if the remaining + * parameters length does not match the number of links or other items + * to load. */ +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value) +{ + if (params_len < 2) return NULL; + + uint64_t id = params[0]; + uint32_t level = params[1]; + + /* Keep track of maximum ID seen while loading. */ + if (id >= index->last_id) index->last_id = id; + + /* Create node, passing vector data directly based on quantization type. */ + hnswNode *node; + if (index->quant_type == HNSW_QUANT_Q8) { + node = hnsw_node_new(index, id, NULL, vector, 0, level); + } else { + node = hnsw_node_new(index, id, vector, NULL, 0, level); + } + if (!node) return NULL; + + /* Load params array into the node. */ + uint32_t param_idx = 2; + for (uint32_t i = 0; i <= level; i++) { + /* Sanity check. */ + if (param_idx + 2 > params_len) { + hnsw_node_free(node); + return NULL; + } + + uint32_t num_links = params[param_idx++]; + uint32_t max_links = params[param_idx++]; + + /* If max_links is larger than current allocation, reallocate. + * It could happen in select_neighbors() that we over-allocate the + * node under very unlikely to happen conditions. */ + if (max_links > node->layers[i].max_links) { + hnswNode **new_links = hrealloc(node->layers[i].links, + sizeof(hnswNode*) * max_links); + if (!new_links) { + hnsw_node_free(node); + return NULL; + } + node->layers[i].links = new_links; + node->layers[i].max_links = max_links; + } + node->layers[i].num_links = num_links; + + /* Sanity check. */ + if (param_idx + num_links > params_len) { + hnsw_node_free(node); + return NULL; + } + + /* Fill links for this layer with the IDs. Note that this + * is going to not work in 32 bit systems. Deleting / adding-back + * nodes can produce IDs larger than 2^32-1 even if we can't never + * fit more than 2^32 nodes in a 32 bit system. */ + for (uint32_t j = 0; j < num_links; j++) + node->layers[i].links[j] = (hnswNode*)params[param_idx++]; + } + + /* Get l2 and quantization range. */ + if (param_idx >= params_len) { + hnsw_node_free(node); + return NULL; + } + uint64_t l2_and_range = params[param_idx]; + unsigned char *aux = (unsigned char*)&l2_and_range; + memcpy(&node->l2, aux, sizeof(float)); + memcpy(&node->quants_range, aux+4, sizeof(float)); + + node->value = value; + hnsw_add_node(index, node); + + /* Keep track of higher node level and set the entry point to the + * greatest level node seen so far: thanks to this check we don't + * need to remember what our entry point was during serialization. */ + if (index->enter_point == NULL || level > index->max_level) { + index->max_level = level; + index->enter_point = node; + } + return node; +} + +/* Integer hashing, used by hnsw_deserialize_index(). + * MurmurHash3's 64-bit finalizer function. */ +uint64_t hnsw_hash_node_id(uint64_t id) { + id ^= id >> 33; + id *= 0xff51afd7ed558ccd; + id ^= id >> 33; + id *= 0xc4ceb9fe1a85ec53; + id ^= id >> 33; + return id; +} + +/* Fix pointers of neighbors nodes: after loading the serialized nodes, the + * neighbors links are just IDs (casted to pointers), instead of the actual + * pointers. We need to resolve IDs into pointers. + * + * Return 0 on error (out of memory or some ID that can't be resolved), 1 on + * success. */ +int hnsw_deserialize_index(HNSW *index) { + /* We will use simple linear probing, so over-allocating is a good + * idea: anyway this flat array of pointers will consume a fraction + * of the memory of the loaded index. */ + uint64_t min_size = index->node_count*2; + uint64_t table_size = 1; + while(table_size < min_size) table_size <<= 1; + + hnswNode **table = hmalloc(sizeof(hnswNode*) * table_size); + if (table == NULL) return 0; + memset(table,0,sizeof(hnswNode*) * table_size); + + /* First pass: populate the ID -> pointer hash table. */ + hnswNode *node = index->head; + while(node) { + uint64_t bucket = hnsw_hash_node_id(node->id) & (table_size-1); + for (uint64_t j = 0; j < table_size; j++) { + if (table[bucket] == NULL) { + table[bucket] = node; + break; + } + bucket = (bucket+1) & (table_size-1); + } + node = node->next; + } + + /* Second pass: fix pointers of all the neighbors links. */ + node = index->head; // Rewind. + while(node) { + for (uint32_t i = 0; i <= node->level; i++) { + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + uint64_t linked_id = (uint64_t) node->layers[i].links[j]; + uint64_t bucket = hnsw_hash_node_id(linked_id) & (table_size-1); + hnswNode *neighbor = NULL; + for (uint64_t k = 0; k < table_size; k++) { + if (table[bucket] && table[bucket]->id == linked_id) { + neighbor = table[bucket]; + break; + } + bucket = (bucket+1) & (table_size-1); + } + if (neighbor == NULL) { + /* Unresolved link! Either a bug in this code + * or broken serialization data. */ + hfree(table); + return 0; + } + node->layers[i].links[j] = neighbor; + } + } + node = node->next; + } + hfree(table); + return 1; +} + +/* ================================ Iterator ================================ */ + +/* Get a cursor that can be used as argument of hnsw_cursor_next() to iterate + * all the elements that remain there from the start to the end of the + * iteration, excluding newly added elements. + * + * The function returns NULL on out of memory. */ +hnswCursor *hnsw_cursor_init(HNSW *index) { + hnswCursor *cursor = hmalloc(sizeof(*cursor)); + if (cursor == NULL) return NULL; + cursor->next = index->cursors; + cursor->current = index->head; + index->cursors = cursor; + return cursor; +} + +/* Free the cursor. Can be called both at the end of the iteration, when + * hnsw_cursor_next() returned NULL, or before. */ +void hnsw_cursor_free(HNSW *index, hnswCursor *cursor) { + hnswCursor *x = index->cursors; + hnswCursor *prev = NULL; + while(x) { + if (x == cursor) { + if (prev) + prev->next = cursor->next; + else + index->cursors = cursor->next; + hfree(cursor); + return; + } + x = x->next; + } +} + +/* Return the next element of the HNSW. See hnsw_cursor_init() for + * the guarantees of the function. */ +hnswNode *hnsw_cursor_next(HNSW *index, hnswCursor *cursor) { + (void)index; // Unused but future proof to have. + hnswNode *ret = cursor->current; + if (ret) cursor->current = ret->next; + return ret; +} + +/* Called by hnsw_unlink_node() if there is at least an active cursor. + * Will scan the cursors to see if any cursor is going to yeld this + * one, and in this case, updates the current element to the next. */ +void hnsw_cursor_element_deleted(HNSW *index, hnswNode *deleted) { + hnswCursor *x = index->cursors; + while(x) { + if (x->current == deleted) x->current = deleted->next; + x = x->next; + } +} + +/* ============================ Debugging stuff ============================= */ + +/* Show stats about nodes connections. */ +void hnsw_print_stats(HNSW *index) { + if (!index || !index->head) { + printf("Empty index or NULL pointer passed\n"); + return; + } + + long long total_links = 0; + int min_links = -1; // We'll set this to first node's count. + int isolated_nodes = 0; + uint32_t node_count = 0; + + // Iterate through all nodes using the linked list. + hnswNode *current = index->head; + while (current) { + // Count total links for this node across all layers. + int node_total_links = 0; + for (uint32_t layer = 0; layer <= current->level; layer++) + node_total_links += current->layers[layer].num_links; + + // Update statistics. + total_links += node_total_links; + + // Initialize or update minimum links. + if (min_links == -1 || node_total_links < min_links) { + min_links = node_total_links; + } + + // Check if node is isolated (no links at all). + if (node_total_links == 0) isolated_nodes++; + + node_count++; + current = current->next; + } + + // Print statistics + printf("HNSW Graph Statistics:\n"); + printf("----------------------\n"); + printf("Total nodes: %u\n", node_count); + if (node_count > 0) { + printf("Average links per node: %.2f\n", + (float)total_links / node_count); + printf("Minimum links in a single node: %d\n", min_links); + printf("Number of isolated nodes: %d (%.1f%%)\n", + isolated_nodes, + (float)isolated_nodes * 100 / node_count); + } +} + +/* Validate graph connectivity and link reciprocity. Takes pointers to store results: + * - connected_nodes: will contain number of reachable nodes from entry point. + * - reciprocal_links: will contain 1 if all links are reciprocal, 0 otherwise. + * Returns 0 on success, -1 on error (NULL parameters and such). + */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links) { + if (!index || !connected_nodes || !reciprocal_links) return -1; + if (!index->enter_point) { + *connected_nodes = 0; + *reciprocal_links = 1; // Empty graph is valid. + return 0; + } + + // Initialize connectivity check. + index->current_epoch[0]++; + *connected_nodes = 0; + *reciprocal_links = 1; + + // Initialize node stack. + uint64_t stack_size = index->node_count; + hnswNode **stack = hmalloc(sizeof(hnswNode*) * stack_size); + if (!stack) return -1; + uint64_t stack_top = 0; + + // Start from entry point. + index->enter_point->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + stack[stack_top++] = index->enter_point; + + // Process all reachable nodes. + while (stack_top > 0) { + hnswNode *current = stack[--stack_top]; + + // Explore all neighbors at each level. + for (uint32_t level = 0; level <= current->level; level++) { + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + + // Check reciprocity. + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + if (!found_backlink) { + *reciprocal_links = 0; + } + + // If we haven't visited this neighbor yet. + if (neighbor->visited_epoch[0] != index->current_epoch[0]) { + neighbor->visited_epoch[0] = index->current_epoch[0]; + (*connected_nodes)++; + if (stack_top < stack_size) { + stack[stack_top++] = neighbor; + } else { + // This should never happen in a valid graph. + hfree(stack); + return -1; + } + } + } + } + } + + hfree(stack); + + // Now scan for unreachable nodes and print debug info. + printf("\nUnreachable nodes debug information:\n"); + printf("=====================================\n"); + + hnswNode *current = index->head; + while (current) { + if (current->visited_epoch[0] != index->current_epoch[0]) { + printf("\nUnreachable node found:\n"); + printf("- Node pointer: %p\n", (void*)current); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + + // Print info about all its links at each level. + for (uint32_t level = 0; level <= current->level; level++) { + printf(" Level %u links (%u):\n", level, + current->layers[level].num_links); + for (uint64_t i = 0; i < current->layers[level].num_links; i++) { + hnswNode *neighbor = current->layers[level].links[i]; + // Check reciprocity for this specific link + int found_backlink = 0; + for (uint64_t j = 0; j < neighbor->layers[level].num_links; j++) { + if (neighbor->layers[level].links[j] == current) { + found_backlink = 1; + break; + } + } + printf(" - Link %llu: pointer=%p, id=%llu, visited=%s,recpr=%s\n", + (unsigned long long)i, (void*)neighbor, + (unsigned long long)neighbor->id, + neighbor->visited_epoch[0] == index->current_epoch[0] ? + "yes" : "no", + found_backlink ? "yes" : "no"); + } + } + } + current = current->next; + } + + printf("Total connected nodes: %llu\n", (unsigned long long)*connected_nodes); + printf("All links are bi-directiona? %s\n", (*reciprocal_links)?"yes":"no"); + return 0; +} + +/* Test graph recall ability by verifying each node can be found searching + * for its own vector. This helps validate that the majority of nodes are + * properly connected and easily reachable in the graph structure. Every + * unreachable node is reported. + * + * Normally only a small percentage of nodes will be not reachable when + * visited. This is expected and part of the statistical properties + * of HNSW. This happens especially with entries that have an ambiguous + * meaning in the represented space, and are across two or multiple clusters + * of items. + * + * The function works by: + * 1. Iterating through all nodes in the linked list + * 2. Using each node's vector to perform a search with specified EF + * 3. Verifying the node can find itself as nearest neighbor + * 4. Collecting and reporting statistics about reachability + * + * This is just a debugging function that reports stuff in the standard + * output, part of the implementation because this kind of functions + * provide some visiblity on what happens inside the HNSW. + */ +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose) { + // Stats + uint32_t total_nodes = 0; + uint32_t unreachable_nodes = 0; + uint32_t perfectly_reachable = 0; // Node finds itself as first result + + // For storing search results + hnswNode **neighbors = hmalloc(sizeof(hnswNode*) * test_ef); + float *distances = hmalloc(sizeof(float) * test_ef); + float *test_vector = hmalloc(sizeof(float) * index->vector_dim); + if (!neighbors || !distances || !test_vector) { + hfree(neighbors); + hfree(distances); + hfree(test_vector); + return; + } + + // Get a read slot for searching (even if it's highly unlikely that + // this test will be run threaded...). + int slot = hnsw_acquire_read_slot(index); + if (slot < 0) { + hfree(neighbors); + hfree(distances); + return; + } + + printf("\nTesting graph recall\n"); + printf("====================\n"); + + // Process one node at a time using the linked list + hnswNode *current = index->head; + while (current) { + total_nodes++; + + // If using quantization, we need to reconstruct the normalized vector + if (index->quant_type == HNSW_QUANT_Q8) { + int8_t *quants = current->vector; + // Reconstruct normalized vector from quantized data + for (uint32_t j = 0; j < index->vector_dim; j++) { + test_vector[j] = (quants[j] * current->quants_range) / 127; + } + } else if (index->quant_type == HNSW_QUANT_NONE) { + memcpy(test_vector,current->vector,sizeof(float)*index->vector_dim); + } else { + assert(0 && "Quantization type not supported."); + } + + // Search using the node's own vector with high ef + int found = hnsw_search(index, test_vector, test_ef, neighbors, + distances, slot, 1); + + if (found == 0) continue; // Empty HNSW? + + // Look for the node itself in the results + int found_self = 0; + int self_position = -1; + for (int i = 0; i < found; i++) { + if (neighbors[i] == current) { + found_self = 1; + self_position = i; + break; + } + } + + if (!found_self || self_position != 0) { + unreachable_nodes++; + if (verbose) { + if (!found_self) + printf("\nNode %s cannot find itself:\n", (char*)current->value); + else + printf("\nNode %s is not top result:\n", (char*)current->value); + printf("- Node ID: %llu\n", (unsigned long long)current->id); + printf("- Node level: %u\n", current->level); + printf("- Found %d neighbors but self not among them\n", found); + printf("- Closest neighbor distance: %f\n", distances[0]); + printf("- Neighbors: "); + for (uint32_t i = 0; i < current->layers[0].num_links; i++) { + printf("%s ", (char*)current->layers[0].links[i]->value); + } + printf("\n"); + printf("\nFound instead: "); + for (int j = 0; j < found && j < 10; j++) { + printf("%s ", (char*)neighbors[j]->value); + } + printf("\n"); + } + } else { + perfectly_reachable++; + } + current = current->next; + } + + // Release read slot + hnsw_release_read_slot(index, slot); + + // Free resources + hfree(neighbors); + hfree(distances); + hfree(test_vector); + + // Print final statistics + printf("Total nodes tested: %u\n", total_nodes); + printf("Perfectly reachable nodes: %u (%.1f%%)\n", + perfectly_reachable, + total_nodes ? (float)perfectly_reachable * 100 / total_nodes : 0); + printf("Unreachable/suboptimal nodes: %u (%.1f%%)\n", + unreachable_nodes, + total_nodes ? (float)unreachable_nodes * 100 / total_nodes : 0); +} diff --git a/hnsw.h b/hnsw.h new file mode 100644 index 000000000..3d104cc5e --- /dev/null +++ b/hnsw.h @@ -0,0 +1,158 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright(C) 2024 Salvatore Sanfilippo. All Rights Reserved. + */ + +#ifndef HNSW_H +#define HNSW_H + +#include +#include + +#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */ + +/* Quantization types you can enable at creation time in hnsw_new() */ +#define HNSW_QUANT_NONE 0 // No quantization. +#define HNSW_QUANT_Q8 1 // Q8 quantization. +#define HNSW_QUANT_BIN 2 // Binary quantization. + +/* Layer structure for HNSW nodes. Each node will have from one to a few + * of this depending on its level. */ +typedef struct { + struct hnswNode **links; /* Array of neighbors for this layer */ + uint32_t num_links; /* Number of used links */ + uint32_t max_links; /* Maximum links for this layer. We may + * reallocate the node in very particular + * conditions in order to allow linking of + * new inserted nodes, so this may change + * dynamically for a small set of nodes. */ + float worst_distance; /* Distance to the worst neighbor */ + uint32_t worst_idx; /* Index of the worst neighbor */ +} hnswNodeLayer; + +/* Node structure for HNSW graph */ +typedef struct hnswNode { + uint32_t level; /* Node's maximum level */ + uint64_t id; /* Unique identifier, may be useful in order to + * have a bitmap of visited notes to use as + * alternative to epoch / visited_epoch. + * Also used in serialization in order to retain + * links specifying IDs. */ + void *vector; /* The vector, quantized or not. */ + float quants_range; /* Quantization range for this vector: + * min/max values will be in the range + * -quants_range, +quants_range */ + float l2; /* L2 before normalization. */ + + /* Last time (epoch) this node was visited. We need one per thread. + * This avoids having a different data structure where we track + * visited nodes, but costs memory per node. */ + uint64_t visited_epoch[HNSW_MAX_THREADS]; + + void *value; /* Associated value */ + struct hnswNode *prev, *next; /* Prev/Next node in the list starting at + * HNSW->head. */ + + /* Links (and links info) per each layer. Note that this is part + * of the node allocation to be more cache friendly: reliable 3% speedup + * on Apple silicon, and does not make anything more complex. */ + hnswNodeLayer layers[]; +} hnswNode; + +/* It is possible to navigate an HNSW with a cursor that guarantees + * visiting all the elements that remain in the HNSW from the start to the + * end of the process (but not the new ones, so that the process will + * eventually finish). Check hnsw_cursor_init(), hnsw_cursor_next() and + * hnsw_cursor_free(). */ +typedef struct hnswCursor { + hnswNode *current; // Element to report when hnsw_cursor_next() is called. + struct hnswCursor *next; // Next cursor active. +} hnswCursor; + +/* Main HNSW index structure */ +typedef struct HNSW { + hnswNode *enter_point; /* Entry point for the graph */ + uint32_t max_level; /* Current maximum level in the graph */ + uint32_t vector_dim; /* Dimensionality of stored vectors */ + uint64_t node_count; /* Total number of nodes */ + _Atomic uint64_t last_id; /* Last node ID used */ + uint64_t current_epoch[HNSW_MAX_THREADS]; /* Current epoch for visit tracking */ + hnswNode *head; /* Linked list of nodes. Last first */ + + /* We have two locks here: + * 1. A global_lock that is used to perform write operations blocking all + * the readers. + * 2. One mutex per epoch slot, in order for read operations to acquire + * a lock on a specific slot to use epochs tracking of visited nodes. */ + pthread_rwlock_t global_lock; /* Global read-write lock */ + pthread_mutex_t slot_locks[HNSW_MAX_THREADS]; /* Per-slot locks */ + + _Atomic uint32_t next_slot; /* Next thread slot to try */ + _Atomic uint64_t version; /* Version for optimistic concurrency, this is + * incremented on deletions and entry point + * updates. */ + uint32_t quant_type; /* Quantization used. HNSW_QUANT_... */ + hnswCursor *cursors; +} HNSW; + +/* Serialized node. This structure is used as return value of + * hnsw_serialize_node(). */ +typedef struct hnswSerNode { + void *vector; + uint32_t vector_size; + uint64_t *params; + uint32_t params_count; +} hnswSerNode; + +/* Insert preparation context */ +typedef struct InsertContext InsertContext; + +/* Core HNSW functions */ +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type); +void hnsw_free(HNSW *index,void(*free_value)(void*value)); +void hnsw_node_free(hnswNode *node); +void hnsw_print_stats(HNSW *index); +hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector, + float qrange, uint64_t id, void *value, int ef); +int hnsw_search(HNSW *index, const float *query, uint32_t k, + hnswNode **neighbors, float *distances, uint32_t slot, + int query_vector_is_normalized); +void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec); +void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value)); + +/* Thread safety functions. */ +int hnsw_acquire_read_slot(HNSW *index); +void hnsw_release_read_slot(HNSW *index, int slot); + +/* Optimistic insertion API. */ +InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef); +hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx); +void hnsw_free_insert_context(InsertContext *ctx); + +/* Serialization. */ +hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node); +void hnsw_free_serialized_node(hnswSerNode *sn); +hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value); +int hnsw_deserialize_index(HNSW *index); + +// Helper function in case the user wants to directly copy +// the vector bytes. +uint32_t hnsw_quants_bytes(HNSW *index); + +/* Cursors. */ +hnswCursor *hnsw_cursor_init(HNSW *index); +void hnsw_cursor_free(HNSW *index, hnswCursor *cursor); +hnswNode *hnsw_cursor_next(HNSW *index, hnswCursor *cursor); + +/* Allocator selection. */ +void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t), + void *(*realloc_ptr)(void*, size_t)); + +/* Testing. */ +int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links); +void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose); +float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b); + +#endif /* HNSW_H */ diff --git a/redismodule.h b/redismodule.h new file mode 100644 index 000000000..b84913b1e --- /dev/null +++ b/redismodule.h @@ -0,0 +1,1704 @@ +#ifndef REDISMODULE_H +#define REDISMODULE_H + +#include +#include +#include +#include + + +typedef struct RedisModuleString RedisModuleString; +typedef struct RedisModuleKey RedisModuleKey; + +/* -------------- Defines NOT common between core and modules ------------- */ + +#if defined REDISMODULE_CORE +/* Things only defined for the modules core (server), not exported to modules + * that include this file. */ + +#define RedisModuleString robj + +#endif /* defined REDISMODULE_CORE */ + +#if !defined REDISMODULE_CORE && !defined REDISMODULE_CORE_MODULE +/* Things defined for modules, but not for core-modules. */ + +typedef long long mstime_t; +typedef long long ustime_t; + +#endif /* !defined REDISMODULE_CORE && !defined REDISMODULE_CORE_MODULE */ + +/* ---------------- Defines common between core and modules --------------- */ + +/* Error status return values. */ +#define REDISMODULE_OK 0 +#define REDISMODULE_ERR 1 + +/* Module Based Authentication status return values. */ +#define REDISMODULE_AUTH_HANDLED 0 +#define REDISMODULE_AUTH_NOT_HANDLED 1 + +/* API versions. */ +#define REDISMODULE_APIVER_1 1 + +/* Version of the RedisModuleTypeMethods structure. Once the RedisModuleTypeMethods + * structure is changed, this version number needs to be changed synchronistically. */ +#define REDISMODULE_TYPE_METHOD_VERSION 5 + +/* API flags and constants */ +#define REDISMODULE_READ (1<<0) +#define REDISMODULE_WRITE (1<<1) + +/* RedisModule_OpenKey extra flags for the 'mode' argument. + * Avoid touching the LRU/LFU of the key when opened. */ +#define REDISMODULE_OPEN_KEY_NOTOUCH (1<<16) +/* Don't trigger keyspace event on key misses. */ +#define REDISMODULE_OPEN_KEY_NONOTIFY (1<<17) +/* Don't update keyspace hits/misses counters. */ +#define REDISMODULE_OPEN_KEY_NOSTATS (1<<18) +/* Avoid deleting lazy expired keys. */ +#define REDISMODULE_OPEN_KEY_NOEXPIRE (1<<19) +/* Avoid any effects from fetching the key */ +#define REDISMODULE_OPEN_KEY_NOEFFECTS (1<<20) +/* Allow access expired key that haven't deleted yet */ +#define REDISMODULE_OPEN_KEY_ACCESS_EXPIRED (1<<21) + +/* Mask of all REDISMODULE_OPEN_KEY_* values. Any new mode should be added to this list. + * Should not be used directly by the module, use RM_GetOpenKeyModesAll instead. + * Located here so when we will add new modes we will not forget to update it. */ +#define _REDISMODULE_OPEN_KEY_ALL REDISMODULE_READ | REDISMODULE_WRITE | REDISMODULE_OPEN_KEY_NOTOUCH | REDISMODULE_OPEN_KEY_NONOTIFY | REDISMODULE_OPEN_KEY_NOSTATS | REDISMODULE_OPEN_KEY_NOEXPIRE | REDISMODULE_OPEN_KEY_NOEFFECTS | REDISMODULE_OPEN_KEY_ACCESS_EXPIRED + +/* List push and pop */ +#define REDISMODULE_LIST_HEAD 0 +#define REDISMODULE_LIST_TAIL 1 + +/* Key types. */ +#define REDISMODULE_KEYTYPE_EMPTY 0 +#define REDISMODULE_KEYTYPE_STRING 1 +#define REDISMODULE_KEYTYPE_LIST 2 +#define REDISMODULE_KEYTYPE_HASH 3 +#define REDISMODULE_KEYTYPE_SET 4 +#define REDISMODULE_KEYTYPE_ZSET 5 +#define REDISMODULE_KEYTYPE_MODULE 6 +#define REDISMODULE_KEYTYPE_STREAM 7 + +/* Reply types. */ +#define REDISMODULE_REPLY_UNKNOWN -1 +#define REDISMODULE_REPLY_STRING 0 +#define REDISMODULE_REPLY_ERROR 1 +#define REDISMODULE_REPLY_INTEGER 2 +#define REDISMODULE_REPLY_ARRAY 3 +#define REDISMODULE_REPLY_NULL 4 +#define REDISMODULE_REPLY_MAP 5 +#define REDISMODULE_REPLY_SET 6 +#define REDISMODULE_REPLY_BOOL 7 +#define REDISMODULE_REPLY_DOUBLE 8 +#define REDISMODULE_REPLY_BIG_NUMBER 9 +#define REDISMODULE_REPLY_VERBATIM_STRING 10 +#define REDISMODULE_REPLY_ATTRIBUTE 11 +#define REDISMODULE_REPLY_PROMISE 12 + +/* Postponed array length. */ +#define REDISMODULE_POSTPONED_ARRAY_LEN -1 /* Deprecated, please use REDISMODULE_POSTPONED_LEN */ +#define REDISMODULE_POSTPONED_LEN -1 + +/* Expire */ +#define REDISMODULE_NO_EXPIRE -1 + +/* Sorted set API flags. */ +#define REDISMODULE_ZADD_XX (1<<0) +#define REDISMODULE_ZADD_NX (1<<1) +#define REDISMODULE_ZADD_ADDED (1<<2) +#define REDISMODULE_ZADD_UPDATED (1<<3) +#define REDISMODULE_ZADD_NOP (1<<4) +#define REDISMODULE_ZADD_GT (1<<5) +#define REDISMODULE_ZADD_LT (1<<6) + +/* Hash API flags. */ +#define REDISMODULE_HASH_NONE 0 +#define REDISMODULE_HASH_NX (1<<0) +#define REDISMODULE_HASH_XX (1<<1) +#define REDISMODULE_HASH_CFIELDS (1<<2) +#define REDISMODULE_HASH_EXISTS (1<<3) +#define REDISMODULE_HASH_COUNT_ALL (1<<4) + +#define REDISMODULE_CONFIG_DEFAULT 0 /* This is the default for a module config. */ +#define REDISMODULE_CONFIG_IMMUTABLE (1ULL<<0) /* Can this value only be set at startup? */ +#define REDISMODULE_CONFIG_SENSITIVE (1ULL<<1) /* Does this value contain sensitive information */ +#define REDISMODULE_CONFIG_HIDDEN (1ULL<<4) /* This config is hidden in `config get ` (used for tests/debugging) */ +#define REDISMODULE_CONFIG_PROTECTED (1ULL<<5) /* Becomes immutable if enable-protected-configs is enabled. */ +#define REDISMODULE_CONFIG_DENY_LOADING (1ULL<<6) /* This config is forbidden during loading. */ + +#define REDISMODULE_CONFIG_MEMORY (1ULL<<7) /* Indicates if this value can be set as a memory value */ +#define REDISMODULE_CONFIG_BITFLAGS (1ULL<<8) /* Indicates if this value can be set as a multiple enum values */ + +/* StreamID type. */ +typedef struct RedisModuleStreamID { + uint64_t ms; + uint64_t seq; +} RedisModuleStreamID; + +/* StreamAdd() flags. */ +#define REDISMODULE_STREAM_ADD_AUTOID (1<<0) +/* StreamIteratorStart() flags. */ +#define REDISMODULE_STREAM_ITERATOR_EXCLUSIVE (1<<0) +#define REDISMODULE_STREAM_ITERATOR_REVERSE (1<<1) +/* StreamIteratorTrim*() flags. */ +#define REDISMODULE_STREAM_TRIM_APPROX (1<<0) + +/* Context Flags: Info about the current context returned by + * RM_GetContextFlags(). */ + +/* The command is running in the context of a Lua script */ +#define REDISMODULE_CTX_FLAGS_LUA (1<<0) +/* The command is running inside a Redis transaction */ +#define REDISMODULE_CTX_FLAGS_MULTI (1<<1) +/* The instance is a master */ +#define REDISMODULE_CTX_FLAGS_MASTER (1<<2) +/* The instance is a slave */ +#define REDISMODULE_CTX_FLAGS_SLAVE (1<<3) +/* The instance is read-only (usually meaning it's a slave as well) */ +#define REDISMODULE_CTX_FLAGS_READONLY (1<<4) +/* The instance is running in cluster mode */ +#define REDISMODULE_CTX_FLAGS_CLUSTER (1<<5) +/* The instance has AOF enabled */ +#define REDISMODULE_CTX_FLAGS_AOF (1<<6) +/* The instance has RDB enabled */ +#define REDISMODULE_CTX_FLAGS_RDB (1<<7) +/* The instance has Maxmemory set */ +#define REDISMODULE_CTX_FLAGS_MAXMEMORY (1<<8) +/* Maxmemory is set and has an eviction policy that may delete keys */ +#define REDISMODULE_CTX_FLAGS_EVICT (1<<9) +/* Redis is out of memory according to the maxmemory flag. */ +#define REDISMODULE_CTX_FLAGS_OOM (1<<10) +/* Less than 25% of memory available according to maxmemory. */ +#define REDISMODULE_CTX_FLAGS_OOM_WARNING (1<<11) +/* The command was sent over the replication link. */ +#define REDISMODULE_CTX_FLAGS_REPLICATED (1<<12) +/* Redis is currently loading either from AOF or RDB. */ +#define REDISMODULE_CTX_FLAGS_LOADING (1<<13) +/* The replica has no link with its master, note that + * there is the inverse flag as well: + * + * REDISMODULE_CTX_FLAGS_REPLICA_IS_ONLINE + * + * The two flags are exclusive, one or the other can be set. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_STALE (1<<14) +/* The replica is trying to connect with the master. + * (REPL_STATE_CONNECT and REPL_STATE_CONNECTING states) */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_CONNECTING (1<<15) +/* THe replica is receiving an RDB file from its master. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_TRANSFERRING (1<<16) +/* The replica is online, receiving updates from its master. */ +#define REDISMODULE_CTX_FLAGS_REPLICA_IS_ONLINE (1<<17) +/* There is currently some background process active. */ +#define REDISMODULE_CTX_FLAGS_ACTIVE_CHILD (1<<18) +/* The next EXEC will fail due to dirty CAS (touched keys). */ +#define REDISMODULE_CTX_FLAGS_MULTI_DIRTY (1<<19) +/* Redis is currently running inside background child process. */ +#define REDISMODULE_CTX_FLAGS_IS_CHILD (1<<20) +/* The current client does not allow blocking, either called from + * within multi, lua, or from another module using RM_Call */ +#define REDISMODULE_CTX_FLAGS_DENY_BLOCKING (1<<21) +/* The current client uses RESP3 protocol */ +#define REDISMODULE_CTX_FLAGS_RESP3 (1<<22) +/* Redis is currently async loading database for diskless replication. */ +#define REDISMODULE_CTX_FLAGS_ASYNC_LOADING (1<<23) +/* Redis is starting. */ +#define REDISMODULE_CTX_FLAGS_SERVER_STARTUP (1<<24) + +/* Next context flag, must be updated when adding new flags above! +This flag should not be used directly by the module. + * Use RedisModule_GetContextFlagsAll instead. */ +#define _REDISMODULE_CTX_FLAGS_NEXT (1<<25) + +/* Keyspace changes notification classes. Every class is associated with a + * character for configuration purposes. + * NOTE: These have to be in sync with NOTIFY_* in server.h */ +#define REDISMODULE_NOTIFY_KEYSPACE (1<<0) /* K */ +#define REDISMODULE_NOTIFY_KEYEVENT (1<<1) /* E */ +#define REDISMODULE_NOTIFY_GENERIC (1<<2) /* g */ +#define REDISMODULE_NOTIFY_STRING (1<<3) /* $ */ +#define REDISMODULE_NOTIFY_LIST (1<<4) /* l */ +#define REDISMODULE_NOTIFY_SET (1<<5) /* s */ +#define REDISMODULE_NOTIFY_HASH (1<<6) /* h */ +#define REDISMODULE_NOTIFY_ZSET (1<<7) /* z */ +#define REDISMODULE_NOTIFY_EXPIRED (1<<8) /* x */ +#define REDISMODULE_NOTIFY_EVICTED (1<<9) /* e */ +#define REDISMODULE_NOTIFY_STREAM (1<<10) /* t */ +#define REDISMODULE_NOTIFY_KEY_MISS (1<<11) /* m (Note: This one is excluded from REDISMODULE_NOTIFY_ALL on purpose) */ +#define REDISMODULE_NOTIFY_LOADED (1<<12) /* module only key space notification, indicate a key loaded from rdb */ +#define REDISMODULE_NOTIFY_MODULE (1<<13) /* d, module key space notification */ +#define REDISMODULE_NOTIFY_NEW (1<<14) /* n, new key notification */ + +/* Next notification flag, must be updated when adding new flags above! +This flag should not be used directly by the module. + * Use RedisModule_GetKeyspaceNotificationFlagsAll instead. */ +#define _REDISMODULE_NOTIFY_NEXT (1<<15) + +#define REDISMODULE_NOTIFY_ALL (REDISMODULE_NOTIFY_GENERIC | REDISMODULE_NOTIFY_STRING | REDISMODULE_NOTIFY_LIST | REDISMODULE_NOTIFY_SET | REDISMODULE_NOTIFY_HASH | REDISMODULE_NOTIFY_ZSET | REDISMODULE_NOTIFY_EXPIRED | REDISMODULE_NOTIFY_EVICTED | REDISMODULE_NOTIFY_STREAM | REDISMODULE_NOTIFY_MODULE) /* A */ + +/* A special pointer that we can use between the core and the module to signal + * field deletion, and that is impossible to be a valid pointer. */ +#define REDISMODULE_HASH_DELETE ((RedisModuleString*)(long)1) + +/* Error messages. */ +#define REDISMODULE_ERRORMSG_WRONGTYPE "WRONGTYPE Operation against a key holding the wrong kind of value" + +#define REDISMODULE_POSITIVE_INFINITE (1.0/0.0) +#define REDISMODULE_NEGATIVE_INFINITE (-1.0/0.0) + +/* Cluster API defines. */ +#define REDISMODULE_NODE_ID_LEN 40 +#define REDISMODULE_NODE_MYSELF (1<<0) +#define REDISMODULE_NODE_MASTER (1<<1) +#define REDISMODULE_NODE_SLAVE (1<<2) +#define REDISMODULE_NODE_PFAIL (1<<3) +#define REDISMODULE_NODE_FAIL (1<<4) +#define REDISMODULE_NODE_NOFAILOVER (1<<5) + +#define REDISMODULE_CLUSTER_FLAG_NONE 0 +#define REDISMODULE_CLUSTER_FLAG_NO_FAILOVER (1<<1) +#define REDISMODULE_CLUSTER_FLAG_NO_REDIRECTION (1<<2) + +#define REDISMODULE_NOT_USED(V) ((void) V) + +/* Logging level strings */ +#define REDISMODULE_LOGLEVEL_DEBUG "debug" +#define REDISMODULE_LOGLEVEL_VERBOSE "verbose" +#define REDISMODULE_LOGLEVEL_NOTICE "notice" +#define REDISMODULE_LOGLEVEL_WARNING "warning" + +/* Bit flags for aux_save_triggers and the aux_load and aux_save callbacks */ +#define REDISMODULE_AUX_BEFORE_RDB (1<<0) +#define REDISMODULE_AUX_AFTER_RDB (1<<1) + +/* RM_Yield flags */ +#define REDISMODULE_YIELD_FLAG_NONE (1<<0) +#define REDISMODULE_YIELD_FLAG_CLIENTS (1<<1) + +/* RM_BlockClientOnKeysWithFlags flags */ +#define REDISMODULE_BLOCK_UNBLOCK_DEFAULT (0) +#define REDISMODULE_BLOCK_UNBLOCK_DELETED (1<<0) + +/* This type represents a timer handle, and is returned when a timer is + * registered and used in order to invalidate a timer. It's just a 64 bit + * number, because this is how each timer is represented inside the radix tree + * of timers that are going to expire, sorted by expire time. */ +typedef uint64_t RedisModuleTimerID; + +/* CommandFilter Flags */ + +/* Do filter RedisModule_Call() commands initiated by module itself. */ +#define REDISMODULE_CMDFILTER_NOSELF (1<<0) + +/* Declare that the module can handle errors with RedisModule_SetModuleOptions. */ +#define REDISMODULE_OPTIONS_HANDLE_IO_ERRORS (1<<0) + +/* When set, Redis will not call RedisModule_SignalModifiedKey(), implicitly in + * RedisModule_CloseKey, and the module needs to do that when manually when keys + * are modified from the user's perspective, to invalidate WATCH. */ +#define REDISMODULE_OPTION_NO_IMPLICIT_SIGNAL_MODIFIED (1<<1) + +/* Declare that the module can handle diskless async replication with RedisModule_SetModuleOptions. */ +#define REDISMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD (1<<2) + +/* Declare that the module want to get nested key space notifications. + * If enabled, the module is responsible to break endless loop. */ +#define REDISMODULE_OPTIONS_ALLOW_NESTED_KEYSPACE_NOTIFICATIONS (1<<3) + +/* Next option flag, must be updated when adding new module flags above! + * This flag should not be used directly by the module. + * Use RedisModule_GetModuleOptionsAll instead. */ +#define _REDISMODULE_OPTIONS_FLAGS_NEXT (1<<4) + +/* Definitions for RedisModule_SetCommandInfo. */ + +typedef enum { + REDISMODULE_ARG_TYPE_STRING, + REDISMODULE_ARG_TYPE_INTEGER, + REDISMODULE_ARG_TYPE_DOUBLE, + REDISMODULE_ARG_TYPE_KEY, /* A string, but represents a keyname */ + REDISMODULE_ARG_TYPE_PATTERN, + REDISMODULE_ARG_TYPE_UNIX_TIME, + REDISMODULE_ARG_TYPE_PURE_TOKEN, + REDISMODULE_ARG_TYPE_ONEOF, /* Must have sub-arguments */ + REDISMODULE_ARG_TYPE_BLOCK /* Must have sub-arguments */ +} RedisModuleCommandArgType; + +#define REDISMODULE_CMD_ARG_NONE (0) +#define REDISMODULE_CMD_ARG_OPTIONAL (1<<0) /* The argument is optional (like GET in SET command) */ +#define REDISMODULE_CMD_ARG_MULTIPLE (1<<1) /* The argument may repeat itself (like key in DEL) */ +#define REDISMODULE_CMD_ARG_MULTIPLE_TOKEN (1<<2) /* The argument may repeat itself, and so does its token (like `GET pattern` in SORT) */ +#define _REDISMODULE_CMD_ARG_NEXT (1<<3) + +typedef enum { + REDISMODULE_KSPEC_BS_INVALID = 0, /* Must be zero. An implicitly value of + * zero is provided when the field is + * absent in a struct literal. */ + REDISMODULE_KSPEC_BS_UNKNOWN, + REDISMODULE_KSPEC_BS_INDEX, + REDISMODULE_KSPEC_BS_KEYWORD +} RedisModuleKeySpecBeginSearchType; + +typedef enum { + REDISMODULE_KSPEC_FK_OMITTED = 0, /* Used when the field is absent in a + * struct literal. Don't use this value + * explicitly. */ + REDISMODULE_KSPEC_FK_UNKNOWN, + REDISMODULE_KSPEC_FK_RANGE, + REDISMODULE_KSPEC_FK_KEYNUM +} RedisModuleKeySpecFindKeysType; + +/* Key-spec flags. For details, see the documentation of + * RedisModule_SetCommandInfo and the key-spec flags in server.h. */ +#define REDISMODULE_CMD_KEY_RO (1ULL<<0) +#define REDISMODULE_CMD_KEY_RW (1ULL<<1) +#define REDISMODULE_CMD_KEY_OW (1ULL<<2) +#define REDISMODULE_CMD_KEY_RM (1ULL<<3) +#define REDISMODULE_CMD_KEY_ACCESS (1ULL<<4) +#define REDISMODULE_CMD_KEY_UPDATE (1ULL<<5) +#define REDISMODULE_CMD_KEY_INSERT (1ULL<<6) +#define REDISMODULE_CMD_KEY_DELETE (1ULL<<7) +#define REDISMODULE_CMD_KEY_NOT_KEY (1ULL<<8) +#define REDISMODULE_CMD_KEY_INCOMPLETE (1ULL<<9) +#define REDISMODULE_CMD_KEY_VARIABLE_FLAGS (1ULL<<10) + +/* Channel flags, for details see the documentation of + * RedisModule_ChannelAtPosWithFlags. */ +#define REDISMODULE_CMD_CHANNEL_PATTERN (1ULL<<0) +#define REDISMODULE_CMD_CHANNEL_PUBLISH (1ULL<<1) +#define REDISMODULE_CMD_CHANNEL_SUBSCRIBE (1ULL<<2) +#define REDISMODULE_CMD_CHANNEL_UNSUBSCRIBE (1ULL<<3) + +typedef struct RedisModuleCommandArg { + const char *name; + RedisModuleCommandArgType type; + int key_spec_index; /* If type is KEY, this is a zero-based index of + * the key_spec in the command. For other types, + * you may specify -1. */ + const char *token; /* If type is PURE_TOKEN, this is the token. */ + const char *summary; + const char *since; + int flags; /* The REDISMODULE_CMD_ARG_* macros. */ + const char *deprecated_since; + struct RedisModuleCommandArg *subargs; + const char *display_text; +} RedisModuleCommandArg; + +typedef struct { + const char *since; + const char *changes; +} RedisModuleCommandHistoryEntry; + +typedef struct { + const char *notes; + uint64_t flags; /* REDISMODULE_CMD_KEY_* macros. */ + RedisModuleKeySpecBeginSearchType begin_search_type; + union { + struct { + /* The index from which we start the search for keys */ + int pos; + } index; + struct { + /* The keyword that indicates the beginning of key args */ + const char *keyword; + /* An index in argv from which to start searching. + * Can be negative, which means start search from the end, in reverse + * (Example: -2 means to start in reverse from the penultimate arg) */ + int startfrom; + } keyword; + } bs; + RedisModuleKeySpecFindKeysType find_keys_type; + union { + struct { + /* Index of the last key relative to the result of the begin search + * step. Can be negative, in which case it's not relative. -1 + * indicating till the last argument, -2 one before the last and so + * on. */ + int lastkey; + /* How many args should we skip after finding a key, in order to + * find the next one. */ + int keystep; + /* If lastkey is -1, we use limit to stop the search by a factor. 0 + * and 1 mean no limit. 2 means 1/2 of the remaining args, 3 means + * 1/3, and so on. */ + int limit; + } range; + struct { + /* Index of the argument containing the number of keys to come + * relative to the result of the begin search step */ + int keynumidx; + /* Index of the fist key. (Usually it's just after keynumidx, in + * which case it should be set to keynumidx + 1.) */ + int firstkey; + /* How many args should we skip after finding a key, in order to + * find the next one, relative to the result of the begin search + * step. */ + int keystep; + } keynum; + } fk; +} RedisModuleCommandKeySpec; + +typedef struct { + int version; + size_t sizeof_historyentry; + size_t sizeof_keyspec; + size_t sizeof_arg; +} RedisModuleCommandInfoVersion; + +static const RedisModuleCommandInfoVersion RedisModule_CurrentCommandInfoVersion = { + .version = 1, + .sizeof_historyentry = sizeof(RedisModuleCommandHistoryEntry), + .sizeof_keyspec = sizeof(RedisModuleCommandKeySpec), + .sizeof_arg = sizeof(RedisModuleCommandArg) +}; + +#define REDISMODULE_COMMAND_INFO_VERSION (&RedisModule_CurrentCommandInfoVersion) + +typedef struct { + /* Always set version to REDISMODULE_COMMAND_INFO_VERSION */ + const RedisModuleCommandInfoVersion *version; + /* Version 1 fields (added in Redis 7.0.0) */ + const char *summary; /* Summary of the command */ + const char *complexity; /* Complexity description */ + const char *since; /* Debut module version of the command */ + RedisModuleCommandHistoryEntry *history; /* History */ + /* A string of space-separated tips meant for clients/proxies regarding this + * command */ + const char *tips; + /* Number of arguments, it is possible to use -N to say >= N */ + int arity; + RedisModuleCommandKeySpec *key_specs; + RedisModuleCommandArg *args; +} RedisModuleCommandInfo; + +/* Eventloop definitions. */ +#define REDISMODULE_EVENTLOOP_READABLE 1 +#define REDISMODULE_EVENTLOOP_WRITABLE 2 +typedef void (*RedisModuleEventLoopFunc)(int fd, void *user_data, int mask); +typedef void (*RedisModuleEventLoopOneShotFunc)(void *user_data); + +/* Server events definitions. + * Those flags should not be used directly by the module, instead + * the module should use RedisModuleEvent_* variables. + * Note: This must be synced with moduleEventVersions */ +#define REDISMODULE_EVENT_REPLICATION_ROLE_CHANGED 0 +#define REDISMODULE_EVENT_PERSISTENCE 1 +#define REDISMODULE_EVENT_FLUSHDB 2 +#define REDISMODULE_EVENT_LOADING 3 +#define REDISMODULE_EVENT_CLIENT_CHANGE 4 +#define REDISMODULE_EVENT_SHUTDOWN 5 +#define REDISMODULE_EVENT_REPLICA_CHANGE 6 +#define REDISMODULE_EVENT_MASTER_LINK_CHANGE 7 +#define REDISMODULE_EVENT_CRON_LOOP 8 +#define REDISMODULE_EVENT_MODULE_CHANGE 9 +#define REDISMODULE_EVENT_LOADING_PROGRESS 10 +#define REDISMODULE_EVENT_SWAPDB 11 +#define REDISMODULE_EVENT_REPL_BACKUP 12 /* Deprecated since Redis 7.0, not used anymore. */ +#define REDISMODULE_EVENT_FORK_CHILD 13 +#define REDISMODULE_EVENT_REPL_ASYNC_LOAD 14 +#define REDISMODULE_EVENT_EVENTLOOP 15 +#define REDISMODULE_EVENT_CONFIG 16 +#define REDISMODULE_EVENT_KEY 17 +#define _REDISMODULE_EVENT_NEXT 18 /* Next event flag, should be updated if a new event added. */ + +typedef struct RedisModuleEvent { + uint64_t id; /* REDISMODULE_EVENT_... defines. */ + uint64_t dataver; /* Version of the structure we pass as 'data'. */ +} RedisModuleEvent; + +struct RedisModuleCtx; +struct RedisModuleDefragCtx; +typedef void (*RedisModuleEventCallback)(struct RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t subevent, void *data); + +/* IMPORTANT: When adding a new version of one of below structures that contain + * event data (RedisModuleFlushInfoV1 for example) we have to avoid renaming the + * old RedisModuleEvent structure. + * For example, if we want to add RedisModuleFlushInfoV2, the RedisModuleEvent + * structures should be: + * RedisModuleEvent_FlushDB = { + * REDISMODULE_EVENT_FLUSHDB, + * 1 + * }, + * RedisModuleEvent_FlushDBV2 = { + * REDISMODULE_EVENT_FLUSHDB, + * 2 + * } + * and NOT: + * RedisModuleEvent_FlushDBV1 = { + * REDISMODULE_EVENT_FLUSHDB, + * 1 + * }, + * RedisModuleEvent_FlushDB = { + * REDISMODULE_EVENT_FLUSHDB, + * 2 + * } + * The reason for that is forward-compatibility: We want that module that + * compiled with a new redismodule.h to be able to work with a old server, + * unless the author explicitly decided to use the newer event type. + */ +static const RedisModuleEvent + RedisModuleEvent_ReplicationRoleChanged = { + REDISMODULE_EVENT_REPLICATION_ROLE_CHANGED, + 1 + }, + RedisModuleEvent_Persistence = { + REDISMODULE_EVENT_PERSISTENCE, + 1 + }, + RedisModuleEvent_FlushDB = { + REDISMODULE_EVENT_FLUSHDB, + 1 + }, + RedisModuleEvent_Loading = { + REDISMODULE_EVENT_LOADING, + 1 + }, + RedisModuleEvent_ClientChange = { + REDISMODULE_EVENT_CLIENT_CHANGE, + 1 + }, + RedisModuleEvent_Shutdown = { + REDISMODULE_EVENT_SHUTDOWN, + 1 + }, + RedisModuleEvent_ReplicaChange = { + REDISMODULE_EVENT_REPLICA_CHANGE, + 1 + }, + RedisModuleEvent_CronLoop = { + REDISMODULE_EVENT_CRON_LOOP, + 1 + }, + RedisModuleEvent_MasterLinkChange = { + REDISMODULE_EVENT_MASTER_LINK_CHANGE, + 1 + }, + RedisModuleEvent_ModuleChange = { + REDISMODULE_EVENT_MODULE_CHANGE, + 1 + }, + RedisModuleEvent_LoadingProgress = { + REDISMODULE_EVENT_LOADING_PROGRESS, + 1 + }, + RedisModuleEvent_SwapDB = { + REDISMODULE_EVENT_SWAPDB, + 1 + }, + /* Deprecated since Redis 7.0, not used anymore. */ + __attribute__ ((deprecated)) + RedisModuleEvent_ReplBackup = { + REDISMODULE_EVENT_REPL_BACKUP, + 1 + }, + RedisModuleEvent_ReplAsyncLoad = { + REDISMODULE_EVENT_REPL_ASYNC_LOAD, + 1 + }, + RedisModuleEvent_ForkChild = { + REDISMODULE_EVENT_FORK_CHILD, + 1 + }, + RedisModuleEvent_EventLoop = { + REDISMODULE_EVENT_EVENTLOOP, + 1 + }, + RedisModuleEvent_Config = { + REDISMODULE_EVENT_CONFIG, + 1 + }, + RedisModuleEvent_Key = { + REDISMODULE_EVENT_KEY, + 1 + }; + +/* Those are values that are used for the 'subevent' callback argument. */ +#define REDISMODULE_SUBEVENT_PERSISTENCE_RDB_START 0 +#define REDISMODULE_SUBEVENT_PERSISTENCE_AOF_START 1 +#define REDISMODULE_SUBEVENT_PERSISTENCE_SYNC_RDB_START 2 +#define REDISMODULE_SUBEVENT_PERSISTENCE_ENDED 3 +#define REDISMODULE_SUBEVENT_PERSISTENCE_FAILED 4 +#define REDISMODULE_SUBEVENT_PERSISTENCE_SYNC_AOF_START 5 +#define _REDISMODULE_SUBEVENT_PERSISTENCE_NEXT 6 + +#define REDISMODULE_SUBEVENT_LOADING_RDB_START 0 +#define REDISMODULE_SUBEVENT_LOADING_AOF_START 1 +#define REDISMODULE_SUBEVENT_LOADING_REPL_START 2 +#define REDISMODULE_SUBEVENT_LOADING_ENDED 3 +#define REDISMODULE_SUBEVENT_LOADING_FAILED 4 +#define _REDISMODULE_SUBEVENT_LOADING_NEXT 5 + +#define REDISMODULE_SUBEVENT_CLIENT_CHANGE_CONNECTED 0 +#define REDISMODULE_SUBEVENT_CLIENT_CHANGE_DISCONNECTED 1 +#define _REDISMODULE_SUBEVENT_CLIENT_CHANGE_NEXT 2 + +#define REDISMODULE_SUBEVENT_MASTER_LINK_UP 0 +#define REDISMODULE_SUBEVENT_MASTER_LINK_DOWN 1 +#define _REDISMODULE_SUBEVENT_MASTER_NEXT 2 + +#define REDISMODULE_SUBEVENT_REPLICA_CHANGE_ONLINE 0 +#define REDISMODULE_SUBEVENT_REPLICA_CHANGE_OFFLINE 1 +#define _REDISMODULE_SUBEVENT_REPLICA_CHANGE_NEXT 2 + +#define REDISMODULE_EVENT_REPLROLECHANGED_NOW_MASTER 0 +#define REDISMODULE_EVENT_REPLROLECHANGED_NOW_REPLICA 1 +#define _REDISMODULE_EVENT_REPLROLECHANGED_NEXT 2 + +#define REDISMODULE_SUBEVENT_FLUSHDB_START 0 +#define REDISMODULE_SUBEVENT_FLUSHDB_END 1 +#define _REDISMODULE_SUBEVENT_FLUSHDB_NEXT 2 + +#define REDISMODULE_SUBEVENT_MODULE_LOADED 0 +#define REDISMODULE_SUBEVENT_MODULE_UNLOADED 1 +#define _REDISMODULE_SUBEVENT_MODULE_NEXT 2 + +#define REDISMODULE_SUBEVENT_CONFIG_CHANGE 0 +#define _REDISMODULE_SUBEVENT_CONFIG_NEXT 1 + +#define REDISMODULE_SUBEVENT_LOADING_PROGRESS_RDB 0 +#define REDISMODULE_SUBEVENT_LOADING_PROGRESS_AOF 1 +#define _REDISMODULE_SUBEVENT_LOADING_PROGRESS_NEXT 2 + +/* Replication Backup events are deprecated since Redis 7.0 and are never fired. */ +#define REDISMODULE_SUBEVENT_REPL_BACKUP_CREATE 0 +#define REDISMODULE_SUBEVENT_REPL_BACKUP_RESTORE 1 +#define REDISMODULE_SUBEVENT_REPL_BACKUP_DISCARD 2 +#define _REDISMODULE_SUBEVENT_REPL_BACKUP_NEXT 3 + +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_STARTED 0 +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_ABORTED 1 +#define REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_COMPLETED 2 +#define _REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_NEXT 3 + +#define REDISMODULE_SUBEVENT_FORK_CHILD_BORN 0 +#define REDISMODULE_SUBEVENT_FORK_CHILD_DIED 1 +#define _REDISMODULE_SUBEVENT_FORK_CHILD_NEXT 2 + +#define REDISMODULE_SUBEVENT_EVENTLOOP_BEFORE_SLEEP 0 +#define REDISMODULE_SUBEVENT_EVENTLOOP_AFTER_SLEEP 1 +#define _REDISMODULE_SUBEVENT_EVENTLOOP_NEXT 2 + +#define REDISMODULE_SUBEVENT_KEY_DELETED 0 +#define REDISMODULE_SUBEVENT_KEY_EXPIRED 1 +#define REDISMODULE_SUBEVENT_KEY_EVICTED 2 +#define REDISMODULE_SUBEVENT_KEY_OVERWRITTEN 3 +#define _REDISMODULE_SUBEVENT_KEY_NEXT 4 + +#define _REDISMODULE_SUBEVENT_SHUTDOWN_NEXT 0 +#define _REDISMODULE_SUBEVENT_CRON_LOOP_NEXT 0 +#define _REDISMODULE_SUBEVENT_SWAPDB_NEXT 0 + +/* RedisModuleClientInfo flags. */ +#define REDISMODULE_CLIENTINFO_FLAG_SSL (1<<0) +#define REDISMODULE_CLIENTINFO_FLAG_PUBSUB (1<<1) +#define REDISMODULE_CLIENTINFO_FLAG_BLOCKED (1<<2) +#define REDISMODULE_CLIENTINFO_FLAG_TRACKING (1<<3) +#define REDISMODULE_CLIENTINFO_FLAG_UNIXSOCKET (1<<4) +#define REDISMODULE_CLIENTINFO_FLAG_MULTI (1<<5) + +/* Here we take all the structures that the module pass to the core + * and the other way around. Notably the list here contains the structures + * used by the hooks API RedisModule_RegisterToServerEvent(). + * + * The structures always start with a 'version' field. This is useful + * when we want to pass a reference to the structure to the core APIs, + * for the APIs to fill the structure. In that case, the structure 'version' + * field is initialized before passing it to the core, so that the core is + * able to cast the pointer to the appropriate structure version. In this + * way we obtain ABI compatibility. + * + * Here we'll list all the structure versions in case they evolve over time, + * however using a define, we'll make sure to use the last version as the + * public name for the module to use. */ + +#define REDISMODULE_CLIENTINFO_VERSION 1 +typedef struct RedisModuleClientInfo { + uint64_t version; /* Version of this structure for ABI compat. */ + uint64_t flags; /* REDISMODULE_CLIENTINFO_FLAG_* */ + uint64_t id; /* Client ID. */ + char addr[46]; /* IPv4 or IPv6 address. */ + uint16_t port; /* TCP port. */ + uint16_t db; /* Selected DB. */ +} RedisModuleClientInfoV1; + +#define RedisModuleClientInfo RedisModuleClientInfoV1 + +#define REDISMODULE_CLIENTINFO_INITIALIZER_V1 { .version = 1 } + +#define REDISMODULE_REPLICATIONINFO_VERSION 1 +typedef struct RedisModuleReplicationInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int master; /* true if master, false if replica */ + char *masterhost; /* master instance hostname for NOW_REPLICA */ + int masterport; /* master instance port for NOW_REPLICA */ + char *replid1; /* Main replication ID */ + char *replid2; /* Secondary replication ID */ + uint64_t repl1_offset; /* Main replication offset */ + uint64_t repl2_offset; /* Offset of replid2 validity */ +} RedisModuleReplicationInfoV1; + +#define RedisModuleReplicationInfo RedisModuleReplicationInfoV1 + +#define REDISMODULE_FLUSHINFO_VERSION 1 +typedef struct RedisModuleFlushInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t sync; /* Synchronous or threaded flush?. */ + int32_t dbnum; /* Flushed database number, -1 for ALL. */ +} RedisModuleFlushInfoV1; + +#define RedisModuleFlushInfo RedisModuleFlushInfoV1 + +#define REDISMODULE_MODULE_CHANGE_VERSION 1 +typedef struct RedisModuleModuleChange { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + const char* module_name;/* Name of module loaded or unloaded. */ + int32_t module_version; /* Module version. */ +} RedisModuleModuleChangeV1; + +#define RedisModuleModuleChange RedisModuleModuleChangeV1 + +#define REDISMODULE_CONFIGCHANGE_VERSION 1 +typedef struct RedisModuleConfigChange { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + uint32_t num_changes; /* how many redis config options were changed */ + const char **config_names; /* the config names that were changed */ +} RedisModuleConfigChangeV1; + +#define RedisModuleConfigChange RedisModuleConfigChangeV1 + +#define REDISMODULE_CRON_LOOP_VERSION 1 +typedef struct RedisModuleCronLoopInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t hz; /* Approximate number of events per second. */ +} RedisModuleCronLoopV1; + +#define RedisModuleCronLoop RedisModuleCronLoopV1 + +#define REDISMODULE_LOADING_PROGRESS_VERSION 1 +typedef struct RedisModuleLoadingProgressInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t hz; /* Approximate number of events per second. */ + int32_t progress; /* Approximate progress between 0 and 1024, or -1 + * if unknown. */ +} RedisModuleLoadingProgressV1; + +#define RedisModuleLoadingProgress RedisModuleLoadingProgressV1 + +#define REDISMODULE_SWAPDBINFO_VERSION 1 +typedef struct RedisModuleSwapDbInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + int32_t dbnum_first; /* Swap Db first dbnum */ + int32_t dbnum_second; /* Swap Db second dbnum */ +} RedisModuleSwapDbInfoV1; + +#define RedisModuleSwapDbInfo RedisModuleSwapDbInfoV1 + +#define REDISMODULE_KEYINFO_VERSION 1 +typedef struct RedisModuleKeyInfo { + uint64_t version; /* Not used since this structure is never passed + from the module to the core right now. Here + for future compatibility. */ + RedisModuleKey *key; /* Opened key. */ +} RedisModuleKeyInfoV1; + +#define RedisModuleKeyInfo RedisModuleKeyInfoV1 + +typedef enum { + REDISMODULE_ACL_LOG_AUTH = 0, /* Authentication failure */ + REDISMODULE_ACL_LOG_CMD, /* Command authorization failure */ + REDISMODULE_ACL_LOG_KEY, /* Key authorization failure */ + REDISMODULE_ACL_LOG_CHANNEL /* Channel authorization failure */ +} RedisModuleACLLogEntryReason; + +/* Incomplete structures needed by both the core and modules. */ +typedef struct RedisModuleIO RedisModuleIO; +typedef struct RedisModuleDigest RedisModuleDigest; +typedef struct RedisModuleInfoCtx RedisModuleInfoCtx; +typedef struct RedisModuleDefragCtx RedisModuleDefragCtx; + +/* Function pointers needed by both the core and modules, these needs to be + * exposed since you can't cast a function pointer to (void *). */ +typedef void (*RedisModuleInfoFunc)(RedisModuleInfoCtx *ctx, int for_crash_report); +typedef void (*RedisModuleDefragFunc)(RedisModuleDefragCtx *ctx); +typedef void (*RedisModuleUserChangedFunc) (uint64_t client_id, void *privdata); + +/* ------------------------- End of common defines ------------------------ */ + +/* ----------- The rest of the defines are only for modules ----------------- */ +#if !defined REDISMODULE_CORE || defined REDISMODULE_CORE_MODULE +/* Things defined for modules and core-modules. */ + +/* Macro definitions specific to individual compilers */ +#ifndef REDISMODULE_ATTR_UNUSED +# ifdef __GNUC__ +# define REDISMODULE_ATTR_UNUSED __attribute__((unused)) +# else +# define REDISMODULE_ATTR_UNUSED +# endif +#endif + +#ifndef REDISMODULE_ATTR_PRINTF +# ifdef __GNUC__ +# define REDISMODULE_ATTR_PRINTF(idx,cnt) __attribute__((format(printf,idx,cnt))) +# else +# define REDISMODULE_ATTR_PRINTF(idx,cnt) +# endif +#endif + +#ifndef REDISMODULE_ATTR_COMMON +# if defined(__GNUC__) && !(defined(__clang__) && defined(__cplusplus)) +# define REDISMODULE_ATTR_COMMON __attribute__((__common__)) +# else +# define REDISMODULE_ATTR_COMMON +# endif +#endif + +/* Incomplete structures for compiler checks but opaque access. */ +typedef struct RedisModuleCtx RedisModuleCtx; +typedef struct RedisModuleCommand RedisModuleCommand; +typedef struct RedisModuleCallReply RedisModuleCallReply; +typedef struct RedisModuleType RedisModuleType; +typedef struct RedisModuleBlockedClient RedisModuleBlockedClient; +typedef struct RedisModuleClusterInfo RedisModuleClusterInfo; +typedef struct RedisModuleDict RedisModuleDict; +typedef struct RedisModuleDictIter RedisModuleDictIter; +typedef struct RedisModuleCommandFilterCtx RedisModuleCommandFilterCtx; +typedef struct RedisModuleCommandFilter RedisModuleCommandFilter; +typedef struct RedisModuleServerInfoData RedisModuleServerInfoData; +typedef struct RedisModuleScanCursor RedisModuleScanCursor; +typedef struct RedisModuleUser RedisModuleUser; +typedef struct RedisModuleKeyOptCtx RedisModuleKeyOptCtx; +typedef struct RedisModuleRdbStream RedisModuleRdbStream; + +typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); +typedef void (*RedisModuleDisconnectFunc)(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc); +typedef int (*RedisModuleNotificationFunc)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key); +typedef void (*RedisModulePostNotificationJobFunc) (RedisModuleCtx *ctx, void *pd); +typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver); +typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value); +typedef int (*RedisModuleTypeAuxLoadFunc)(RedisModuleIO *rdb, int encver, int when); +typedef void (*RedisModuleTypeAuxSaveFunc)(RedisModuleIO *rdb, int when); +typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, void *value); +typedef size_t (*RedisModuleTypeMemUsageFunc)(const void *value); +typedef size_t (*RedisModuleTypeMemUsageFunc2)(RedisModuleKeyOptCtx *ctx, const void *value, size_t sample_size); +typedef void (*RedisModuleTypeDigestFunc)(RedisModuleDigest *digest, void *value); +typedef void (*RedisModuleTypeFreeFunc)(void *value); +typedef size_t (*RedisModuleTypeFreeEffortFunc)(RedisModuleString *key, const void *value); +typedef size_t (*RedisModuleTypeFreeEffortFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef void (*RedisModuleTypeUnlinkFunc)(RedisModuleString *key, const void *value); +typedef void (*RedisModuleTypeUnlinkFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef void *(*RedisModuleTypeCopyFunc)(RedisModuleString *fromkey, RedisModuleString *tokey, const void *value); +typedef void *(*RedisModuleTypeCopyFunc2)(RedisModuleKeyOptCtx *ctx, const void *value); +typedef int (*RedisModuleTypeDefragFunc)(RedisModuleDefragCtx *ctx, RedisModuleString *key, void **value); +typedef void (*RedisModuleClusterMessageReceiver)(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, const unsigned char *payload, uint32_t len); +typedef void (*RedisModuleTimerProc)(RedisModuleCtx *ctx, void *data); +typedef void (*RedisModuleCommandFilterFunc) (RedisModuleCommandFilterCtx *filter); +typedef void (*RedisModuleForkDoneHandler) (int exitcode, int bysignal, void *user_data); +typedef void (*RedisModuleScanCB)(RedisModuleCtx *ctx, RedisModuleString *keyname, RedisModuleKey *key, void *privdata); +typedef void (*RedisModuleScanKeyCB)(RedisModuleKey *key, RedisModuleString *field, RedisModuleString *value, void *privdata); +typedef RedisModuleString * (*RedisModuleConfigGetStringFunc)(const char *name, void *privdata); +typedef long long (*RedisModuleConfigGetNumericFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigGetBoolFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigGetEnumFunc)(const char *name, void *privdata); +typedef int (*RedisModuleConfigSetStringFunc)(const char *name, RedisModuleString *val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetNumericFunc)(const char *name, long long val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetBoolFunc)(const char *name, int val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigSetEnumFunc)(const char *name, int val, void *privdata, RedisModuleString **err); +typedef int (*RedisModuleConfigApplyFunc)(RedisModuleCtx *ctx, void *privdata, RedisModuleString **err); +typedef void (*RedisModuleOnUnblocked)(RedisModuleCtx *ctx, RedisModuleCallReply *reply, void *private_data); +typedef int (*RedisModuleAuthCallback)(RedisModuleCtx *ctx, RedisModuleString *username, RedisModuleString *password, RedisModuleString **err); + +typedef struct RedisModuleTypeMethods { + uint64_t version; + RedisModuleTypeLoadFunc rdb_load; + RedisModuleTypeSaveFunc rdb_save; + RedisModuleTypeRewriteFunc aof_rewrite; + RedisModuleTypeMemUsageFunc mem_usage; + RedisModuleTypeDigestFunc digest; + RedisModuleTypeFreeFunc free; + RedisModuleTypeAuxLoadFunc aux_load; + RedisModuleTypeAuxSaveFunc aux_save; + int aux_save_triggers; + RedisModuleTypeFreeEffortFunc free_effort; + RedisModuleTypeUnlinkFunc unlink; + RedisModuleTypeCopyFunc copy; + RedisModuleTypeDefragFunc defrag; + RedisModuleTypeMemUsageFunc2 mem_usage2; + RedisModuleTypeFreeEffortFunc2 free_effort2; + RedisModuleTypeUnlinkFunc2 unlink2; + RedisModuleTypeCopyFunc2 copy2; + RedisModuleTypeAuxSaveFunc aux_save2; +} RedisModuleTypeMethods; + +#define REDISMODULE_GET_API(name) \ + RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_ ## name)) + +/* Default API declaration prefix (not 'extern' for backwards compatibility) */ +#ifndef REDISMODULE_API +#define REDISMODULE_API +#endif + +/* Default API declaration suffix (compiler attributes) */ +#ifndef REDISMODULE_ATTR +#define REDISMODULE_ATTR REDISMODULE_ATTR_COMMON +#endif + +REDISMODULE_API void * (*RedisModule_Alloc)(size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryAlloc)(size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_Realloc)(void *ptr, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryRealloc)(void *ptr, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Free)(void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_Calloc)(size_t nmemb, size_t size) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_TryCalloc)(size_t nmemb, size_t size) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_Strdup)(const char *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetApi)(const char *, void *) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCommand *(*RedisModule_GetCommand)(RedisModuleCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CreateSubcommand)(RedisModuleCommand *parent, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetCommandInfo)(RedisModuleCommand *command, const RedisModuleCommandInfo *info) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetCommandACLCategories)(RedisModuleCommand *command, const char *ctgrsflags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AddACLCategory)(RedisModuleCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, const char *name, int ver, int apiver) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsModuleNameBusy)(const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_WrongArity)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, long long ll) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetSelectedDb)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SelectDb)(RedisModuleCtx *ctx, int newid) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KeyExists)(RedisModuleCtx *ctx, RedisModuleString *keyname) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleKey * (*RedisModule_OpenKey)(RedisModuleCtx *ctx, RedisModuleString *keyname, int mode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetOpenKeyModesAll)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_CloseKey)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KeyType)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_ValueLength)(RedisModuleKey *kp) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListPush)(RedisModuleKey *kp, int where, RedisModuleString *ele) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ListPop)(RedisModuleKey *key, int where) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ListGet)(RedisModuleKey *key, long index) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListSet)(RedisModuleKey *key, long index, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListInsert)(RedisModuleKey *key, long index, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ListDelete)(RedisModuleKey *key, long index) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_Call)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_CallReplyProto)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeCallReply)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyType)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_CallReplyInteger)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_CallReplyDouble)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyBool)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API const char* (*RedisModule_CallReplyBigNumber)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API const char* (*RedisModule_CallReplyVerbatim)(RedisModuleCallReply *reply, size_t *len, const char **format) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplySetElement)(RedisModuleCallReply *reply, size_t idx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyMapElement)(RedisModuleCallReply *reply, size_t idx, RedisModuleCallReply **key, RedisModuleCallReply **val) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyAttributeElement)(RedisModuleCallReply *reply, size_t idx, RedisModuleCallReply **key, RedisModuleCallReply **val) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_CallReplyPromiseSetUnblockHandler)(RedisModuleCallReply *reply, RedisModuleOnUnblocked on_unblock, void *private_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CallReplyPromiseAbort)(RedisModuleCallReply *reply, void **private_data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplyAttribute)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_CallReplyLength)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCallReply * (*RedisModule_CallReplyArrayElement)(RedisModuleCallReply *reply, size_t idx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateString)(RedisModuleCtx *ctx, const char *ptr, size_t len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromLongLong)(RedisModuleCtx *ctx, long long ll) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromULongLong)(RedisModuleCtx *ctx, unsigned long long ull) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromDouble)(RedisModuleCtx *ctx, double d) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromLongDouble)(RedisModuleCtx *ctx, long double ld, int humanfriendly) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromString)(RedisModuleCtx *ctx, const RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromStreamID)(RedisModuleCtx *ctx, const RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringPrintf)(RedisModuleCtx *ctx, const char *fmt, ...) REDISMODULE_ATTR_PRINTF(2,3) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_StringPtrLen)(const RedisModuleString *str, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithError)(RedisModuleCtx *ctx, const char *err) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithErrorFormat)(RedisModuleCtx *ctx, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, const char *msg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithArray)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithMap)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithSet)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithAttribute)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithNullArray)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithEmptyArray)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetArrayLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetMapLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetSetLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetAttributeLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ReplySetPushLength)(RedisModuleCtx *ctx, long len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithCString)(RedisModuleCtx *ctx, const char *buf) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithEmptyString)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithVerbatimString)(RedisModuleCtx *ctx, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithVerbatimStringType)(RedisModuleCtx *ctx, const char *buf, size_t len, const char *ext) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithNull)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithBool)(RedisModuleCtx *ctx, int b) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithLongDouble)(RedisModuleCtx *ctx, long double d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithDouble)(RedisModuleCtx *ctx, double d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithBigNumber)(RedisModuleCtx *ctx, const char *bignum, size_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToLongLong)(const RedisModuleString *str, long long *ll) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToULongLong)(const RedisModuleString *str, unsigned long long *ull) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToDouble)(const RedisModuleString *str, double *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToLongDouble)(const RedisModuleString *str, long double *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringToStreamID)(const RedisModuleString *str, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_AutoMemory)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ReplicateVerbatim)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_CallReplyStringPtr)(RedisModuleCallReply *reply, size_t *len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CreateStringFromCallReply)(RedisModuleCallReply *reply) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DeleteKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnlinkKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringSet)(RedisModuleKey *key, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, int mode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringTruncate)(RedisModuleKey *key, size_t newlen) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_GetExpire)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetExpire)(RedisModuleKey *key, mstime_t expire) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_GetAbsExpire)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetAbsExpire)(RedisModuleKey *key, mstime_t expire) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ResetDataset)(int restart_aof, int async) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_DbSize)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_RandomKey)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetAdd)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr, double *newscore) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetScore)(RedisModuleKey *key, RedisModuleString *ele, double *score) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, int *deleted) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ZsetRangeStop)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ZsetRangeCurrentElement)(RedisModuleKey *key, double *score) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangeNext)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangePrev)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ZsetRangeEndReached)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_HashSet)(RedisModuleKey *key, int flags, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_HashGet)(RedisModuleKey *key, int flags, ...) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamAdd)(RedisModuleKey *key, int flags, RedisModuleStreamID *id, RedisModuleString **argv, int64_t numfields) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamDelete)(RedisModuleKey *key, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorStart)(RedisModuleKey *key, int flags, RedisModuleStreamID *startid, RedisModuleStreamID *endid) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorStop)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorNextID)(RedisModuleKey *key, RedisModuleStreamID *id, long *numfields) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorNextField)(RedisModuleKey *key, RedisModuleString **field_ptr, RedisModuleString **value_ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StreamIteratorDelete)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_StreamTrimByLength)(RedisModuleKey *key, int flags, long long length) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_StreamTrimByID)(RedisModuleKey *key, int flags, RedisModuleStreamID *id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsKeysPositionRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_KeyAtPos)(RedisModuleCtx *ctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_KeyAtPosWithFlags)(RedisModuleCtx *ctx, int pos, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsChannelsPositionRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ChannelAtPosWithFlags)(RedisModuleCtx *ctx, int pos, int flags) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_GetClientId)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientUserNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetClientInfoById)(void *ci, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetClientNameById)(uint64_t id, RedisModuleString *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_PublishMessage)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_PublishMessageShard)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetContextFlags)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AvoidReplicaTraffic)(void) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleType * (*RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ModuleTypeReplaceValue)(RedisModuleKey *key, RedisModuleType *mt, void *new_value, void **old_value) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleType * (*RedisModule_ModuleTypeGetType)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_ModuleTypeGetValue)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsIOError)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetModuleOptions)(RedisModuleCtx *ctx, int options) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SignalModifiedKey)(RedisModuleCtx *ctx, RedisModuleString *keyname) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveUnsigned)(RedisModuleIO *io, uint64_t value) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_LoadUnsigned)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveSigned)(RedisModuleIO *io, int64_t value) REDISMODULE_ATTR; +REDISMODULE_API int64_t (*RedisModule_LoadSigned)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, const char *fmt, ...) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveString)(RedisModuleIO *io, RedisModuleString *s) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveStringBuffer)(RedisModuleIO *io, const char *str, size_t len) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_LoadString)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API char * (*RedisModule_LoadStringBuffer)(RedisModuleIO *io, size_t *lenptr) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveDouble)(RedisModuleIO *io, double value) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_LoadDouble)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveFloat)(RedisModuleIO *io, float value) REDISMODULE_ATTR; +REDISMODULE_API float (*RedisModule_LoadFloat)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SaveLongDouble)(RedisModuleIO *io, long double value) REDISMODULE_ATTR; +REDISMODULE_API long double (*RedisModule_LoadLongDouble)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_LoadDataTypeFromString)(const RedisModuleString *str, const RedisModuleType *mt) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_LoadDataTypeFromStringEncver)(const RedisModuleString *str, const RedisModuleType *mt, int encver) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_SaveDataTypeToString)(RedisModuleCtx *ctx, void *data, const RedisModuleType *mt) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Log)(RedisModuleCtx *ctx, const char *level, const char *fmt, ...) REDISMODULE_ATTR REDISMODULE_ATTR_PRINTF(3,4); +REDISMODULE_API void (*RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, const char *fmt, ...) REDISMODULE_ATTR REDISMODULE_ATTR_PRINTF(3,4); +REDISMODULE_API void (*RedisModule__Assert)(const char *estr, const char *file, int line) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_LatencyAddSample)(const char *event, mstime_t latency) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, RedisModuleString *str, const char *buf, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_TrimStringAllocation)(RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RetainString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_HoldString)(RedisModuleCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StringCompare)(const RedisModuleString *a, const RedisModuleString *b) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetContextFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromModuleKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromModuleKey)(RedisModuleKey *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromIO)(RedisModuleIO *io) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetToDbIdFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetToKeyNameFromOptCtx)(RedisModuleKeyOptCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API mstime_t (*RedisModule_Milliseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_MonotonicMicroseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API ustime_t (*RedisModule_Microseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API ustime_t (*RedisModule_CachedMicroseconds)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestAddStringBuffer)(RedisModuleDigest *md, const char *ele, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestAddLongLong)(RedisModuleDigest *md, long long ele) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DigestEndSequence)(RedisModuleDigest *md) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromDigest)(RedisModuleDigest *dig) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromDigest)(RedisModuleDigest *dig) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDict * (*RedisModule_CreateDict)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeDict)(RedisModuleCtx *ctx, RedisModuleDict *d) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_DictSize)(RedisModuleDict *d) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictSetC)(RedisModuleDict *d, void *key, size_t keylen, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictReplaceC)(RedisModuleDict *d, void *key, size_t keylen, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictSet)(RedisModuleDict *d, RedisModuleString *key, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictReplace)(RedisModuleDict *d, RedisModuleString *key, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictGetC)(RedisModuleDict *d, void *key, size_t keylen, int *nokey) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictGet)(RedisModuleDict *d, RedisModuleString *key, int *nokey) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictDelC)(RedisModuleDict *d, void *key, size_t keylen, void *oldval) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictDel)(RedisModuleDict *d, RedisModuleString *key, void *oldval) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDictIter * (*RedisModule_DictIteratorStartC)(RedisModuleDict *d, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleDictIter * (*RedisModule_DictIteratorStart)(RedisModuleDict *d, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DictIteratorStop)(RedisModuleDictIter *di) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictIteratorReseekC)(RedisModuleDictIter *di, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictIteratorReseek)(RedisModuleDictIter *di, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictNextC)(RedisModuleDictIter *di, size_t *keylen, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_DictPrevC)(RedisModuleDictIter *di, size_t *keylen, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_DictNext)(RedisModuleCtx *ctx, RedisModuleDictIter *di, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_DictPrev)(RedisModuleCtx *ctx, RedisModuleDictIter *di, void **dataptr) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictCompareC)(RedisModuleDictIter *di, const char *op, void *key, size_t keylen) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DictCompare)(RedisModuleDictIter *di, const char *op, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterInfoFunc)(RedisModuleCtx *ctx, RedisModuleInfoFunc cb) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RegisterAuthCallback)(RedisModuleCtx *ctx, RedisModuleAuthCallback cb) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddSection)(RedisModuleInfoCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoBeginDictField)(RedisModuleInfoCtx *ctx, const char *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoEndDictField)(RedisModuleInfoCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldString)(RedisModuleInfoCtx *ctx, const char *field, RedisModuleString *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldCString)(RedisModuleInfoCtx *ctx, const char *field,const char *value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldDouble)(RedisModuleInfoCtx *ctx, const char *field, double value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldLongLong)(RedisModuleInfoCtx *ctx, const char *field, long long value) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_InfoAddFieldULongLong)(RedisModuleInfoCtx *ctx, const char *field, unsigned long long value) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleServerInfoData * (*RedisModule_GetServerInfo)(RedisModuleCtx *ctx, const char *section) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeServerInfo)(RedisModuleCtx *ctx, RedisModuleServerInfoData *data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_ServerInfoGetField)(RedisModuleCtx *ctx, RedisModuleServerInfoData *data, const char* field) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_ServerInfoGetFieldC)(RedisModuleServerInfoData *data, const char* field) REDISMODULE_ATTR; +REDISMODULE_API long long (*RedisModule_ServerInfoGetFieldSigned)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_ServerInfoGetFieldUnsigned)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API double (*RedisModule_ServerInfoGetFieldDouble)(RedisModuleServerInfoData *data, const char* field, int *out_err) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SubscribeToServerEvent)(RedisModuleCtx *ctx, RedisModuleEvent event, RedisModuleEventCallback callback) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetLRU)(RedisModuleKey *key, mstime_t lru_idle) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetLRU)(RedisModuleKey *key, mstime_t *lru_idle) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetLFU)(RedisModuleKey *key, long long lfu_freq) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetLFU)(RedisModuleKey *key, long long *lfu_freq) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeys)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms, RedisModuleString **keys, int numkeys, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeysWithFlags)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms, RedisModuleString **keys, int numkeys, void *privdata, int flags) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SignalKeyAsReady)(RedisModuleCtx *ctx, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetBlockedClientReadyKey)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleScanCursor * (*RedisModule_ScanCursorCreate)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ScanCursorRestart)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ScanCursorDestroy)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Scan)(RedisModuleCtx *ctx, RedisModuleScanCursor *cursor, RedisModuleScanCB fn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ScanKey)(RedisModuleKey *key, RedisModuleScanCursor *cursor, RedisModuleScanKeyCB fn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetContextFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetModuleOptionsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetKeyspaceNotificationFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsSubEventSupported)(RedisModuleEvent event, uint64_t subevent) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetServerVersion)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetTypeMethodVersion)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_Yield)(RedisModuleCtx *ctx, int flags, const char *busy_reply) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClient)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_BlockClientGetPrivateData)(RedisModuleBlockedClient *blocked_client) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_BlockClientSetPrivateData)(RedisModuleBlockedClient *blocked_client, void *private_data) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnAuth)(RedisModuleCtx *ctx, RedisModuleAuthCallback reply_callback, void (*free_privdata)(RedisModuleCtx*,void*)) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsBlockedReplyRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_IsBlockedTimeoutRequest)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_GetBlockedClientPrivateData)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_GetBlockedClientHandle)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AbortBlock)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientMeasureTimeStart)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientMeasureTimeEnd)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetThreadSafeContext)(RedisModuleBlockedClient *bc) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCtx * (*RedisModule_GetDetachedThreadSafeContext)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeThreadSafeContext)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ThreadSafeContextLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ThreadSafeContextTryLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ThreadSafeContextUnlock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SubscribeToKeyspaceEvents)(RedisModuleCtx *ctx, int types, RedisModuleNotificationFunc cb) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AddPostNotificationJob)(RedisModuleCtx *ctx, RedisModulePostNotificationJobFunc callback, void *pd, void (*free_pd)(void*)) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_NotifyKeyspaceEvent)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_BlockedClientDisconnected)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RegisterClusterMessageReceiver)(RedisModuleCtx *ctx, uint8_t type, RedisModuleClusterMessageReceiver callback) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SendClusterMessage)(RedisModuleCtx *ctx, const char *target_id, uint8_t type, const char *msg, uint32_t len) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetClusterNodeInfo)(RedisModuleCtx *ctx, const char *id, char *ip, char *master_id, int *port, int *flags) REDISMODULE_ATTR; +REDISMODULE_API char ** (*RedisModule_GetClusterNodesList)(RedisModuleCtx *ctx, size_t *numnodes) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeClusterNodesList)(char **ids) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleTimerID (*RedisModule_CreateTimer)(RedisModuleCtx *ctx, mstime_t period, RedisModuleTimerProc callback, void *data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_StopTimer)(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetTimerInfo)(RedisModuleCtx *ctx, RedisModuleTimerID id, uint64_t *remaining, void **data) REDISMODULE_ATTR; +REDISMODULE_API const char * (*RedisModule_GetMyClusterID)(void) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_GetClusterSize)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_GetRandomBytes)(unsigned char *dst, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_GetRandomHexChars)(char *dst, size_t len) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetDisconnectCallback)(RedisModuleBlockedClient *bc, RedisModuleDisconnectFunc callback) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetClusterFlags)(RedisModuleCtx *ctx, uint64_t flags) REDISMODULE_ATTR; +REDISMODULE_API unsigned int (*RedisModule_ClusterKeySlot)(RedisModuleString *key) REDISMODULE_ATTR; +REDISMODULE_API const char *(*RedisModule_ClusterCanonicalKeyNameInSlot)(unsigned int slot) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ExportSharedAPI)(RedisModuleCtx *ctx, const char *apiname, void *func) REDISMODULE_ATTR; +REDISMODULE_API void * (*RedisModule_GetSharedAPI)(RedisModuleCtx *ctx, const char *apiname) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleCommandFilter * (*RedisModule_RegisterCommandFilter)(RedisModuleCtx *ctx, RedisModuleCommandFilterFunc cb, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_UnregisterCommandFilter)(RedisModuleCtx *ctx, RedisModuleCommandFilter *filter) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgsCount)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_CommandFilterArgGet)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgInsert)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgReplace)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_CommandFilterArgDelete)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_CommandFilterGetClientId)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_Fork)(RedisModuleForkDoneHandler cb, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SendChildHeartbeat)(double progress) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ExitFromChild)(int retcode) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_KillForkChild)(int child_pid) REDISMODULE_ATTR; +REDISMODULE_API float (*RedisModule_GetUsedMemoryRatio)(void) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSize)(void* ptr) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocUsableSize)(void *ptr) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSizeString)(RedisModuleString* str) REDISMODULE_ATTR; +REDISMODULE_API size_t (*RedisModule_MallocSizeDict)(RedisModuleDict* dict) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleUser * (*RedisModule_CreateModuleUser)(const char *name) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_FreeModuleUser)(RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetContextUser)(RedisModuleCtx *ctx, const RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetModuleUserACL)(RedisModuleUser *user, const char* acl) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetModuleUserACLString)(RedisModuleCtx * ctx, RedisModuleUser *user, const char* acl, RedisModuleString **error) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetModuleUserACLString)(RedisModuleUser *user) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetCurrentUserName)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleUser * (*RedisModule_GetModuleUserFromUserName)(RedisModuleString *name) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckCommandPermissions)(RedisModuleUser *user, RedisModuleString **argv, int argc) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckKeyPermissions)(RedisModuleUser *user, RedisModuleString *key, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_ACLCheckChannelPermissions)(RedisModuleUser *user, RedisModuleString *ch, int literal) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ACLAddLogEntry)(RedisModuleCtx *ctx, RedisModuleUser *user, RedisModuleString *object, RedisModuleACLLogEntryReason reason) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_ACLAddLogEntryByUserName)(RedisModuleCtx *ctx, RedisModuleString *user, RedisModuleString *object, RedisModuleACLLogEntryReason reason) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AuthenticateClientWithACLUser)(RedisModuleCtx *ctx, const char *name, size_t len, RedisModuleUserChangedFunc callback, void *privdata, uint64_t *client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AuthenticateClientWithUser)(RedisModuleCtx *ctx, RedisModuleUser *user, RedisModuleUserChangedFunc callback, void *privdata, uint64_t *client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DeauthenticateAndCloseClient)(RedisModuleCtx *ctx, uint64_t client_id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RedactClientCommandArgument)(RedisModuleCtx *ctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientCertificate)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int *(*RedisModule_GetCommandKeys)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, int *num_keys) REDISMODULE_ATTR; +REDISMODULE_API int *(*RedisModule_GetCommandKeysWithFlags)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, int *num_keys, int **out_flags) REDISMODULE_ATTR; +REDISMODULE_API const char *(*RedisModule_GetCurrentCommandName)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterDefragFunc)(RedisModuleCtx *ctx, RedisModuleDefragFunc func) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterDefragCallbacks)(RedisModuleCtx *ctx, RedisModuleDefragFunc start, RedisModuleDefragFunc end) REDISMODULE_ATTR; +REDISMODULE_API void *(*RedisModule_DefragAlloc)(RedisModuleDefragCtx *ctx, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API void *(*RedisModule_DefragAllocRaw)(RedisModuleDefragCtx *ctx, size_t size) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_DefragFreeRaw)(RedisModuleDefragCtx *ctx, void *ptr) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString *(*RedisModule_DefragRedisModuleString)(RedisModuleDefragCtx *ctx, RedisModuleString *str) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragShouldStop)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragCursorSet)(RedisModuleDefragCtx *ctx, unsigned long cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_DefragCursorGet)(RedisModuleDefragCtx *ctx, unsigned long *cursor) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetDbIdFromDefragCtx)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API const RedisModuleString * (*RedisModule_GetKeyNameFromDefragCtx)(RedisModuleDefragCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopAdd)(int fd, int mask, RedisModuleEventLoopFunc func, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopDel)(int fd, int mask) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_EventLoopAddOneShot)(RedisModuleEventLoopOneShotFunc func, void *user_data) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterBoolConfig)(RedisModuleCtx *ctx, const char *name, int default_val, unsigned int flags, RedisModuleConfigGetBoolFunc getfn, RedisModuleConfigSetBoolFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterNumericConfig)(RedisModuleCtx *ctx, const char *name, long long default_val, unsigned int flags, long long min, long long max, RedisModuleConfigGetNumericFunc getfn, RedisModuleConfigSetNumericFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterStringConfig)(RedisModuleCtx *ctx, const char *name, const char *default_val, unsigned int flags, RedisModuleConfigGetStringFunc getfn, RedisModuleConfigSetStringFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RegisterEnumConfig)(RedisModuleCtx *ctx, const char *name, int default_val, unsigned int flags, const char **enum_values, const int *int_values, int num_enum_vals, RedisModuleConfigGetEnumFunc getfn, RedisModuleConfigSetEnumFunc setfn, RedisModuleConfigApplyFunc applyfn, void *privdata) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_LoadConfigs)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleRdbStream *(*RedisModule_RdbStreamCreateFromFile)(const char *filename) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_RdbStreamFree)(RedisModuleRdbStream *stream) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RdbLoad)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_RdbSave)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; + +#define RedisModule_IsAOFClient(id) ((id) == UINT64_MAX) + +/* This is included inline inside each Redis module. */ +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) REDISMODULE_ATTR_UNUSED; +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { + void *getapifuncptr = ((void**)ctx)[0]; + RedisModule_GetApi = (int (*)(const char *, void *)) (unsigned long)getapifuncptr; + REDISMODULE_GET_API(Alloc); + REDISMODULE_GET_API(TryAlloc); + REDISMODULE_GET_API(Calloc); + REDISMODULE_GET_API(TryCalloc); + REDISMODULE_GET_API(Free); + REDISMODULE_GET_API(Realloc); + REDISMODULE_GET_API(TryRealloc); + REDISMODULE_GET_API(Strdup); + REDISMODULE_GET_API(CreateCommand); + REDISMODULE_GET_API(GetCommand); + REDISMODULE_GET_API(CreateSubcommand); + REDISMODULE_GET_API(SetCommandInfo); + REDISMODULE_GET_API(SetCommandACLCategories); + REDISMODULE_GET_API(AddACLCategory); + REDISMODULE_GET_API(SetModuleAttribs); + REDISMODULE_GET_API(IsModuleNameBusy); + REDISMODULE_GET_API(WrongArity); + REDISMODULE_GET_API(ReplyWithLongLong); + REDISMODULE_GET_API(ReplyWithError); + REDISMODULE_GET_API(ReplyWithErrorFormat); + REDISMODULE_GET_API(ReplyWithSimpleString); + REDISMODULE_GET_API(ReplyWithArray); + REDISMODULE_GET_API(ReplyWithMap); + REDISMODULE_GET_API(ReplyWithSet); + REDISMODULE_GET_API(ReplyWithAttribute); + REDISMODULE_GET_API(ReplyWithNullArray); + REDISMODULE_GET_API(ReplyWithEmptyArray); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(ReplySetMapLength); + REDISMODULE_GET_API(ReplySetSetLength); + REDISMODULE_GET_API(ReplySetAttributeLength); + REDISMODULE_GET_API(ReplySetPushLength); + REDISMODULE_GET_API(ReplyWithStringBuffer); + REDISMODULE_GET_API(ReplyWithCString); + REDISMODULE_GET_API(ReplyWithString); + REDISMODULE_GET_API(ReplyWithEmptyString); + REDISMODULE_GET_API(ReplyWithVerbatimString); + REDISMODULE_GET_API(ReplyWithVerbatimStringType); + REDISMODULE_GET_API(ReplyWithNull); + REDISMODULE_GET_API(ReplyWithBool); + REDISMODULE_GET_API(ReplyWithCallReply); + REDISMODULE_GET_API(ReplyWithDouble); + REDISMODULE_GET_API(ReplyWithBigNumber); + REDISMODULE_GET_API(ReplyWithLongDouble); + REDISMODULE_GET_API(GetSelectedDb); + REDISMODULE_GET_API(SelectDb); + REDISMODULE_GET_API(KeyExists); + REDISMODULE_GET_API(OpenKey); + REDISMODULE_GET_API(GetOpenKeyModesAll); + REDISMODULE_GET_API(CloseKey); + REDISMODULE_GET_API(KeyType); + REDISMODULE_GET_API(ValueLength); + REDISMODULE_GET_API(ListPush); + REDISMODULE_GET_API(ListPop); + REDISMODULE_GET_API(ListGet); + REDISMODULE_GET_API(ListSet); + REDISMODULE_GET_API(ListInsert); + REDISMODULE_GET_API(ListDelete); + REDISMODULE_GET_API(StringToLongLong); + REDISMODULE_GET_API(StringToULongLong); + REDISMODULE_GET_API(StringToDouble); + REDISMODULE_GET_API(StringToLongDouble); + REDISMODULE_GET_API(StringToStreamID); + REDISMODULE_GET_API(Call); + REDISMODULE_GET_API(CallReplyProto); + REDISMODULE_GET_API(FreeCallReply); + REDISMODULE_GET_API(CallReplyInteger); + REDISMODULE_GET_API(CallReplyDouble); + REDISMODULE_GET_API(CallReplyBool); + REDISMODULE_GET_API(CallReplyBigNumber); + REDISMODULE_GET_API(CallReplyVerbatim); + REDISMODULE_GET_API(CallReplySetElement); + REDISMODULE_GET_API(CallReplyMapElement); + REDISMODULE_GET_API(CallReplyAttributeElement); + REDISMODULE_GET_API(CallReplyPromiseSetUnblockHandler); + REDISMODULE_GET_API(CallReplyPromiseAbort); + REDISMODULE_GET_API(CallReplyAttribute); + REDISMODULE_GET_API(CallReplyType); + REDISMODULE_GET_API(CallReplyLength); + REDISMODULE_GET_API(CallReplyArrayElement); + REDISMODULE_GET_API(CallReplyStringPtr); + REDISMODULE_GET_API(CreateStringFromCallReply); + REDISMODULE_GET_API(CreateString); + REDISMODULE_GET_API(CreateStringFromLongLong); + REDISMODULE_GET_API(CreateStringFromULongLong); + REDISMODULE_GET_API(CreateStringFromDouble); + REDISMODULE_GET_API(CreateStringFromLongDouble); + REDISMODULE_GET_API(CreateStringFromString); + REDISMODULE_GET_API(CreateStringFromStreamID); + REDISMODULE_GET_API(CreateStringPrintf); + REDISMODULE_GET_API(FreeString); + REDISMODULE_GET_API(StringPtrLen); + REDISMODULE_GET_API(AutoMemory); + REDISMODULE_GET_API(Replicate); + REDISMODULE_GET_API(ReplicateVerbatim); + REDISMODULE_GET_API(DeleteKey); + REDISMODULE_GET_API(UnlinkKey); + REDISMODULE_GET_API(StringSet); + REDISMODULE_GET_API(StringDMA); + REDISMODULE_GET_API(StringTruncate); + REDISMODULE_GET_API(GetExpire); + REDISMODULE_GET_API(SetExpire); + REDISMODULE_GET_API(GetAbsExpire); + REDISMODULE_GET_API(SetAbsExpire); + REDISMODULE_GET_API(ResetDataset); + REDISMODULE_GET_API(DbSize); + REDISMODULE_GET_API(RandomKey); + REDISMODULE_GET_API(ZsetAdd); + REDISMODULE_GET_API(ZsetIncrby); + REDISMODULE_GET_API(ZsetScore); + REDISMODULE_GET_API(ZsetRem); + REDISMODULE_GET_API(ZsetRangeStop); + REDISMODULE_GET_API(ZsetFirstInScoreRange); + REDISMODULE_GET_API(ZsetLastInScoreRange); + REDISMODULE_GET_API(ZsetFirstInLexRange); + REDISMODULE_GET_API(ZsetLastInLexRange); + REDISMODULE_GET_API(ZsetRangeCurrentElement); + REDISMODULE_GET_API(ZsetRangeNext); + REDISMODULE_GET_API(ZsetRangePrev); + REDISMODULE_GET_API(ZsetRangeEndReached); + REDISMODULE_GET_API(HashSet); + REDISMODULE_GET_API(HashGet); + REDISMODULE_GET_API(StreamAdd); + REDISMODULE_GET_API(StreamDelete); + REDISMODULE_GET_API(StreamIteratorStart); + REDISMODULE_GET_API(StreamIteratorStop); + REDISMODULE_GET_API(StreamIteratorNextID); + REDISMODULE_GET_API(StreamIteratorNextField); + REDISMODULE_GET_API(StreamIteratorDelete); + REDISMODULE_GET_API(StreamTrimByLength); + REDISMODULE_GET_API(StreamTrimByID); + REDISMODULE_GET_API(IsKeysPositionRequest); + REDISMODULE_GET_API(KeyAtPos); + REDISMODULE_GET_API(KeyAtPosWithFlags); + REDISMODULE_GET_API(IsChannelsPositionRequest); + REDISMODULE_GET_API(ChannelAtPosWithFlags); + REDISMODULE_GET_API(GetClientId); + REDISMODULE_GET_API(GetClientUserNameById); + REDISMODULE_GET_API(GetContextFlags); + REDISMODULE_GET_API(AvoidReplicaTraffic); + REDISMODULE_GET_API(PoolAlloc); + REDISMODULE_GET_API(CreateDataType); + REDISMODULE_GET_API(ModuleTypeSetValue); + REDISMODULE_GET_API(ModuleTypeReplaceValue); + REDISMODULE_GET_API(ModuleTypeGetType); + REDISMODULE_GET_API(ModuleTypeGetValue); + REDISMODULE_GET_API(IsIOError); + REDISMODULE_GET_API(SetModuleOptions); + REDISMODULE_GET_API(SignalModifiedKey); + REDISMODULE_GET_API(SaveUnsigned); + REDISMODULE_GET_API(LoadUnsigned); + REDISMODULE_GET_API(SaveSigned); + REDISMODULE_GET_API(LoadSigned); + REDISMODULE_GET_API(SaveString); + REDISMODULE_GET_API(SaveStringBuffer); + REDISMODULE_GET_API(LoadString); + REDISMODULE_GET_API(LoadStringBuffer); + REDISMODULE_GET_API(SaveDouble); + REDISMODULE_GET_API(LoadDouble); + REDISMODULE_GET_API(SaveFloat); + REDISMODULE_GET_API(LoadFloat); + REDISMODULE_GET_API(SaveLongDouble); + REDISMODULE_GET_API(LoadLongDouble); + REDISMODULE_GET_API(SaveDataTypeToString); + REDISMODULE_GET_API(LoadDataTypeFromString); + REDISMODULE_GET_API(LoadDataTypeFromStringEncver); + REDISMODULE_GET_API(EmitAOF); + REDISMODULE_GET_API(Log); + REDISMODULE_GET_API(LogIOError); + REDISMODULE_GET_API(_Assert); + REDISMODULE_GET_API(LatencyAddSample); + REDISMODULE_GET_API(StringAppendBuffer); + REDISMODULE_GET_API(TrimStringAllocation); + REDISMODULE_GET_API(RetainString); + REDISMODULE_GET_API(HoldString); + REDISMODULE_GET_API(StringCompare); + REDISMODULE_GET_API(GetContextFromIO); + REDISMODULE_GET_API(GetKeyNameFromIO); + REDISMODULE_GET_API(GetKeyNameFromModuleKey); + REDISMODULE_GET_API(GetDbIdFromModuleKey); + REDISMODULE_GET_API(GetDbIdFromIO); + REDISMODULE_GET_API(GetKeyNameFromOptCtx); + REDISMODULE_GET_API(GetToKeyNameFromOptCtx); + REDISMODULE_GET_API(GetDbIdFromOptCtx); + REDISMODULE_GET_API(GetToDbIdFromOptCtx); + REDISMODULE_GET_API(Milliseconds); + REDISMODULE_GET_API(MonotonicMicroseconds); + REDISMODULE_GET_API(Microseconds); + REDISMODULE_GET_API(CachedMicroseconds); + REDISMODULE_GET_API(DigestAddStringBuffer); + REDISMODULE_GET_API(DigestAddLongLong); + REDISMODULE_GET_API(DigestEndSequence); + REDISMODULE_GET_API(GetKeyNameFromDigest); + REDISMODULE_GET_API(GetDbIdFromDigest); + REDISMODULE_GET_API(CreateDict); + REDISMODULE_GET_API(FreeDict); + REDISMODULE_GET_API(DictSize); + REDISMODULE_GET_API(DictSetC); + REDISMODULE_GET_API(DictReplaceC); + REDISMODULE_GET_API(DictSet); + REDISMODULE_GET_API(DictReplace); + REDISMODULE_GET_API(DictGetC); + REDISMODULE_GET_API(DictGet); + REDISMODULE_GET_API(DictDelC); + REDISMODULE_GET_API(DictDel); + REDISMODULE_GET_API(DictIteratorStartC); + REDISMODULE_GET_API(DictIteratorStart); + REDISMODULE_GET_API(DictIteratorStop); + REDISMODULE_GET_API(DictIteratorReseekC); + REDISMODULE_GET_API(DictIteratorReseek); + REDISMODULE_GET_API(DictNextC); + REDISMODULE_GET_API(DictPrevC); + REDISMODULE_GET_API(DictNext); + REDISMODULE_GET_API(DictPrev); + REDISMODULE_GET_API(DictCompare); + REDISMODULE_GET_API(DictCompareC); + REDISMODULE_GET_API(RegisterInfoFunc); + REDISMODULE_GET_API(RegisterAuthCallback); + REDISMODULE_GET_API(InfoAddSection); + REDISMODULE_GET_API(InfoBeginDictField); + REDISMODULE_GET_API(InfoEndDictField); + REDISMODULE_GET_API(InfoAddFieldString); + REDISMODULE_GET_API(InfoAddFieldCString); + REDISMODULE_GET_API(InfoAddFieldDouble); + REDISMODULE_GET_API(InfoAddFieldLongLong); + REDISMODULE_GET_API(InfoAddFieldULongLong); + REDISMODULE_GET_API(GetServerInfo); + REDISMODULE_GET_API(FreeServerInfo); + REDISMODULE_GET_API(ServerInfoGetField); + REDISMODULE_GET_API(ServerInfoGetFieldC); + REDISMODULE_GET_API(ServerInfoGetFieldSigned); + REDISMODULE_GET_API(ServerInfoGetFieldUnsigned); + REDISMODULE_GET_API(ServerInfoGetFieldDouble); + REDISMODULE_GET_API(GetClientInfoById); + REDISMODULE_GET_API(GetClientNameById); + REDISMODULE_GET_API(SetClientNameById); + REDISMODULE_GET_API(PublishMessage); + REDISMODULE_GET_API(PublishMessageShard); + REDISMODULE_GET_API(SubscribeToServerEvent); + REDISMODULE_GET_API(SetLRU); + REDISMODULE_GET_API(GetLRU); + REDISMODULE_GET_API(SetLFU); + REDISMODULE_GET_API(GetLFU); + REDISMODULE_GET_API(BlockClientOnKeys); + REDISMODULE_GET_API(BlockClientOnKeysWithFlags); + REDISMODULE_GET_API(SignalKeyAsReady); + REDISMODULE_GET_API(GetBlockedClientReadyKey); + REDISMODULE_GET_API(ScanCursorCreate); + REDISMODULE_GET_API(ScanCursorRestart); + REDISMODULE_GET_API(ScanCursorDestroy); + REDISMODULE_GET_API(Scan); + REDISMODULE_GET_API(ScanKey); + REDISMODULE_GET_API(GetContextFlagsAll); + REDISMODULE_GET_API(GetModuleOptionsAll); + REDISMODULE_GET_API(GetKeyspaceNotificationFlagsAll); + REDISMODULE_GET_API(IsSubEventSupported); + REDISMODULE_GET_API(GetServerVersion); + REDISMODULE_GET_API(GetTypeMethodVersion); + REDISMODULE_GET_API(Yield); + REDISMODULE_GET_API(GetThreadSafeContext); + REDISMODULE_GET_API(GetDetachedThreadSafeContext); + REDISMODULE_GET_API(FreeThreadSafeContext); + REDISMODULE_GET_API(ThreadSafeContextLock); + REDISMODULE_GET_API(ThreadSafeContextTryLock); + REDISMODULE_GET_API(ThreadSafeContextUnlock); + REDISMODULE_GET_API(BlockClient); + REDISMODULE_GET_API(BlockClientGetPrivateData); + REDISMODULE_GET_API(BlockClientSetPrivateData); + REDISMODULE_GET_API(BlockClientOnAuth); + REDISMODULE_GET_API(UnblockClient); + REDISMODULE_GET_API(IsBlockedReplyRequest); + REDISMODULE_GET_API(IsBlockedTimeoutRequest); + REDISMODULE_GET_API(GetBlockedClientPrivateData); + REDISMODULE_GET_API(GetBlockedClientHandle); + REDISMODULE_GET_API(AbortBlock); + REDISMODULE_GET_API(BlockedClientMeasureTimeStart); + REDISMODULE_GET_API(BlockedClientMeasureTimeEnd); + REDISMODULE_GET_API(SetDisconnectCallback); + REDISMODULE_GET_API(SubscribeToKeyspaceEvents); + REDISMODULE_GET_API(AddPostNotificationJob); + REDISMODULE_GET_API(NotifyKeyspaceEvent); + REDISMODULE_GET_API(GetNotifyKeyspaceEvents); + REDISMODULE_GET_API(BlockedClientDisconnected); + REDISMODULE_GET_API(RegisterClusterMessageReceiver); + REDISMODULE_GET_API(SendClusterMessage); + REDISMODULE_GET_API(GetClusterNodeInfo); + REDISMODULE_GET_API(GetClusterNodesList); + REDISMODULE_GET_API(FreeClusterNodesList); + REDISMODULE_GET_API(CreateTimer); + REDISMODULE_GET_API(StopTimer); + REDISMODULE_GET_API(GetTimerInfo); + REDISMODULE_GET_API(GetMyClusterID); + REDISMODULE_GET_API(GetClusterSize); + REDISMODULE_GET_API(GetRandomBytes); + REDISMODULE_GET_API(GetRandomHexChars); + REDISMODULE_GET_API(SetClusterFlags); + REDISMODULE_GET_API(ClusterKeySlot); + REDISMODULE_GET_API(ClusterCanonicalKeyNameInSlot); + REDISMODULE_GET_API(ExportSharedAPI); + REDISMODULE_GET_API(GetSharedAPI); + REDISMODULE_GET_API(RegisterCommandFilter); + REDISMODULE_GET_API(UnregisterCommandFilter); + REDISMODULE_GET_API(CommandFilterArgsCount); + REDISMODULE_GET_API(CommandFilterArgGet); + REDISMODULE_GET_API(CommandFilterArgInsert); + REDISMODULE_GET_API(CommandFilterArgReplace); + REDISMODULE_GET_API(CommandFilterArgDelete); + REDISMODULE_GET_API(CommandFilterGetClientId); + REDISMODULE_GET_API(Fork); + REDISMODULE_GET_API(SendChildHeartbeat); + REDISMODULE_GET_API(ExitFromChild); + REDISMODULE_GET_API(KillForkChild); + REDISMODULE_GET_API(GetUsedMemoryRatio); + REDISMODULE_GET_API(MallocSize); + REDISMODULE_GET_API(MallocUsableSize); + REDISMODULE_GET_API(MallocSizeString); + REDISMODULE_GET_API(MallocSizeDict); + REDISMODULE_GET_API(CreateModuleUser); + REDISMODULE_GET_API(FreeModuleUser); + REDISMODULE_GET_API(SetContextUser); + REDISMODULE_GET_API(SetModuleUserACL); + REDISMODULE_GET_API(SetModuleUserACLString); + REDISMODULE_GET_API(GetModuleUserACLString); + REDISMODULE_GET_API(GetCurrentUserName); + REDISMODULE_GET_API(GetModuleUserFromUserName); + REDISMODULE_GET_API(ACLCheckCommandPermissions); + REDISMODULE_GET_API(ACLCheckKeyPermissions); + REDISMODULE_GET_API(ACLCheckChannelPermissions); + REDISMODULE_GET_API(ACLAddLogEntry); + REDISMODULE_GET_API(ACLAddLogEntryByUserName); + REDISMODULE_GET_API(DeauthenticateAndCloseClient); + REDISMODULE_GET_API(AuthenticateClientWithACLUser); + REDISMODULE_GET_API(AuthenticateClientWithUser); + REDISMODULE_GET_API(RedactClientCommandArgument); + REDISMODULE_GET_API(GetClientCertificate); + REDISMODULE_GET_API(GetCommandKeys); + REDISMODULE_GET_API(GetCommandKeysWithFlags); + REDISMODULE_GET_API(GetCurrentCommandName); + REDISMODULE_GET_API(RegisterDefragFunc); + REDISMODULE_GET_API(RegisterDefragCallbacks); + REDISMODULE_GET_API(DefragAlloc); + REDISMODULE_GET_API(DefragAllocRaw); + REDISMODULE_GET_API(DefragFreeRaw); + REDISMODULE_GET_API(DefragRedisModuleString); + REDISMODULE_GET_API(DefragShouldStop); + REDISMODULE_GET_API(DefragCursorSet); + REDISMODULE_GET_API(DefragCursorGet); + REDISMODULE_GET_API(GetKeyNameFromDefragCtx); + REDISMODULE_GET_API(GetDbIdFromDefragCtx); + REDISMODULE_GET_API(EventLoopAdd); + REDISMODULE_GET_API(EventLoopDel); + REDISMODULE_GET_API(EventLoopAddOneShot); + REDISMODULE_GET_API(RegisterBoolConfig); + REDISMODULE_GET_API(RegisterNumericConfig); + REDISMODULE_GET_API(RegisterStringConfig); + REDISMODULE_GET_API(RegisterEnumConfig); + REDISMODULE_GET_API(LoadConfigs); + REDISMODULE_GET_API(RdbStreamCreateFromFile); + REDISMODULE_GET_API(RdbStreamFree); + REDISMODULE_GET_API(RdbLoad); + REDISMODULE_GET_API(RdbSave); + + if (RedisModule_IsModuleNameBusy && RedisModule_IsModuleNameBusy(name)) return REDISMODULE_ERR; + RedisModule_SetModuleAttribs(ctx,name,ver,apiver); + return REDISMODULE_OK; +} + +#define RedisModule_Assert(_e) ((_e)?(void)0 : (RedisModule__Assert(#_e,__FILE__,__LINE__),exit(1))) + +#define RMAPI_FUNC_SUPPORTED(func) (func != NULL) + +#endif /* REDISMODULE_CORE */ +#endif /* REDISMODULE_H */ diff --git a/test.py b/test.py new file mode 100755 index 000000000..afc5d7fc7 --- /dev/null +++ b/test.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# +# Vector set tests. +# A Redis instance should be running in the default port. +# Copyright(C) 2024-2025 Salvatore Sanfilippo. +# All Rights Reserved. + +#!/usr/bin/env python3 +import redis +import random +import struct +import math +import time +import sys +import os +import importlib +import inspect +from typing import List, Tuple, Optional +from dataclasses import dataclass + +def colored(text: str, color: str) -> str: + colors = { + 'red': '\033[91m', + 'green': '\033[92m' + } + reset = '\033[0m' + return f"{colors.get(color, '')}{text}{reset}" + +@dataclass +class VectorData: + vectors: List[List[float]] + names: List[str] + + def find_k_nearest(self, query_vector: List[float], k: int) -> List[Tuple[str, float]]: + """Find k-nearest neighbors using the same scoring as Redis VSIM WITHSCORES.""" + similarities = [] + query_norm = math.sqrt(sum(x*x for x in query_vector)) + if query_norm == 0: + return [] + + for i, vec in enumerate(self.vectors): + vec_norm = math.sqrt(sum(x*x for x in vec)) + if vec_norm == 0: + continue + + dot_product = sum(a*b for a,b in zip(query_vector, vec)) + cosine_sim = dot_product / (query_norm * vec_norm) + distance = 1.0 - cosine_sim + redis_similarity = 1.0 - (distance/2.0) + similarities.append((self.names[i], redis_similarity)) + + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:k] + +def generate_random_vector(dim: int) -> List[float]: + """Generate a random normalized vector.""" + vec = [random.gauss(0, 1) for _ in range(dim)] + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + +def fill_redis_with_vectors(r: redis.Redis, key: str, count: int, dim: int, + with_reduce: Optional[int] = None) -> VectorData: + """Fill Redis with random vectors and return a VectorData object for verification.""" + vectors = [] + names = [] + + r.delete(key) + for i in range(count): + vec = generate_random_vector(dim) + name = f"{key}:item:{i}" + vectors.append(vec) + names.append(name) + + vec_bytes = struct.pack(f'{dim}f', *vec) + args = [key] + if with_reduce: + args.extend(['REDUCE', with_reduce]) + args.extend(['FP32', vec_bytes, name]) + r.execute_command('VADD', *args) + + return VectorData(vectors=vectors, names=names) + +class TestCase: + def __init__(self): + self.error_msg = None + self.error_details = None + self.test_key = f"test:{self.__class__.__name__.lower()}" + self.redis = redis.Redis() + + def setup(self): + self.redis.delete(self.test_key) + + def teardown(self): + self.redis.delete(self.test_key) + + def test(self): + raise NotImplementedError("Subclasses must implement test method") + + def run(self): + try: + self.setup() + self.test() + return True + except AssertionError as e: + self.error_msg = str(e) + import traceback + self.error_details = traceback.format_exc() + return False + except Exception as e: + self.error_msg = f"Unexpected error: {str(e)}" + import traceback + self.error_details = traceback.format_exc() + return False + finally: + self.teardown() + + def getname(self): + """Each test class should override this to provide its name""" + return self.__class__.__name__ + + def estimated_runtime(self): + """"Each test class should override this if it takes a significant amount of time to run. Default is 100ms""" + return 0.1 + +def find_test_classes(): + test_classes = [] + tests_dir = 'tests' + + if not os.path.exists(tests_dir): + return [] + + for file in os.listdir(tests_dir): + if file.endswith('.py'): + module_name = f"tests.{file[:-3]}" + try: + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and obj.__name__ != 'TestCase' and hasattr(obj, 'test'): + test_classes.append(obj()) + except Exception as e: + print(f"Error loading {file}: {e}") + + return test_classes + +def run_tests(): + print("================================================\n"+ + "Make sure to have Redis running in the localhost\n"+ + "with --enable-debug-command yes\n"+ + "================================================\n") + + tests = find_test_classes() + if not tests: + print("No tests found!") + return + + # Sort tests by estimated runtime + tests.sort(key=lambda t: t.estimated_runtime()) + + passed = 0 + total = len(tests) + + for test in tests: + print(f"{test.getname()}: ", end="") + sys.stdout.flush() + + start_time = time.time() + success = test.run() + duration = time.time() - start_time + + if success: + print(colored("OK", "green"), f"({duration:.2f}s)") + passed += 1 + else: + print(colored("ERR", "red"), f"({duration:.2f}s)") + print(f"Error: {test.error_msg}") + if test.error_details: + print("\nTraceback:") + print(test.error_details) + + print("\n" + "="*50) + print(f"\nTest Summary: {passed}/{total} tests passed") + + if passed == total: + print(colored("\nALL TESTS PASSED!", "green")) + else: + print(colored(f"\n{total-passed} TESTS FAILED!", "red")) + +if __name__ == "__main__": + run_tests() diff --git a/tests/basic_commands.py b/tests/basic_commands.py new file mode 100644 index 000000000..8481a3668 --- /dev/null +++ b/tests/basic_commands.py @@ -0,0 +1,21 @@ +from test import TestCase, generate_random_vector +import struct + +class BasicCommands(TestCase): + def getname(self): + return "VADD, VDIM, VCARD basic usage" + + def test(self): + # Test VADD + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Test VDIM + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == 4, f"VDIM should return 4, got {dim}" + + # Test VCARD + card = self.redis.execute_command('VCARD', self.test_key) + assert card == 1, f"VCARD should return 1, got {card}" diff --git a/tests/basic_similarity.py b/tests/basic_similarity.py new file mode 100644 index 000000000..11c3c9b17 --- /dev/null +++ b/tests/basic_similarity.py @@ -0,0 +1,35 @@ +from test import TestCase + +class BasicSimilarity(TestCase): + def getname(self): + return "VSIM reported distance makes sense with 4D vectors" + + def test(self): + # Add two very similar vectors, one different + vec1 = [1, 0, 0, 0] + vec2 = [0.99, 0.01, 0, 0] + vec3 = [0.1, 1, -1, 0.5] + + # Add vectors using VALUES format + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], f'{self.test_key}:item:1') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec2], f'{self.test_key}:item:2') + self.redis.execute_command('VADD', self.test_key, 'VALUES', 4, + *[str(x) for x in vec3], f'{self.test_key}:item:3') + + # Query similarity with vec1 + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4, + *[str(x) for x in vec1], 'WITHSCORES') + + # Convert results to dictionary + results_dict = {} + for i in range(0, len(result), 2): + key = result[i].decode() + score = float(result[i+1]) + results_dict[key] = score + + # Verify results + assert results_dict[f'{self.test_key}:item:1'] > 0.99, "Self-similarity should be very high" + assert results_dict[f'{self.test_key}:item:2'] > 0.99, "Similar vector should have high similarity" + assert results_dict[f'{self.test_key}:item:3'] < 0.8, "Not very similar vector should have low similarity" diff --git a/tests/concurrent_vsim_and_del.py b/tests/concurrent_vsim_and_del.py new file mode 100644 index 000000000..9bbf01116 --- /dev/null +++ b/tests/concurrent_vsim_and_del.py @@ -0,0 +1,48 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import threading, time + +class ConcurrentVSIMAndDEL(TestCase): + def getname(self): + return "Concurrent VSIM and DEL operations" + + def estimated_runtime(self): + return 2 + + def test(self): + # Fill the key with 5000 random vectors + dim = 128 + count = 5000 + fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # List to store results from threads + thread_results = [] + + def vsim_thread(): + """Thread function to perform VSIM operations until the key is deleted""" + while True: + query_vec = generate_random_vector(dim) + result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], 'COUNT', 10) + if not result: + # Empty array detected, key is deleted + thread_results.append(True) + break + + # Start multiple threads to perform VSIM operations + threads = [] + for _ in range(4): # Start 4 threads + t = threading.Thread(target=vsim_thread) + t.start() + threads.append(t) + + # Delete the key while threads are still running + time.sleep(1) + self.redis.delete(self.test_key) + + # Wait for all threads to finish (they will exit once they detect the key is deleted) + for t in threads: + t.join() + + # Verify that all threads detected an empty array or error + assert len(thread_results) == len(threads), "Not all threads detected the key deletion" + assert all(thread_results), "Some threads did not detect an empty array or error after DEL" diff --git a/tests/deletion.py b/tests/deletion.py new file mode 100644 index 000000000..cb919591b --- /dev/null +++ b/tests/deletion.py @@ -0,0 +1,173 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +""" +A note about this test: +It was experimentally tried to modify hnsw.c in order to +avoid calling hnsw_reconnect_nodes(). In this case, the test +fails very often with EF set to 250, while it hardly +fails at all with the same parameters if hnsw_reconnect_nodes() +is called. + +Note that for the nature of the test (it is very strict) it can +still fail from time to time, without this signaling any +actual bug. +""" + +class VREM(TestCase): + def getname(self): + return "Deletion and graph state after deletion" + + def estimated_runtime(self): + return 2.0 + + def format_neighbors_with_scores(self, links_result, old_links=None, items_to_remove=None): + """Format neighbors with their similarity scores and status indicators""" + if not links_result: + return "No neighbors" + + output = [] + for level, neighbors in enumerate(links_result): + level_num = len(links_result) - level - 1 + output.append(f"Level {level_num}:") + + # Get neighbors and scores + neighbors_with_scores = [] + for i in range(0, len(neighbors), 2): + neighbor = neighbors[i].decode() if isinstance(neighbors[i], bytes) else neighbors[i] + score = float(neighbors[i+1]) if i+1 < len(neighbors) else None + status = "" + + # For old links, mark deleted ones + if items_to_remove and neighbor in items_to_remove: + status = " [lost]" + # For new links, mark newly added ones + elif old_links is not None: + # Check if this neighbor was in the old links at this level + was_present = False + if old_links and level < len(old_links): + old_neighbors = [n.decode() if isinstance(n, bytes) else n + for n in old_links[level]] + was_present = neighbor in old_neighbors + if not was_present: + status = " [gained]" + + if score is not None: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor} ({score:.6f}){status}") + else: + neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor}{status}") + + output.extend([" " + n for n in neighbors_with_scores]) + return "\n".join(output) + + def test(self): + # 1. Fill server with random elements + dim = 128 + count = 5000 + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # 2. Do VSIM to get 200 items + query_vec = generate_random_vector(dim) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES') + + # Convert results to list of (item, score) pairs, sorted by score + items = [] + for i in range(0, len(results), 2): + item = results[i].decode() + score = float(results[i+1]) + items.append((item, score)) + items.sort(key=lambda x: x[1], reverse=True) # Sort by similarity + + # Store the graph structure for all items before deletion + neighbors_before = {} + for item, _ in items: + links = self.redis.execute_command('VLINKS', self.test_key, item, 'WITHSCORES') + if links: # Some items might not have links + neighbors_before[item] = links + + # 3. Remove 100 random items + items_to_remove = set(item for item, _ in random.sample(items, 100)) + # Keep track of top 10 non-removed items + top_remaining = [] + for item, score in items: + if item not in items_to_remove: + top_remaining.append((item, score)) + if len(top_remaining) == 10: + break + + # Remove the items + for item in items_to_remove: + result = self.redis.execute_command('VREM', self.test_key, item) + assert result == 1, f"VREM failed to remove {item}" + + # 4. Do VSIM again with same vector + new_results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', 200, 'WITHSCORES', + 'EF', 500) + + # Convert new results to dict of item -> score + new_scores = {} + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_scores[item] = score + + failure = False + failed_item = None + failed_reason = None + # 5. Verify all top 10 non-removed items are still found with similar scores + for item, old_score in top_remaining: + if item not in new_scores: + failure = True + failed_item = item + failed_reason = "missing" + break + new_score = new_scores[item] + if abs(new_score - old_score) >= 0.01: + failure = True + failed_item = item + failed_reason = f"score changed: {old_score:.6f} -> {new_score:.6f}" + break + + if failure: + print("\nTest failed!") + print(f"Problem with item: {failed_item} ({failed_reason})") + + print("\nOriginal neighbors (with similarity scores):") + if failed_item in neighbors_before: + print(self.format_neighbors_with_scores( + neighbors_before[failed_item], + items_to_remove=items_to_remove)) + else: + print("No neighbors found in original graph") + + print("\nCurrent neighbors (with similarity scores):") + current_links = self.redis.execute_command('VLINKS', self.test_key, + failed_item, 'WITHSCORES') + if current_links: + print(self.format_neighbors_with_scores( + current_links, + old_links=neighbors_before.get(failed_item))) + else: + print("No neighbors in current graph") + + print("\nOriginal results (top 20):") + for item, score in items[:20]: + deleted = "[deleted]" if item in items_to_remove else "" + print(f"{item}: {score:.6f} {deleted}") + + print("\nNew results after removal (top 20):") + new_items = [] + for i in range(0, len(new_results), 2): + item = new_results[i].decode() + score = float(new_results[i+1]) + new_items.append((item, score)) + new_items.sort(key=lambda x: x[1], reverse=True) + for item, score in new_items[:20]: + print(f"{item}: {score:.6f}") + + raise AssertionError(f"Test failed: Problem with item {failed_item} ({failed_reason}). *** IMPORTANT *** This test may fail from time to time without indicating that there is a bug. However normally it should pass. The fact is that it's a quite extreme test where we destroy 50% of nodes of top results and still expect perfect recall, with vectors that are very hostile because of the distribution used.") + diff --git a/tests/evict_empty.py b/tests/evict_empty.py new file mode 100644 index 000000000..6c78c825d --- /dev/null +++ b/tests/evict_empty.py @@ -0,0 +1,27 @@ +from test import TestCase, generate_random_vector +import struct + +class VREM_LastItemDeletesKey(TestCase): + def getname(self): + return "VREM last item deletes key" + + def test(self): + # Generate a random vector + vec = generate_random_vector(4) + vec_bytes = struct.pack('4f', *vec) + + # Add the vector to the key + result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + assert result == 1, "VADD should return 1 for first item" + + # Verify the key exists + exists = self.redis.exists(self.test_key) + assert exists == 1, "Key should exist after VADD" + + # Remove the item + result = self.redis.execute_command('VREM', self.test_key, f'{self.test_key}:item:1') + assert result == 1, "VREM should return 1 for successful removal" + + # Verify the key no longer exists + exists = self.redis.exists(self.test_key) + assert exists == 0, "Key should no longer exist after VREM of last item" diff --git a/tests/large_scale.py b/tests/large_scale.py new file mode 100644 index 000000000..eac5dca52 --- /dev/null +++ b/tests/large_scale.py @@ -0,0 +1,56 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class LargeScale(TestCase): + def getname(self): + return "Large Scale Comparison" + + def estimated_runtime(self): + return 10 + + def test(self): + dim = 300 + count = 20000 + k = 50 + + # Fill Redis and get reference data for comparison + random.seed(42) # Make test deterministic + data = fill_redis_with_vectors(self.redis, self.test_key, count, dim) + + # Generate query vector + query_vec = generate_random_vector(dim) + + # Get results from Redis with good exploration factor + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES', 'EF', 500) + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + + # If test fails, print comparison for debugging + if overlap < k * 0.7: + data.print_comparison({'items': redis_results, 'query_vector': query_vec}, k) + + assert overlap >= k * 0.7, \ + f"Expected at least 70% overlap in top {k} results, got {overlap/k*100:.1f}%" + + # Verify scores for common items + for item in redis_set & linear_set: + redis_score = redis_results[item] + linear_score = linear_items[item] + assert abs(redis_score - linear_score) < 0.01, \ + f"Score mismatch for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" diff --git a/tests/node_update.py b/tests/node_update.py new file mode 100644 index 000000000..53aa2dd56 --- /dev/null +++ b/tests/node_update.py @@ -0,0 +1,85 @@ +from test import TestCase, generate_random_vector +import struct +import math +import random + +class VectorUpdateAndClusters(TestCase): + def getname(self): + return "VADD vector update with cluster relocation" + + def estimated_runtime(self): + return 2.0 # Should take around 2 seconds + + def generate_cluster_vector(self, base_vec, noise=0.1): + """Generate a vector that's similar to base_vec with some noise.""" + vec = [x + random.gauss(0, noise) for x in base_vec] + # Normalize + norm = math.sqrt(sum(x*x for x in vec)) + return [x/norm for x in vec] + + def test(self): + dim = 128 + vectors_per_cluster = 5000 + + # Create two very different base vectors for our clusters + cluster1_base = generate_random_vector(dim) + cluster2_base = [-x for x in cluster1_base] # Opposite direction + + # Add vectors from first cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster1_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster1:{i}') + + # Add vectors from second cluster + for i in range(vectors_per_cluster): + vec = self.generate_cluster_vector(cluster2_base) + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, + f'{self.test_key}:cluster2:{i}') + + # Pick a test vector from cluster1 + test_key = f'{self.test_key}:cluster1:0' + + # Verify it's in cluster1 using VSIM + initial_vec = self.generate_cluster_vector(cluster1_base) + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in initial_vec], + 'COUNT', 100, 'WITHSCORES') + + # Count how many cluster1 items are in top results + cluster1_count = sum(1 for i in range(0, len(results), 2) + if b'cluster1' in results[i]) + assert cluster1_count > 80, "Initial clustering check failed" + + # Now update the test vector to be in cluster2 + new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05) + vec_bytes = struct.pack(f'{dim}f', *new_vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key) + + # Verify the embedding was actually updated using VEMB + emb_result = self.redis.execute_command('VEMB', self.test_key, test_key) + updated_vec = [float(x) for x in emb_result] + + # Verify updated vector matches what we inserted + dot_product = sum(a*b for a,b in zip(updated_vec, new_vec)) + similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) * + math.sqrt(sum(x*x for x in new_vec))) + assert similarity > 0.9, "Vector was not properly updated" + + # Verify it's now in cluster2 using VSIM + results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim, + *[str(x) for x in cluster2_base], + 'COUNT', 100, 'WITHSCORES') + + # Verify our updated vector is among top results + found = False + for i in range(0, len(results), 2): + if results[i].decode() == test_key: + found = True + similarity = float(results[i+1]) + assert similarity > 0.80, f"Updated vector has low similarity: {similarity}" + break + + assert found, "Updated vector not found in cluster2 proximity" diff --git a/tests/persistence.py b/tests/persistence.py new file mode 100644 index 000000000..021c8b6e3 --- /dev/null +++ b/tests/persistence.py @@ -0,0 +1,83 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector +import random + +class HNSWPersistence(TestCase): + def getname(self): + return "HNSW Persistence" + + def estimated_runtime(self): + return 30 + + def _verify_results(self, key, dim, query_vec, reduced_dim=None): + """Run a query and return results dict""" + k = 10 + args = ['VSIM', key] + + if reduced_dim: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + else: + args.extend(['VALUES', dim]) + args.extend([str(x) for x in query_vec]) + + args.extend(['COUNT', k, 'WITHSCORES']) + results = self.redis.execute_command(*args) + + results_dict = {} + for i in range(0, len(results), 2): + key = results[i].decode() + score = float(results[i+1]) + results_dict[key] = score + return results_dict + + def test(self): + # Setup dimensions + dim = 128 + reduced_dim = 32 + count = 5000 + random.seed(42) + + # Create two datasets - one normal and one with dimension reduction + normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim) + projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected", + count, dim, reduced_dim) + + # Generate query vectors we'll use before and after reload + query_vec_normal = generate_random_vector(dim) + query_vec_projected = generate_random_vector(dim) + + # Get initial results for both sets + initial_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + initial_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Force Redis to save and reload the dataset + self.redis.execute_command('DEBUG', 'RELOAD') + + # Verify results after reload + reloaded_normal = self._verify_results(f"{self.test_key}:normal", + dim, query_vec_normal) + reloaded_projected = self._verify_results(f"{self.test_key}:projected", + dim, query_vec_projected, reduced_dim) + + # Verify normal vectors results + assert len(initial_normal) == len(reloaded_normal), \ + "Normal vectors: Result count mismatch before/after reload" + + for key in initial_normal: + assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}" + assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \ + f"Normal vectors: Score mismatch for {key}: " + \ + f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}" + + # Verify projected vectors results + assert len(initial_projected) == len(reloaded_projected), \ + "Projected vectors: Result count mismatch before/after reload" + + for key in initial_projected: + assert key in reloaded_projected, \ + f"Projected vectors: Missing item after reload: {key}" + assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \ + f"Projected vectors: Score mismatch for {key}: " + \ + f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}" diff --git a/tests/reduce.py b/tests/reduce.py new file mode 100644 index 000000000..e39164f3b --- /dev/null +++ b/tests/reduce.py @@ -0,0 +1,71 @@ +from test import TestCase, fill_redis_with_vectors, generate_random_vector + +class Reduce(TestCase): + def getname(self): + return "Dimension Reduction" + + def estimated_runtime(self): + return 0.2 + + def test(self): + original_dim = 100 + reduced_dim = 80 + count = 1000 + k = 50 # Number of nearest neighbors to check + + # Fill Redis with vectors using REDUCE and get reference data + data = fill_redis_with_vectors(self.redis, self.test_key, count, original_dim, reduced_dim) + + # Verify dimension is reduced + dim = self.redis.execute_command('VDIM', self.test_key) + assert dim == reduced_dim, f"Expected dimension {reduced_dim}, got {dim}" + + # Generate query vector and get nearest neighbors using Redis + query_vec = generate_random_vector(original_dim) + redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', + original_dim, *[str(x) for x in query_vec], + 'COUNT', k, 'WITHSCORES') + + # Convert Redis results to dict + redis_results = {} + for i in range(0, len(redis_raw), 2): + key = redis_raw[i].decode() + score = float(redis_raw[i+1]) + redis_results[key] = score + + # Get results from linear scan with original vectors + linear_results = data.find_k_nearest(query_vec, k) + linear_items = {name: score for name, score in linear_results} + + # Compare overlap between reduced and non-reduced results + redis_set = set(redis_results.keys()) + linear_set = set(linear_items.keys()) + overlap = len(redis_set & linear_set) + overlap_ratio = overlap / k + + # With random projection, we expect some loss of accuracy but should + # maintain at least some similarity structure. + # Note that gaussian distribution is the worse with this test, so + # in real world practice, things will be better. + min_expected_overlap = 0.1 # At least 10% overlap in top-k + assert overlap_ratio >= min_expected_overlap, \ + f"Dimension reduction lost too much structure. Only {overlap_ratio*100:.1f}% overlap in top {k}" + + # For items that appear in both results, scores should be reasonably correlated + common_items = redis_set & linear_set + for item in common_items: + redis_score = redis_results[item] + linear_score = linear_items[item] + # Allow for some deviation due to dimensionality reduction + assert abs(redis_score - linear_score) < 0.2, \ + f"Score mismatch too high for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}" + + # If test fails, print comparison for debugging + if overlap_ratio < min_expected_overlap: + print("\nLow overlap in results. Details:") + print("\nTop results from linear scan (original vectors):") + for name, score in linear_results: + print(f"{name}: {score:.3f}") + print("\nTop results from Redis (reduced vectors):") + for item, score in sorted(redis_results.items(), key=lambda x: x[1], reverse=True): + print(f"{item}: {score:.3f}") diff --git a/tests/vadd_cas.py b/tests/vadd_cas.py new file mode 100644 index 000000000..3cb3508e5 --- /dev/null +++ b/tests/vadd_cas.py @@ -0,0 +1,98 @@ +from test import TestCase, generate_random_vector +import threading +import struct +import math +import time +import random +from typing import List, Dict + +class ConcurrentCASTest(TestCase): + def getname(self): + return "Concurrent VADD with CAS" + + def estimated_runtime(self): + return 1.5 + + def worker(self, vectors: List[List[float]], start_idx: int, end_idx: int, + dim: int, results: Dict[str, bool]): + """Worker thread that adds a subset of vectors using VADD CAS""" + for i in range(start_idx, end_idx): + vec = vectors[i] + name = f"{self.test_key}:item:{i}" + vec_bytes = struct.pack(f'{dim}f', *vec) + + # Try to add the vector with CAS + try: + result = self.redis.execute_command('VADD', self.test_key, 'FP32', + vec_bytes, name, 'CAS') + results[name] = (result == 1) # Store if it was actually added + except Exception as e: + results[name] = False + print(f"Error adding {name}: {e}") + + def verify_vector_similarity(self, vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between two vectors""" + dot_product = sum(a*b for a,b in zip(vec1, vec2)) + norm1 = math.sqrt(sum(x*x for x in vec1)) + norm2 = math.sqrt(sum(x*x for x in vec2)) + return dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0 + + def test(self): + # Test parameters + dim = 128 + total_vectors = 5000 + num_threads = 8 + vectors_per_thread = total_vectors // num_threads + + # Generate all vectors upfront + random.seed(42) # For reproducibility + vectors = [generate_random_vector(dim) for _ in range(total_vectors)] + + # Prepare threads and results dictionary + threads = [] + results = {} # Will store success/failure for each vector + + # Launch threads + for i in range(num_threads): + start_idx = i * vectors_per_thread + end_idx = start_idx + vectors_per_thread if i < num_threads-1 else total_vectors + thread = threading.Thread(target=self.worker, + args=(vectors, start_idx, end_idx, dim, results)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify cardinality + card = self.redis.execute_command('VCARD', self.test_key) + assert card == total_vectors, \ + f"Expected {total_vectors} elements, but found {card}" + + # Verify each vector + num_verified = 0 + for i in range(total_vectors): + name = f"{self.test_key}:item:{i}" + + # Verify the item was successfully added + assert results[name], f"Vector {name} was not successfully added" + + # Get the stored vector + stored_vec_raw = self.redis.execute_command('VEMB', self.test_key, name) + stored_vec = [float(x) for x in stored_vec_raw] + + # Verify vector dimensions + assert len(stored_vec) == dim, \ + f"Stored vector dimension mismatch for {name}: {len(stored_vec)} != {dim}" + + # Calculate similarity with original vector + similarity = self.verify_vector_similarity(vectors[i], stored_vec) + assert similarity > 0.99, \ + f"Low similarity ({similarity}) for {name}" + + num_verified += 1 + + # Final verification + assert num_verified == total_vectors, \ + f"Only verified {num_verified} out of {total_vectors} vectors" diff --git a/tests/vemb.py b/tests/vemb.py new file mode 100644 index 000000000..0f4cf77a7 --- /dev/null +++ b/tests/vemb.py @@ -0,0 +1,41 @@ +from test import TestCase +import struct +import math + +class VEMB(TestCase): + def getname(self): + return "VEMB Command" + + def test(self): + dim = 4 + + # Add same vector in both formats + vec = [1, 0, 0, 0] + norm = math.sqrt(sum(x*x for x in vec)) + vec = [x/norm for x in vec] # Normalize the vector + + # Add using FP32 + vec_bytes = struct.pack(f'{dim}f', *vec) + self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1') + + # Add using VALUES + self.redis.execute_command('VADD', self.test_key, 'VALUES', dim, + *[str(x) for x in vec], f'{self.test_key}:item:2') + + # Get both back with VEMB + result1 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:1') + result2 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:2') + + retrieved_vec1 = [float(x) for x in result1] + retrieved_vec2 = [float(x) for x in result2] + + # Compare both vectors with original (allow for small quantization errors) + for i in range(dim): + assert abs(vec[i] - retrieved_vec1[i]) < 0.01, \ + f"FP32 vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec1[i]}" + assert abs(vec[i] - retrieved_vec2[i]) < 0.01, \ + f"VALUES vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec2[i]}" + + # Test non-existent item + result = self.redis.execute_command('VEMB', self.test_key, 'nonexistent') + assert result is None, "Non-existent item should return nil" diff --git a/vset.c b/vset.c new file mode 100644 index 000000000..03646af1a --- /dev/null +++ b/vset.c @@ -0,0 +1,1208 @@ +/* Redis implementation for vector sets. The data structure itself + * is implemented in hnsw.c. + * + * Copyright(C) 2024 Salvatore Sanfilippo. + * All Rights Reserved. + */ + +#define _DEFAULT_SOURCE +#define _USE_MATH_DEFINES +#define _POSIX_C_SOURCE 200809L + +#include "redismodule.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "hnsw.h" + +static RedisModuleType *VectorSetType; +static uint64_t VectorSetTypeNextId = 0; + +#define VSET_DEFAULT_C_EF 200 // Default EF value if not specified. +#define VSET_DEFAULT_COUNT 10 // Default num elements returned by VSIM. + +/* ========================== Internal data structure ====================== */ + +/* Our abstract data type needs a dual representation similar to Redis + * sorted set: the proximity graph, and also a element -> graph-node map + * that will allow us to perform deletions and other operations that have + * as input the element itself. */ +struct vsetObject { + HNSW *hnsw; // Proximity graph. + RedisModuleDict *dict; // Element -> node mapping. + float *proj_matrix; // Random projection matrix, NULL if no projection + uint32_t proj_input_size; // Input dimension after projection. + // Output dimension is implicit in + // hnsw->vector_dim. + pthread_rwlock_t in_use_lock; // Lock needed to destroy the object safely. + uint64_t id; // Unique ID used by threaded VADD to know the + // object is still the same. +}; + +/* Create a random projection matrix for dimensionality reduction. + * Returns NULL on allocation failure. Matrix is scaled by 1/sqrt(input_dim). */ +float *createProjectionMatrix(uint32_t input_dim, uint32_t output_dim) { + float *matrix = RedisModule_Alloc(sizeof(float) * input_dim * output_dim); + if (!matrix) return NULL; + + const float scale = 1.0f / sqrt(input_dim); + for (uint32_t i = 0; i < input_dim * output_dim; i++) { + /* Box-Muller transform for normal distribution */ + float u1 = (float)rand() / RAND_MAX; + float u2 = (float)rand() / RAND_MAX; + float r = sqrt(-2.0f * log(u1)); + float theta = 2.0f * M_PI * u2; + matrix[i] = r * cos(theta) * scale; + } + return matrix; +} + +/* Apply random projection to input vector. Returns new allocated vector or NULL. */ +float *applyProjection(const float *input, const float *proj_matrix, + uint32_t input_dim, uint32_t output_dim) +{ + float *output = RedisModule_Alloc(sizeof(float) * output_dim); + if (!output) return NULL; + + for (uint32_t i = 0; i < output_dim; i++) { + const float *row = &proj_matrix[i * input_dim]; + float sum = 0.0f; + for (uint32_t j = 0; j < input_dim; j++) { + sum += row[j] * input[j]; + } + output[i] = sum; + } + return output; +} + +/* Create the vector as HNSW+Dictionary combined data structure. */ +struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type) { + struct vsetObject *o; + o = RedisModule_Alloc(sizeof(*o)); + if (!o) return NULL; + + o->id = VectorSetTypeNextId++; + o->hnsw = hnsw_new(dim,quant_type); + if (!o->hnsw) { + RedisModule_Free(o); + return NULL; + } + + o->dict = RedisModule_CreateDict(NULL); + if (!o->dict) { + hnsw_free(o->hnsw,NULL); + RedisModule_Free(o); + return NULL; + } + + o->proj_matrix = NULL; + o->proj_input_size = 0; + pthread_rwlock_init(&o->in_use_lock,NULL); + + return o; +} + +void vectorSetReleaseNodeValue(void *v) { + RedisModule_FreeString(NULL,v); +} + +/* Free the vector set object. */ +void vectorSetReleaseObject(struct vsetObject *o) { + if (!o) return; + if (o->hnsw) hnsw_free(o->hnsw,vectorSetReleaseNodeValue); + if (o->dict) RedisModule_FreeDict(NULL,o->dict); + if (o->proj_matrix) RedisModule_Free(o->proj_matrix); + pthread_rwlock_destroy(&o->in_use_lock); + RedisModule_Free(o); +} + +/* Insert the specified element into the Vector Set. + * If update is '1', the existing node will be updated. + * + * Returns 1 if the element was added, or 0 if the element was already there + * and was just updated. */ +int vectorSetInsert(struct vsetObject *o, float *vec, int8_t *qvec, float qrange, RedisModuleString *val, int update, int ef) +{ + hnswNode *node = RedisModule_DictGet(o->dict,val,NULL); + if (node != NULL) { + if (update) { + void *old_val = node->value; + /* Pass NULL as value-free function. We want to reuse + * the old value. */ + hnsw_delete_node(o->hnsw, node, NULL); + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,old_val,ef); + RedisModule_DictReplace(o->dict,val,node); + } + return 0; + } + + node = hnsw_insert(o->hnsw,vec,qvec,qrange,0,val,ef); + if (!node) return 0; + RedisModule_DictSet(o->dict,val,node); + return 1; +} + +/* Parse vector from FP32 blob or VALUES format, with optional REDUCE. + * Format: [REDUCE dim] FP32|VALUES ... + * Returns allocated vector and sets dimension in *dim. + * If reduce_dim is not NULL, sets it to the requested reduction dimension. + * Returns NULL on parsing error. */ +float *parseVector(RedisModuleString **argv, int argc, int start_idx, + size_t *dim, uint32_t *reduce_dim, int *consumed_args) +{ + int consumed = 0; // Argumnets consumed + + /* Check for REDUCE option first */ + if (reduce_dim) *reduce_dim = 0; + if (reduce_dim && argc > start_idx + 2 && + !strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"REDUCE")) + { + long long rdim; + if (RedisModule_StringToLongLong(argv[start_idx+1],&rdim) != REDISMODULE_OK + || rdim <= 0) return NULL; + if (reduce_dim) *reduce_dim = rdim; + start_idx += 2; // Skip REDUCE and its argument + consumed += 2; + } + + /* Now parse the vector format as before */ + float *vec = NULL; + + if (!strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"FP32")) { + if (argc < start_idx + 2) return NULL; // Need FP32 + vector + value + size_t vec_raw_len; + const char *blob = RedisModule_StringPtrLen(argv[start_idx+1],&vec_raw_len); + if (vec_raw_len % 4 || vec_raw_len < 4) return NULL; + *dim = vec_raw_len/4; + vec = RedisModule_Alloc(vec_raw_len); + if (!vec) return NULL; + memcpy(vec,blob,vec_raw_len); + consumed += 2; + } else if (!strcasecmp(RedisModule_StringPtrLen(argv[start_idx],NULL),"VALUES")) { + if (argc < start_idx + 2) return NULL; // Need at least dimension. + long long vdim; + if (RedisModule_StringToLongLong(argv[start_idx+1],&vdim) != REDISMODULE_OK + || vdim < 1) return NULL; + + // Check that all the arguments are available. + if (argc < start_idx + 2 + vdim) return NULL; + + *dim = vdim; + vec = RedisModule_Alloc(sizeof(float) * vdim); + if (!vec) return NULL; + + for (int j = 0; j < vdim; j++) { + double val; + if (RedisModule_StringToDouble(argv[start_idx+2+j],&val) != REDISMODULE_OK) { + RedisModule_Free(vec); + return NULL; + } + vec[j] = val; + } + consumed += vdim + 2; + } else { + return NULL; // Unknown format + } + + if (consumed_args) *consumed_args = consumed; + return vec; +} + +/* ========================== Commands implementation ======================= */ + +/* VADD thread handling the "CAS" version of the command, that is + * performed blocking the client, accumulating here, in the thread, the + * set of potential candidates, and later inserting the element in the + * key (if it still exists, and if it is still the *same* vector set) + * in the Reply callback. */ +void *VADD_thread(void *arg) { + pthread_detach(pthread_self()); + + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[3]; + RedisModuleString *val = targ[4]; + int ef = (uint64_t)targ[6]; + + /* Look for candidates... */ + InsertContext *ic = hnsw_prepare_insert(vset->hnsw, vec, NULL, 0, 0, val, ef); + targ[5] = ic; // Pass the context to the reply callback. + + /* Unblock the client so that our read reply will be invoked. */ + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_UnblockClient(bc,targ); // Use targ as privdata. + return NULL; +} + +/* Reply callback for CAS variant of VADD. */ +int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + (void)argc; + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + int retval = REDISMODULE_OK; + void **targ = (void**)RedisModule_GetBlockedClientPrivateData(ctx); + uint64_t vset_id = (unsigned long) targ[2]; + float *vec = targ[3]; + RedisModuleString *val = targ[4]; + InsertContext *ic = targ[5]; + int ef = (uint64_t)targ[6]; + RedisModule_Free(targ); + + /* Open the key: there are no guarantees it still exists, or contains + * a vector set, or even the SAME vector set. */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + struct vsetObject *vset = NULL; + + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) == VectorSetType) + { + vset = RedisModule_ModuleTypeGetValue(key); + // Same vector set? + if (vset->id != vset_id) vset = NULL; + + /* Also, if the element was already inserted, we just pretend + * the other insert won. We don't even start a threaded VADD + * if this was an udpate, since the deletion of the element itself + * in order to perform the update would invalidate the CAS state. */ + if (RedisModule_DictGet(vset->dict,val,NULL) != NULL) vset = NULL; + } + + if (vset == NULL) { + /* If the object does not match the start of the operation, we + * just pretend the VADD was performed BEFORE the key was deleted + * or replaced. We return success but don't do anything. */ + hnsw_free_insert_context(ic); + } else { + /* Otherwise try to insert the new element with the neighbors + * collected in background. If we fail, do it synchronously again + * from scratch. */ + hnswNode *newnode; + if ((newnode = hnsw_try_commit_insert(vset->hnsw, ic)) == NULL) { + newnode = hnsw_insert(vset->hnsw, vec, NULL, 0, 0, val, ef); + } + RedisModule_DictSet(vset->dict,val,newnode); + val = NULL; // Don't free it later. + } + + // Whatever happens is a success... :D + RedisModule_ReplyWithLongLong(ctx,1); + + if (val) RedisModule_FreeString(ctx,val); // Not added? Free it. + RedisModule_Free(vec); + return retval; +} + +/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] */ +int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc < 5) return RedisModule_WrongArity(ctx); + + /* Parse vector with optional REDUCE */ + size_t dim = 0; + uint32_t reduce_dim = 0; + int consumed_args; + int cas = 0; // Threaded check-and-set style insert. + long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes. + float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args); + if (!vec) + return RedisModule_ReplyWithError(ctx,"ERR invalid vector specification"); + + /* Missing element string at the end? */ + if (argc-2-consumed_args < 1) return RedisModule_WrongArity(ctx); + + /* Parse options after the element string. */ + uint32_t quant_type = HNSW_QUANT_Q8; // Default quantization type. + + for (int j = 2 + consumed_args + 1; j < argc; j++) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "CAS")) { + cas = 1; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) + != REDISMODULE_OK || ef <= 0 || ef > 1000000) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j++; // skip EF argument. + } else if (!strcasecmp(opt, "NOQUANT")) { + quant_type = HNSW_QUANT_NONE; + } else if (!strcasecmp(opt, "BIN")) { + quant_type = HNSW_QUANT_BIN; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,"ERR invalid option after element"); + } + } + + /* Open/create key */ + RedisModuleKey *key = RedisModule_OpenKey(ctx,argv[1], + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(key); + if (type != REDISMODULE_KEYTYPE_EMPTY && + RedisModule_ModuleTypeGetType(key) != VectorSetType) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx,REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get the correct value argument based on format and REDUCE */ + RedisModuleString *val = argv[2 + consumed_args]; + + /* Create or get existing vector set */ + struct vsetObject *vset; + if (type == REDISMODULE_KEYTYPE_EMPTY) { + cas = 0; /* Do synchronous insert at creation, otherwise the + * key would be left empty until the threaded part + * does not return. It's also pointless to try try + * doing threaded first elemetn insertion. */ + vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type); + + /* Initialize projection if requested */ + if (reduce_dim) { + vset->proj_matrix = createProjectionMatrix(dim, reduce_dim); + vset->proj_input_size = dim; + + /* Project the vector */ + float *projected = applyProjection(vec, vset->proj_matrix, + dim, reduce_dim); + RedisModule_Free(vec); + vec = projected; + } + RedisModule_ModuleTypeSetValue(key,VectorSetType,vset); + } else { + vset = RedisModule_ModuleTypeGetValue(key); + + if (vset->hnsw->quant_type != quant_type) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR use the same quantization of the existing vector set"); + } + + if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) || + (vset->proj_matrix && vset->hnsw->vector_dim != reduce_dim)) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Check REDUCE compatibility */ + if (reduce_dim) { + if (!vset->proj_matrix) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR cannot add projection to existing set without projection"); + } + if (reduce_dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR projection dimension mismatch with existing set"); + } + } + + /* Apply projection if needed */ + if (vset->proj_matrix) { + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, dim); + RedisModule_Free(vec); + vec = projected; + } + } + + /* Don't do CAS updates. For how things work now, the CAS state would + * be invalidated by the detetion before adding back. */ + if (cas && RedisModule_DictGet(vset->dict,val,NULL) != NULL) + cas = 0; + + /* Here depending on the CAS option we directly insert in a blocking + * way, or use a therad to do candidate neighbors selection and only + * later, in the reply callback, actually add the element. */ + + if (!cas) { + /* Insert vector synchronously. */ + int added = vectorSetInsert(vset,vec,NULL,0,val,1,ef); + if (added) RedisModule_RetainString(ctx,val); + RedisModule_Free(vec); + + RedisModule_ReplyWithLongLong(ctx,added); + if (added) RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; + } else { + /* Make sure the key does not get deleted during the background + * operation. See VSIM implementation for more information. */ + pthread_rwlock_rdlock(&vset->in_use_lock); + + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,VADD_CASReply,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*7); + targ[0] = bc; + targ[1] = vset; + targ[2] = (void*)(unsigned long)vset->id; + targ[3] = vec; + targ[4] = val; + targ[5] = NULL; // Used later for insertion context. + targ[6] = (void*)(unsigned long)ef; + RedisModule_RetainString(ctx,val); + if (pthread_create(&tid,NULL,VADD_thread,targ) != 0) { + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_AbortBlock(bc); + RedisModule_FreeString(ctx, val); + RedisModule_Free(vec); + RedisModule_Free(targ); + return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread"); + } + return REDISMODULE_OK; + } +} + +/* VSIM thread handling the blocked client request. */ +void *VSIM_thread(void *arg) { + pthread_detach(pthread_self()); + + void **targ = (void**)arg; + RedisModuleBlockedClient *bc = targ[0]; + struct vsetObject *vset = targ[1]; + float *vec = targ[2]; + unsigned long count = (unsigned long)targ[3]; + float epsilon = *((float*)targ[4]); + unsigned long withscores = (unsigned long)targ[5]; + unsigned long ef = (unsigned long)targ[6]; + + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + + /* In our scan, we can't just collect 'count' elements as + * if count is small we would explore the graph in an insufficient + * way to provide enough recall. + * + * If the user didn't asked for a specific exploration, we use + * 50 as minimum, or we match count if count is greater than + * that. Otherwise the minumim will be the specified EF argument. */ + + if (ef == 0) ef = 100; // This is a decent default to go fast but avoid + // obvious local minima along the path. + if (count > ef) ef = count; + + /* Perform search */ + hnswNode **neighbors = RedisModule_Alloc(sizeof(hnswNode*)*ef); + float *distances = RedisModule_Alloc(sizeof(float)*ef); + int slot = hnsw_acquire_read_slot(vset->hnsw); + unsigned int found = hnsw_search(vset->hnsw, vec, ef, neighbors, distances, slot, 0); + hnsw_release_read_slot(vset->hnsw,slot); + RedisModule_Free(vec); + + /* Return results */ + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(bc); + if (withscores) + RedisModule_ReplyWithMap(ctx, REDISMODULE_POSTPONED_LEN); + else + RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_LEN); + long long arraylen = 0; + + for (unsigned int i = 0; i < found && i < count; i++) { + if (distances[i] > epsilon) break; + RedisModule_ReplyWithString(ctx, neighbors[i]->value); + arraylen++; + if (withscores) { + /* The similarity score is provided in a 0-1 range. */ + RedisModule_ReplyWithDouble(ctx, 1.0 - distances[i]/2.0); + } + } + + if (withscores) + RedisModule_ReplySetMapLength(ctx, arraylen); + else + RedisModule_ReplySetArrayLength(ctx, arraylen); + + RedisModule_FreeThreadSafeContext(ctx); + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_UnblockClient(bc,NULL); + RedisModule_Free(neighbors); + RedisModule_Free(distances); + return NULL; +} + +/* VSIM key [ELE|FP32|VALUES] [WITHSCORES] [COUNT num] [EPSILON eps] [EF exploration-factor] */ +int VSIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + /* Basic argument check: need at least key and vector specification + * method. */ + if (argc < 4) return RedisModule_WrongArity(ctx); + + /* Defaults */ + int withscores = 0; + long long count = VSET_DEFAULT_COUNT; /* New default value */ + long long ef = 0; /* Exploration factor (see HNSW paper) */ + double epsilon = 2.0; /* Max cosine distance */ + + /* Get key and vector type */ + RedisModuleString *key = argv[1]; + const char *vectorType = RedisModule_StringPtrLen(argv[2], NULL); + + /* Get vector set */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithEmptyArray(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Vector parsing stage */ + float *vec = NULL; + size_t dim = 0; + int vector_args = 0; /* Number of args consumed by vector specification */ + + if (!strcasecmp(vectorType, "ELE")) { + /* Get vector from existing element */ + RedisModuleString *ele = argv[3]; + hnswNode *node = RedisModule_DictGet(vset->dict, ele, NULL); + if (!node) { + return RedisModule_ReplyWithError(ctx, "ERR element not found in set"); + } + vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw,node,vec); + dim = vset->hnsw->vector_dim; + vector_args = 2; /* ELE + element name */ + } else { + /* Parse vector. */ + int consumed_args; + + vec = parseVector(argv, argc, 2, &dim, NULL, &consumed_args); + if (!vec) { + return RedisModule_ReplyWithError(ctx, + "ERR invalid vector specification"); + } + vector_args = consumed_args; + + /* Apply projection if the set uses it, with the exception + * of ELE type, that will already have the right dimension. + * XXX: check explicitly that ELE was passed, not just size. */ + if (vset->proj_matrix && dim != vset->hnsw->vector_dim) { + float *projected = applyProjection(vec, vset->proj_matrix, + vset->proj_input_size, dim); + RedisModule_Free(vec); + vec = projected; + dim = vset->hnsw->vector_dim; + } + + /* Count consumed arguments */ + if (!strcasecmp(vectorType, "FP32")) { + vector_args = 2; /* FP32 + vector blob */ + } else if (!strcasecmp(vectorType, "VALUES")) { + long long vdim; + if (RedisModule_StringToLongLong(argv[3], &vdim) != REDISMODULE_OK) { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid vector dimension"); + } + vector_args = 2 + vdim; /* VALUES + dim + values */ + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR vector type must be ELE, FP32 or VALUES"); + } + } + + /* Check vector dimension matches set */ + if (dim != vset->hnsw->vector_dim) { + RedisModule_Free(vec); + return RedisModule_ReplyWithErrorFormat(ctx, + "ERR Vector dimension mismatch - got %d but set has %d", + (int)dim, (int)vset->hnsw->vector_dim); + } + + /* Parse optional arguments - start after vector specification */ + int j = 2 + vector_args; + while (j < argc) { + const char *opt = RedisModule_StringPtrLen(argv[j], NULL); + if (!strcasecmp(opt, "WITHSCORES")) { + withscores = 1; + j++; + } else if (!strcasecmp(opt, "COUNT") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &count) + != REDISMODULE_OK || count <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid COUNT"); + } + j += 2; + } else if (!strcasecmp(opt, "EPSILON") && j+1 < argc) { + if (RedisModule_StringToDouble(argv[j+1], &epsilon) != + REDISMODULE_OK || epsilon <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EPSILON"); + } + j += 2; + } else if (!strcasecmp(opt, "EF") && j+1 < argc) { + if (RedisModule_StringToLongLong(argv[j+1], &ef) != + REDISMODULE_OK || ef <= 0) + { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, "ERR invalid EF"); + } + j += 2; + } else { + RedisModule_Free(vec); + return RedisModule_ReplyWithError(ctx, + "ERR syntax error in VSIM command"); + } + } + + /* Spawn the thread serving the request: + * Acquire the lock here so that the object will not be + * destroyed while we work with it in the thread. + * + * This lock should never block, since: + * 1. If we are in the main thread, the key exists (we looked it up) + * and so there is no deletion in progress. + * 2. If the write lock is taken while destroying the object, another + * command or operation (expire?) from the main thread acquired + * it to delete the object, so *it* will block if there are still + * operations in progress on this key. */ + pthread_rwlock_rdlock(&vset->in_use_lock); + + RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx,NULL,NULL,NULL,0); + pthread_t tid; + void **targ = RedisModule_Alloc(sizeof(void*)*7); + targ[0] = bc; + targ[1] = vset; + targ[2] = vec; + targ[3] = (void*)count; + targ[4] = RedisModule_Alloc(sizeof(float)); + *((float*)targ[4]) = epsilon; + targ[5] = (void*)(unsigned long)withscores; + targ[6] = (void*)(unsigned long)ef; + if (pthread_create(&tid,NULL,VSIM_thread,targ) != 0) { + pthread_rwlock_unlock(&vset->in_use_lock); + RedisModule_AbortBlock(bc); + RedisModule_Free(vec); + RedisModule_Free(targ[4]); + RedisModule_Free(targ); + return RedisModule_ReplyWithError(ctx,"-ERR Can't start thread"); + } + + return REDISMODULE_OK; +} + +/* VDIM : return the dimension of vectors in the vector set. */ +int VDIM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithError(ctx, "ERR key does not exist"); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); +} + +/* VCARD : return cardinality (num of elements) of the vector set. */ +int VCARD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithLongLong(ctx, 0); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + return RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); +} + +/* VREM key element + * Remove an element from a vector set. + * Returns 1 if the element was found and removed, 0 if not found. */ +int VREM_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); /* Use automatic memory management. */ + + if (argc != 3) return RedisModule_WrongArity(ctx); + + /* Get key and value */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, + REDISMODULE_READ|REDISMODULE_WRITE); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithLongLong(ctx, 0); + } + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Get vector set from key */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + + /* Find the node for this element */ + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithLongLong(ctx, 0); + } + + /* Remove from dictionary */ + RedisModule_DictDel(vset->dict, element, NULL); + + /* Remove from HNSW graph using the high-level API that handles + * locking and cleanup. We pass RedisModule_FreeString as the value + * free function since the strings were retained at insertion time. */ + hnsw_delete_node(vset->hnsw, node, vectorSetReleaseNodeValue); + + /* Destroy empty vector set. */ + if (RedisModule_DictSize(vset->dict) == 0) { + RedisModule_DeleteKey(keyptr); + } + + /* Reply and propagate the command */ + RedisModule_ReplyWithLongLong(ctx, 1); + RedisModule_ReplicateVerbatim(ctx); + return REDISMODULE_OK; +} + +/* VEMB key element + * Returns the embedding vector associated with an element, or NIL if not + * found. The vector is returned in the same format it was added, but the + * return value will have some lack of precision due to quantization and + * normalization of vectors. Also, if items were added using REDUCE, the + * reduced vector is returned instead. */ +int VEMB_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 3) return RedisModule_WrongArity(ctx); + + /* Get key and element. */ + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Open key. */ + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key and key of wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithNull(ctx); + } else if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) { + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + } + + /* Lookup the node about the specified element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) { + return RedisModule_ReplyWithNull(ctx); + } + + /* Get the vector associated with the node. */ + float *vec = RedisModule_Alloc(sizeof(float) * vset->hnsw->vector_dim); + hnsw_get_node_vector(vset->hnsw, node, vec); // May dequantize/denorm. + + /* Return as array of doubles. */ + RedisModule_ReplyWithArray(ctx, vset->hnsw->vector_dim); + for (uint32_t i = 0; i < vset->hnsw->vector_dim; i++) + RedisModule_ReplyWithDouble(ctx, vec[i]); + + RedisModule_Free(vec); + return REDISMODULE_OK; +} + +/* ============================== Reflection ================================ */ + +/* VLINKS key element [WITHSCORES] + * Returns the neighbors of an element at each layer in the HNSW graph. + * Reply is an array of arrays, where each nested array represents one level + * of neighbors, from highest level to level 0. If WITHSCORES is specified, + * each neighbor is followed by its distance from the element. */ +int VLINKS_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc < 3 || argc > 4) return RedisModule_WrongArity(ctx); + + RedisModuleString *key = argv[1]; + RedisModuleString *element = argv[2]; + + /* Parse WITHSCORES option. */ + int withscores = 0; + if (argc == 4) { + const char *opt = RedisModule_StringPtrLen(argv[3], NULL); + if (strcasecmp(opt, "WITHSCORES") != 0) { + return RedisModule_WrongArity(ctx); + } + withscores = 1; + } + + RedisModuleKey *keyptr = RedisModule_OpenKey(ctx, key, REDISMODULE_READ); + int type = RedisModule_KeyType(keyptr); + + /* Handle non-existing key or wrong type. */ + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNull(ctx); + + if (RedisModule_ModuleTypeGetType(keyptr) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + /* Find the node for this element. */ + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(keyptr); + hnswNode *node = RedisModule_DictGet(vset->dict, element, NULL); + if (!node) + return RedisModule_ReplyWithNull(ctx); + + /* Reply with array of arrays, one per level. */ + RedisModule_ReplyWithArray(ctx, node->level + 1); + + /* For each level, from highest to lowest: */ + for (int i = node->level; i >= 0; i--) { + /* Reply with array of neighbors at this level. */ + if (withscores) + RedisModule_ReplyWithMap(ctx,node->layers[i].num_links); + else + RedisModule_ReplyWithArray(ctx,node->layers[i].num_links); + + /* Add each neighbor's element value to the array. */ + for (uint32_t j = 0; j < node->layers[i].num_links; j++) { + RedisModule_ReplyWithString(ctx, node->layers[i].links[j]->value); + if (withscores) { + float distance = hnsw_distance(vset->hnsw, node, node->layers[i].links[j]); + /* Convert distance to similarity score to match + * VSIM behavior.*/ + float similarity = 1.0 - distance/2.0; + RedisModule_ReplyWithDouble(ctx, similarity); + } + } + } + return REDISMODULE_OK; +} + +/* VINFO key + * Returns information about a vector set, both visible and hidden + * features of the HNSW data structure. */ +int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + RedisModule_AutoMemory(ctx); + + if (argc != 2) return RedisModule_WrongArity(ctx); + + RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_READ); + int type = RedisModule_KeyType(key); + + if (type == REDISMODULE_KEYTYPE_EMPTY) + return RedisModule_ReplyWithNullArray(ctx); + + if (RedisModule_ModuleTypeGetType(key) != VectorSetType) + return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); + + struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key); + + /* Reply with hash */ + RedisModule_ReplyWithMap(ctx, 6); + + /* Quantization type */ + RedisModule_ReplyWithSimpleString(ctx, "quant-type"); + if (vset->hnsw->quant_type == HNSW_QUANT_NONE) { + RedisModule_ReplyWithSimpleString(ctx, "f32"); + } else if (vset->hnsw->quant_type == HNSW_QUANT_Q8) { + RedisModule_ReplyWithSimpleString(ctx, "int8"); + } else { + RedisModule_ReplyWithSimpleString(ctx, "unknown"); + } + + /* Vector dimensionality. */ + RedisModule_ReplyWithSimpleString(ctx, "vector-dim"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim); + + /* Number of elements. */ + RedisModule_ReplyWithSimpleString(ctx, "size"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->node_count); + + /* Max level of HNSW. */ + RedisModule_ReplyWithSimpleString(ctx, "max-level"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->max_level); + + /* Vector set ID. */ + RedisModule_ReplyWithSimpleString(ctx, "vset-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->id); + + /* HNSW max node ID. */ + RedisModule_ReplyWithSimpleString(ctx, "hnsw-max-node-uid"); + RedisModule_ReplyWithLongLong(ctx, vset->hnsw->last_id); + + return REDISMODULE_OK; +} + +/* ============================== vset type methods ========================= */ + +/* Save object to RDB */ +void VectorSetRdbSave(RedisModuleIO *rdb, void *value) { + struct vsetObject *vset = value; + RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim); + RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count); + RedisModule_SaveUnsigned(rdb, vset->hnsw->quant_type); + + /* Save projection matrix if present */ + if (vset->proj_matrix) { + RedisModule_SaveUnsigned(rdb, 1); // has projection + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + RedisModule_SaveUnsigned(rdb, input_dim); + // Output dim is the same as the first value saved + // above, so we don't save it. + + // Save projection matrix as binary blob + size_t matrix_size = sizeof(float) * input_dim * output_dim; + RedisModule_SaveStringBuffer(rdb, (const char *)vset->proj_matrix, matrix_size); + } else { + RedisModule_SaveUnsigned(rdb, 0); // no projection + } + + hnswNode *node = vset->hnsw->head; + while(node) { + RedisModule_SaveString(rdb, node->value); + hnswSerNode *sn = hnsw_serialize_node(vset->hnsw,node); + RedisModule_SaveStringBuffer(rdb, (const char *)sn->vector, sn->vector_size); + RedisModule_SaveUnsigned(rdb, sn->params_count); + for (uint32_t j = 0; j < sn->params_count; j++) + RedisModule_SaveUnsigned(rdb, sn->params[j]); + hnsw_free_serialized_node(sn); + node = node->next; + } +} + +/* Load object from RDB */ +void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) { + if (encver != 0) return NULL; // Invalid version + + uint32_t dim = RedisModule_LoadUnsigned(rdb); + uint64_t elements = RedisModule_LoadUnsigned(rdb); + uint32_t quant_type = RedisModule_LoadUnsigned(rdb); + + struct vsetObject *vset = createVectorSetObject(dim,quant_type); + if (!vset) return NULL; + + /* Load projection matrix if present */ + uint32_t has_projection = RedisModule_LoadUnsigned(rdb); + if (has_projection) { + uint32_t input_dim = RedisModule_LoadUnsigned(rdb); + uint32_t output_dim = dim; + size_t matrix_size = sizeof(float) * input_dim * output_dim; + + vset->proj_matrix = RedisModule_Alloc(matrix_size); + if (!vset->proj_matrix) { + vectorSetReleaseObject(vset); + return NULL; + } + vset->proj_input_size = input_dim; + + // Load projection matrix as a binary blob + char *matrix_blob = RedisModule_LoadStringBuffer(rdb, NULL); + memcpy(vset->proj_matrix, matrix_blob, matrix_size); + RedisModule_Free(matrix_blob); + } + + while(elements--) { + // Load associated string element. + RedisModuleString *ele = RedisModule_LoadString(rdb); + size_t vector_len; + void *vector = RedisModule_LoadStringBuffer(rdb, &vector_len); + uint32_t vector_bytes = dim * (quant_type == HNSW_QUANT_Q8 ? 1 : 4); + if (vector_len != vector_bytes) { + RedisModule_LogIOError(rdb,"warning", + "Mismatching vector dimension"); + return NULL; // Loading error. + } + + // Load node parameters back. + uint32_t params_count = RedisModule_LoadUnsigned(rdb); + uint64_t *params = RedisModule_Alloc(params_count*sizeof(uint64_t)); + for (uint32_t j = 0; j < params_count; j++) + params[j] = RedisModule_LoadUnsigned(rdb); + + hnswNode *node = hnsw_insert_serialized(vset->hnsw, vector, params, params_count, ele); + if (node == NULL) { + RedisModule_LogIOError(rdb,"warning", + "Vector set node index loading error"); + return NULL; // Loading error. + } + RedisModule_DictSet(vset->dict,ele,node); + RedisModule_Free(vector); + RedisModule_Free(params); + } + hnsw_deserialize_index(vset->hnsw); + return vset; +} + +/* Calculate memory usage */ +size_t VectorSetMemUsage(const void *value) { + const struct vsetObject *vset = value; + size_t size = sizeof(*vset); + + /* Account for HNSW index base structure */ + size += sizeof(HNSW); + + /* Account for projection matrix if present */ + if (vset->proj_matrix) { + /* For the matrix size, we need the input dimension. We can get it + * from the first node if the set is not empty. */ + uint32_t input_dim = vset->proj_input_size; + uint32_t output_dim = vset->hnsw->vector_dim; + size += sizeof(float) * input_dim * output_dim; + } + + /* Account for each node's memory usage. */ + hnswNode *node = vset->hnsw->head; + if (node == NULL) return size; + + /* Base node structure. */ + size += sizeof(*node) * vset->hnsw->node_count; + + /* Vector storage. */ + uint64_t vec_storage = vset->hnsw->vector_dim; + if (vset->hnsw->quant_type == HNSW_QUANT_NONE) vec_storage *= 4; + size += vec_storage * vset->hnsw->node_count; + + /* Layers array. We use 1.33 as average nodes layers count. */ + uint64_t layers_storage = sizeof(hnswNodeLayer) * vset->hnsw->node_count; + layers_storage = layers_storage * 4 / 3; // 1.33 times. + size += layers_storage; + + /* All the nodes have layer 0 links. */ + uint64_t level0_links = node->layers[0].max_links; + uint64_t other_levels_links = level0_links/2; + size += sizeof(hnswNode*) * level0_links * vset->hnsw->node_count; + + /* Add the 0.33 remaining part, but upper layers have less links. */ + size += (sizeof(hnswNode*) * other_levels_links * vset->hnsw->node_count)/3; + + /* Associated string value - use Redis Module API to get string size, and + * guess that all the elements have similar size. */ + size += RedisModule_MallocSizeString(node->value) * vset->hnsw->node_count; + + /* Account for dictionary overhead - this is an approximation. */ + size += RedisModule_DictSize(vset->dict) * (sizeof(void*) * 2); + + return size; +} + +/* Free the entire data structure */ +void VectorSetFree(void *value) { + struct vsetObject *vset = value; + + // Wait for all the threads performing operations on this + // index to terminate their work (locking for write will + // wait for all the other threads). + pthread_rwlock_wrlock(&vset->in_use_lock); + + // This lock is managed only in the main thread, so we can + // unlock it now, to be able to destroy the mutex later + // in vectorSetReleaseObject(). + pthread_rwlock_unlock(&vset->in_use_lock); + vectorSetReleaseObject(value); +} + +/* Add object digest to the digest context */ +void VectorSetDigest(RedisModuleDigest *md, void *value) { + struct vsetObject *vset = value; + + /* Add consistent order-independent hash of all vectors */ + hnswNode *node = vset->hnsw->head; + while(node) { + /* Hash the vector dimension */ + RedisModule_DigestAddLongLong(md, vset->hnsw->vector_dim); + /* Hash each vector component */ + RedisModule_DigestAddStringBuffer(md, node->vector, hnsw_quants_bytes(vset->hnsw)); + /* Hash the associated value */ + size_t len; + const char *str = RedisModule_StringPtrLen(node->value, &len); + RedisModule_DigestAddStringBuffer(md, (char*)str, len); + node = node->next; + } + RedisModule_DigestEndSequence(md); +} + +/* This function must be present on each Redis module. It is used in order to + * register the commands into the Redis server. */ +int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + REDISMODULE_NOT_USED(argc); + + if (RedisModule_Init(ctx,"vectorset",1,REDISMODULE_APIVER_1) + == REDISMODULE_ERR) return REDISMODULE_ERR; + + RedisModuleTypeMethods tm = { + .version = REDISMODULE_TYPE_METHOD_VERSION, + .rdb_load = VectorSetRdbLoad, + .rdb_save = VectorSetRdbSave, + .aof_rewrite = NULL, + .mem_usage = VectorSetMemUsage, + .free = VectorSetFree, + .digest = VectorSetDigest + }; + + VectorSetType = RedisModule_CreateDataType(ctx,"vectorset",0,&tm); + if (VectorSetType == NULL) return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VADD", + VADD_RedisCommand,"write deny-oom",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VREM", + VREM_RedisCommand,"write",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"VSIM", + VSIM_RedisCommand,"readonly",1,1,1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VDIM", + VDIM_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VCARD", + VCARD_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VEMB", + VEMB_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VLINKS", + VLINKS_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx, "VINFO", + VINFO_RedisCommand, "readonly fast", 1, 1, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + hnsw_set_allocator(RedisModule_Free, RedisModule_Alloc, + RedisModule_Realloc); + + return REDISMODULE_OK; +} diff --git a/w2v.c b/w2v.c new file mode 100644 index 000000000..012dfd95c --- /dev/null +++ b/w2v.c @@ -0,0 +1,315 @@ +/* + * HNSW (Hierarchical Navigable Small World) Implementation + * Based on the paper by Yu. A. Malkov, D. A. Yashunin + * + * Copyright(C) 2024 Salvatore Sanfilippo. All Rights Reserved. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hnsw.h" + +/* Get current time in milliseconds */ +uint64_t ms_time(void) { + struct timeval tv; + gettimeofday(&tv, NULL); + return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000); +} + +/* Example usage in main() */ +int w2v_single_thread(int quantization, uint64_t numele, int massdel, int recall) { + /* Create index */ + HNSW *index = hnsw_new(300, quantization); + float v[300]; + uint16_t wlen; + + FILE *fp = fopen("word2vec.bin","rb"); + if (fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + unsigned char header[8]; + fread(header,8,1,fp); // Skip header + + uint64_t id = 0; + uint64_t start_time = ms_time(); + char *word = NULL; + hnswNode *search_node = NULL; + + while(id < numele) { + if (fread(&wlen,2,1,fp) == 0) break; + word = malloc(wlen+1); + fread(word,wlen,1,fp); + word[wlen] = 0; + fread(v,300*sizeof(float),1,fp); + + // Plain API that acquires a write lock for the whole time. + hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200); + + if (!strcmp(word,"banana")) search_node = added; + if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id); + } + uint64_t elapsed = ms_time() - start_time; + fclose(fp); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)index->node_count, + (unsigned long long)id*1000/elapsed, word); + + /* Search query */ + if (search_node == NULL) search_node = index->head; + hnsw_get_node_vector(index,search_node,v); + hnswNode *neighbors[10]; + float distances[10]; + + int found, j; + start_time = ms_time(); + for (j = 0; j < 20000; j++) + found = hnsw_search(index, v, 10, neighbors, distances, 0, 0); + elapsed = ms_time() - start_time; + printf("%d searches performed (%llu searches/sec), nodes found: %d\n", + j, (unsigned long long)j*1000/elapsed, found); + + if (found > 0) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + + // Recall test (slow). + if (recall) { + hnsw_print_stats(index); + hnsw_test_graph_recall(index,200,0); + } + + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + + if (massdel) { + int remove_perc = 95; + printf("\nRemoving %d%% of nodes...\n", remove_perc); + uint64_t initial_nodes = index->node_count; + + hnswNode *current = index->head; + while (current && index->node_count > initial_nodes*(100-remove_perc)/100) { + hnswNode *next = current->next; + hnsw_delete_node(index,current,free); + current = next; + // In order to don't remove only contiguous nodes, from time + // skip a node. + if (current && !(random() % remove_perc)) current = current->next; + } + printf("%llu nodes left\n", (unsigned long long)index->node_count); + + // Test again. + hnsw_validate_graph(index, &connected_nodes, &reciprocal_links); + hnsw_test_graph_recall(index,200,0); + } + + hnsw_free(index,free); + return 0; +} + +struct threadContext { + pthread_mutex_t FileAccessMutex; + uint64_t numele; + _Atomic uint64_t SearchesDone; + _Atomic uint64_t id; + FILE *fp; + HNSW *index; + float *search_vector; +}; + +// Note that in practical terms inserting with many concurrent threads +// may be *slower* and not faster, because there is a lot of +// contention. So this is more a robustness test than anything else. +// +// The optimistic commit API goal is actually to exploit the ability to +// add faster when there are many concurrent reads. +void *threaded_insert(void *ctxptr) { + struct threadContext *ctx = ctxptr; + char *word; + float v[300]; + uint16_t wlen; + + while(1) { + pthread_mutex_lock(&ctx->FileAccessMutex); + if (fread(&wlen,2,1,ctx->fp) == 0) break; + pthread_mutex_unlock(&ctx->FileAccessMutex); + word = malloc(wlen+1); + fread(word,wlen,1,ctx->fp); + word[wlen] = 0; + fread(v,300*sizeof(float),1,ctx->fp); + + // Check-and-set API that performs the costly scan for similar + // nodes concurrently with other read threads, and finally + // applies the check if the graph wasn't modified. + InsertContext *ic; + uint64_t next_id = ctx->id++; + ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, word, 200); + if (hnsw_try_commit_insert(ctx->index, ic) == NULL) { + // This time try locking since the start. + hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200); + } + + if (next_id >= ctx->numele) break; + if (!((next_id+1) % 10000)) + printf("%llu added\n", (unsigned long long)next_id+1); + } + return NULL; +} + +void *threaded_search(void *ctxptr) { + struct threadContext *ctx = ctxptr; + + /* Search query */ + hnswNode *neighbors[10]; + float distances[10]; + int found = 0; + uint64_t last_id = 0; + + while(ctx->id < 1000000) { + int slot = hnsw_acquire_read_slot(ctx->index); + found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0); + hnsw_release_read_slot(ctx->index,slot); + last_id = ++ctx->id; + } + + if (found > 0 && last_id == 1000000) { + printf("Found %d neighbors:\n", found); + for (int i = 0; i < found; i++) { + printf("Node ID: %llu, distance: %f, word: %s\n", + (unsigned long long)neighbors[i]->id, + distances[i], (char*)neighbors[i]->value); + } + } + return NULL; +} + +int w2v_multi_thread(int numthreads, int quantization, uint64_t numele) { + /* Create index */ + struct threadContext ctx; + + ctx.index = hnsw_new(300,quantization); + + ctx.fp = fopen("word2vec.bin","rb"); + if (ctx.fp == NULL) { + perror("word2vec.bin file missing"); + exit(1); + } + + unsigned char header[8]; + fread(header,8,1,ctx.fp); // Skip header + pthread_mutex_init(&ctx.FileAccessMutex,NULL); + + uint64_t start_time = ms_time(); + ctx.id = 0; + ctx.numele = numele; + pthread_t threads[numthreads]; + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_insert, &ctx); + + // Wait for all the threads to terminate adding items. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + uint64_t elapsed = ms_time() - start_time; + fclose(ctx.fp); + + // Obtain the last word. + hnswNode *node = ctx.index->head; + char *word = node->value; + + // We will search this last inserted word in the next test. + // Let's save its embedding. + ctx.search_vector = malloc(sizeof(float)*300); + hnsw_get_node_vector(ctx.index,node,ctx.search_vector); + + printf("%llu words added (%llu words/sec), last word: %s\n", + (unsigned long long)ctx.index->node_count, + (unsigned long long)ctx.id*1000/elapsed, word); + + /* Search query */ + start_time = ms_time(); + ctx.id = 0; // We will use this atomic field to stop at N queries done. + + for (int j = 0; j < numthreads; j++) + pthread_create(&threads[j], NULL, threaded_search, &ctx); + + // Wait for all the threads to terminate searching. + for (int j = 0; j < numthreads; j++) + pthread_join(threads[j],NULL); + + elapsed = ms_time() - start_time; + printf("%llu searches performed (%llu searches/sec)\n", + (unsigned long long)ctx.id, + (unsigned long long)ctx.id*1000/elapsed); + + hnsw_print_stats(ctx.index); + uint64_t connected_nodes; + int reciprocal_links; + hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links); + printf("%llu connected nodes. Links all reciprocal: %d\n", + (unsigned long long)connected_nodes, reciprocal_links); + hnsw_free(ctx.index,free); + return 0; +} + +int main(int argc, char **argv) { + int quantization = HNSW_QUANT_NONE; + int numthreads = 0; + uint64_t numele = 20000; + + /* This you can enable in single thread mode for testing: */ + int massdel = 0; // If true, does the mass deletion test. + int recall = 0; // If true, does the recall test. + + for (int j = 1; j < argc; j++) { + int moreargs = argc-j-1; + + if (!strcasecmp(argv[j],"--quant")) { + quantization = HNSW_QUANT_Q8; + } else if (!strcasecmp(argv[j],"--bin")) { + quantization = HNSW_QUANT_BIN; + } else if (!strcasecmp(argv[j],"--mass-del")) { + massdel = 1; + } else if (!strcasecmp(argv[j],"--recall")) { + recall = 1; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) { + numthreads = atoi(argv[j+1]); + j++; + } else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) { + numele = strtoll(argv[j+1],NULL,0); + j++; + if (numele < 1) numele = 1; + } else if (!strcasecmp(argv[j],"--help")) { + printf("%s [--quant] [--bin] [--thread ] [--numele ] [--mass-del] [--recall]\n", argv[0]); + exit(0); + } else { + printf("Unrecognized option: %s\n", argv[j]); + exit(1); + } + } + + if (quantization == HNSW_QUANT_NONE) { + printf("You can enable quantization with --quant\n"); + } + + if (numthreads > 0) { + w2v_multi_thread(numthreads,quantization,numele); + } else { + printf("Single thread execution. Use --threads 4 for concurrent API\n"); + w2v_single_thread(quantization,numele,massdel,recall); + } +}