First internal release.

This commit is contained in:
antirez 2025-01-27 17:24:02 +01:00
commit 33d653e24f
21 changed files with 7058 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
__pycache__
misc
*.so
*.xo
*.o
.DS_Store
w2v
word2vec.bin
TODO
*.txt

2
LICENSE Normal file
View File

@ -0,0 +1,2 @@
This code is Copyright (C) 2024-2025 Salvatore Sanfilippo.
All Rights Reserved.

77
Makefile Normal file
View File

@ -0,0 +1,77 @@
# Compiler settings
CC = gcc
ifdef SANITIZER
ifeq ($(SANITIZER),address)
SAN=-fsanitize=address
else
ifeq ($(SANITIZER),undefined)
SAN=-fsanitize=undefined
else
ifeq ($(SANITIZER),thread)
SAN=-fsanitize=thread
else
$(error "unknown sanitizer=${SANITIZER}")
endif
endif
endif
endif
CFLAGS = -O2 -Wall -Wextra -g -ffast-math $(SAN)
LDFLAGS = -lm $(SAN)
# Detect OS
uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not')
# Shared library compile flags for linux / osx
ifeq ($(uname_S),Linux)
SHOBJ_CFLAGS ?= -W -Wall -fno-common -g -ggdb -std=c99 -O2
SHOBJ_LDFLAGS ?= -shared
else
SHOBJ_CFLAGS ?= -W -Wall -dynamic -fno-common -g -ggdb -std=c99 -Ofast -ffast-math
SHOBJ_LDFLAGS ?= -bundle -undefined dynamic_lookup
endif
# OS X 11.x doesn't have /usr/lib/libSystem.dylib and needs an explicit setting.
ifeq ($(uname_S),Darwin)
ifeq ("$(wildcard /usr/lib/libSystem.dylib)","")
LIBS = -L /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lsystem
endif
endif
.SUFFIXES: .c .so .xo .o
all: vset.so
.c.xo:
$(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@
vset.xo: redismodule.h
vset.so: vset.xo hnsw.xo
$(LD) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc
# Example sources / objects
SRCS = hnsw.c w2v.c
OBJS = $(SRCS:.c=.o)
TARGET = w2v
MODULE = vset.so
# Default target
all: $(TARGET) $(MODULE)
# Example linking rule
$(TARGET): $(OBJS)
$(CC) $(OBJS) $(LDFLAGS) -o $(TARGET)
# Compilation rule for object files
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
# Clean rule
clean:
rm -f $(TARGET) $(OBJS) *.xo *.so
# Declare phony targets
.PHONY: all clean

175
README.md Normal file
View File

@ -0,0 +1,175 @@
This module implements vector sets for Redis, a new Redis data type similar
to sorted sets but having a vector instead of a score. It is possible to
add items and then get them back by similiarity to either a user-provided
vector or a vector of an element already inserted.
## Installation
make
Then load the module with the following command line, or by inserting the needed directives in the `redis.conf` file.
./redis-server --loadmodule vset.so
To run tests, I suggest using this:
./redis-server --save "" --enable-debug-command yes
The execute the tests with:
./test.py
## Commands
**VADD: add items into a vector set**
VADD key [REDUCE dim] FP32|VALUES vector element [CAS] [NOQUANT] [BIN]
Add a new element into the vector set specified by the key.
The vector can be provided as FP32 blob of values, or as floating point
numbers as strings, prefixed by the number of elements (3 in the example):
VADD mykey VALUES 3 0.1 1.2 0.5 my-element
The `REDUCE` option implements random projection, in order to reduce the
dimensionality of the vector. The projection matrix is saved and reloaded
along with the vector set.
The `CAS` option performs the operation partially using threads, in a
check-and-set style. The neighbor candidates collection, which is slow, is
performed in the background, while the command is executed in the main thread.
The `NOQUANT` option forces the vector to be created (in the first VADD call to a given key) without integer 8 quantization, which is otherwise the default.
The `BIN` option forces the vector to use binary quantization instead of int8. This is much faster and uses less memory, but has impacts on the recall quality.
**VSIM: return elements by vector similarity**
VSIM key [ELE|FP32|VALUES] <vector or element> [WITHSCORES] [COUNT num] [EF exploration-factor]
The command returns similar vectors, in the example instead of providing a vector using FP32 or VALUES (like in `VADD`), we will ask for elements associated with a vector similar to a given element already in the sorted set:
> VSIM word_embeddings ELE apple
1) "apple"
2) "apples"
3) "pear"
4) "fruit"
5) "berry"
6) "pears"
7) "strawberry"
8) "peach"
9) "potato"
10) "grape"
It is possible to specify a `COUNT` and also to get the similarity score (from 1 to 0, where 1 is identical, 0 is opposite vector) between the query and the returned items.
> VSIM word_embeddings ELE apple WITHSCORES COUNT 3
1) "apple"
2) "0.9998867657923256"
3) "apples"
4) "0.8598527610301971"
5) "pear"
6) "0.8226882219314575"
The `EF` argument is the exploration factor: the higher it is, the slower the command becomes, but the better the index is explored to find nodes that are near to our query. Sensible values are from 50 to 1000.
**VDIM: return the dimension of the vectors inside the vector set**
VDIM keyname
Example:
> VDIM word_embeddings
(integer) 300
Note that in the case of vectors that were populated using the `REDUCE`
option, for random projection, the vector set will report the size of
the projected (reduced) dimension. Yet the user should perform all the
queries using full-size vectors.
**VCARD: return the number of elements in a vector set**
VCARD key
Example:
> VCARD word_embeddings
(integer) 3000000
**VREM: remove elements from vector set**
VREM key element
Example:
> VADD vset VALUES 3 1 0 1 bar
(integer) 1
> VREM vset bar
(integer) 1
> VREM vset bar
(integer) 0
VREM does not perform thumstone / logical deletion, but will actually reclaim
the memory from the vector set, so it is save to add and remove elements
in a vector set in the context of long running applications that continuously
update the same index.
**VEMB: return the approximated vector of an element**
VEMB key element
Example:
> VEMB word_embeddings SQL
1) "0.18208661675453186"
2) "0.08535309880971909"
3) "0.1365649551153183"
4) "-0.16501599550247192"
5) "0.14225517213344574"
... 295 more elements ...
Because vector sets perform insertion time normalization and optional
quantization, the returned vector could be approximated. `VEMB` will take
care to de-quantized and de-normalize the vector before returning it.
**VLINKS: introspection command that shows neighbors for a node**
VLINKS key element [WITHSCORES]
The command reports the neighbors for each level.
**VINFO: introspection command that shows info about a vector set**
VINFO key
Example:
> VINFO word_embeddings
1) quant-type
2) int8
3) vector-dim
4) (integer) 300
5) size
6) (integer) 3000000
7) max-level
8) (integer) 12
9) vset-uid
10) (integer) 1
11) hnsw-max-node-uid
12) (integer) 3000000
## Known bugs
* When VADD with REDUCE is replicated, we should probably send the replicas the random matrix, in order for VEMB to read the same things. This is not critical, because the behavior of VADD / VSIM should be transparent if you don't look at the transformed vectors, but still probably worth doing.
* Replication code is pretty much untested, and very vanilla (replicating the commands verbatim).
## Implementation details
Vector sets are based on the `hnsw.c` implementation of the HNSW data structure with extensions for speed and functionality.
The main features are:
* Proper nodes deletion with relinking.
* 8 bits quantization.
* Threaded queries.

2482
hnsw.c Normal file

File diff suppressed because it is too large Load Diff

158
hnsw.h Normal file
View File

@ -0,0 +1,158 @@
/*
* HNSW (Hierarchical Navigable Small World) Implementation
* Based on the paper by Yu. A. Malkov, D. A. Yashunin
*
* Copyright(C) 2024 Salvatore Sanfilippo. All Rights Reserved.
*/
#ifndef HNSW_H
#define HNSW_H
#include <pthread.h>
#include <stdatomic.h>
#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */
/* Quantization types you can enable at creation time in hnsw_new() */
#define HNSW_QUANT_NONE 0 // No quantization.
#define HNSW_QUANT_Q8 1 // Q8 quantization.
#define HNSW_QUANT_BIN 2 // Binary quantization.
/* Layer structure for HNSW nodes. Each node will have from one to a few
* of this depending on its level. */
typedef struct {
struct hnswNode **links; /* Array of neighbors for this layer */
uint32_t num_links; /* Number of used links */
uint32_t max_links; /* Maximum links for this layer. We may
* reallocate the node in very particular
* conditions in order to allow linking of
* new inserted nodes, so this may change
* dynamically for a small set of nodes. */
float worst_distance; /* Distance to the worst neighbor */
uint32_t worst_idx; /* Index of the worst neighbor */
} hnswNodeLayer;
/* Node structure for HNSW graph */
typedef struct hnswNode {
uint32_t level; /* Node's maximum level */
uint64_t id; /* Unique identifier, may be useful in order to
* have a bitmap of visited notes to use as
* alternative to epoch / visited_epoch.
* Also used in serialization in order to retain
* links specifying IDs. */
void *vector; /* The vector, quantized or not. */
float quants_range; /* Quantization range for this vector:
* min/max values will be in the range
* -quants_range, +quants_range */
float l2; /* L2 before normalization. */
/* Last time (epoch) this node was visited. We need one per thread.
* This avoids having a different data structure where we track
* visited nodes, but costs memory per node. */
uint64_t visited_epoch[HNSW_MAX_THREADS];
void *value; /* Associated value */
struct hnswNode *prev, *next; /* Prev/Next node in the list starting at
* HNSW->head. */
/* Links (and links info) per each layer. Note that this is part
* of the node allocation to be more cache friendly: reliable 3% speedup
* on Apple silicon, and does not make anything more complex. */
hnswNodeLayer layers[];
} hnswNode;
/* It is possible to navigate an HNSW with a cursor that guarantees
* visiting all the elements that remain in the HNSW from the start to the
* end of the process (but not the new ones, so that the process will
* eventually finish). Check hnsw_cursor_init(), hnsw_cursor_next() and
* hnsw_cursor_free(). */
typedef struct hnswCursor {
hnswNode *current; // Element to report when hnsw_cursor_next() is called.
struct hnswCursor *next; // Next cursor active.
} hnswCursor;
/* Main HNSW index structure */
typedef struct HNSW {
hnswNode *enter_point; /* Entry point for the graph */
uint32_t max_level; /* Current maximum level in the graph */
uint32_t vector_dim; /* Dimensionality of stored vectors */
uint64_t node_count; /* Total number of nodes */
_Atomic uint64_t last_id; /* Last node ID used */
uint64_t current_epoch[HNSW_MAX_THREADS]; /* Current epoch for visit tracking */
hnswNode *head; /* Linked list of nodes. Last first */
/* We have two locks here:
* 1. A global_lock that is used to perform write operations blocking all
* the readers.
* 2. One mutex per epoch slot, in order for read operations to acquire
* a lock on a specific slot to use epochs tracking of visited nodes. */
pthread_rwlock_t global_lock; /* Global read-write lock */
pthread_mutex_t slot_locks[HNSW_MAX_THREADS]; /* Per-slot locks */
_Atomic uint32_t next_slot; /* Next thread slot to try */
_Atomic uint64_t version; /* Version for optimistic concurrency, this is
* incremented on deletions and entry point
* updates. */
uint32_t quant_type; /* Quantization used. HNSW_QUANT_... */
hnswCursor *cursors;
} HNSW;
/* Serialized node. This structure is used as return value of
* hnsw_serialize_node(). */
typedef struct hnswSerNode {
void *vector;
uint32_t vector_size;
uint64_t *params;
uint32_t params_count;
} hnswSerNode;
/* Insert preparation context */
typedef struct InsertContext InsertContext;
/* Core HNSW functions */
HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type);
void hnsw_free(HNSW *index,void(*free_value)(void*value));
void hnsw_node_free(hnswNode *node);
void hnsw_print_stats(HNSW *index);
hnswNode *hnsw_insert(HNSW *index, const float *vector, const int8_t *qvector,
float qrange, uint64_t id, void *value, int ef);
int hnsw_search(HNSW *index, const float *query, uint32_t k,
hnswNode **neighbors, float *distances, uint32_t slot,
int query_vector_is_normalized);
void hnsw_get_node_vector(HNSW *index, hnswNode *node, float *vec);
void hnsw_delete_node(HNSW *index, hnswNode *node, void(*free_value)(void*value));
/* Thread safety functions. */
int hnsw_acquire_read_slot(HNSW *index);
void hnsw_release_read_slot(HNSW *index, int slot);
/* Optimistic insertion API. */
InsertContext *hnsw_prepare_insert(HNSW *index, const float *vector, const int8_t *qvector, float qrange, uint64_t id, void *value, int ef);
hnswNode *hnsw_try_commit_insert(HNSW *index, InsertContext *ctx);
void hnsw_free_insert_context(InsertContext *ctx);
/* Serialization. */
hnswSerNode *hnsw_serialize_node(HNSW *index, hnswNode *node);
void hnsw_free_serialized_node(hnswSerNode *sn);
hnswNode *hnsw_insert_serialized(HNSW *index, void *vector, uint64_t *params, uint32_t params_len, void *value);
int hnsw_deserialize_index(HNSW *index);
// Helper function in case the user wants to directly copy
// the vector bytes.
uint32_t hnsw_quants_bytes(HNSW *index);
/* Cursors. */
hnswCursor *hnsw_cursor_init(HNSW *index);
void hnsw_cursor_free(HNSW *index, hnswCursor *cursor);
hnswNode *hnsw_cursor_next(HNSW *index, hnswCursor *cursor);
/* Allocator selection. */
void hnsw_set_allocator(void (*free_ptr)(void*), void *(*malloc_ptr)(size_t),
void *(*realloc_ptr)(void*, size_t));
/* Testing. */
int hnsw_validate_graph(HNSW *index, uint64_t *connected_nodes, int *reciprocal_links);
void hnsw_test_graph_recall(HNSW *index, int test_ef, int verbose);
float hnsw_distance(HNSW *index, hnswNode *a, hnswNode *b);
#endif /* HNSW_H */

1704
redismodule.h Normal file

File diff suppressed because it is too large Load Diff

189
test.py Executable file
View File

@ -0,0 +1,189 @@
#!/usr/bin/env python3
#
# Vector set tests.
# A Redis instance should be running in the default port.
# Copyright(C) 2024-2025 Salvatore Sanfilippo.
# All Rights Reserved.
#!/usr/bin/env python3
import redis
import random
import struct
import math
import time
import sys
import os
import importlib
import inspect
from typing import List, Tuple, Optional
from dataclasses import dataclass
def colored(text: str, color: str) -> str:
colors = {
'red': '\033[91m',
'green': '\033[92m'
}
reset = '\033[0m'
return f"{colors.get(color, '')}{text}{reset}"
@dataclass
class VectorData:
vectors: List[List[float]]
names: List[str]
def find_k_nearest(self, query_vector: List[float], k: int) -> List[Tuple[str, float]]:
"""Find k-nearest neighbors using the same scoring as Redis VSIM WITHSCORES."""
similarities = []
query_norm = math.sqrt(sum(x*x for x in query_vector))
if query_norm == 0:
return []
for i, vec in enumerate(self.vectors):
vec_norm = math.sqrt(sum(x*x for x in vec))
if vec_norm == 0:
continue
dot_product = sum(a*b for a,b in zip(query_vector, vec))
cosine_sim = dot_product / (query_norm * vec_norm)
distance = 1.0 - cosine_sim
redis_similarity = 1.0 - (distance/2.0)
similarities.append((self.names[i], redis_similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:k]
def generate_random_vector(dim: int) -> List[float]:
"""Generate a random normalized vector."""
vec = [random.gauss(0, 1) for _ in range(dim)]
norm = math.sqrt(sum(x*x for x in vec))
return [x/norm for x in vec]
def fill_redis_with_vectors(r: redis.Redis, key: str, count: int, dim: int,
with_reduce: Optional[int] = None) -> VectorData:
"""Fill Redis with random vectors and return a VectorData object for verification."""
vectors = []
names = []
r.delete(key)
for i in range(count):
vec = generate_random_vector(dim)
name = f"{key}:item:{i}"
vectors.append(vec)
names.append(name)
vec_bytes = struct.pack(f'{dim}f', *vec)
args = [key]
if with_reduce:
args.extend(['REDUCE', with_reduce])
args.extend(['FP32', vec_bytes, name])
r.execute_command('VADD', *args)
return VectorData(vectors=vectors, names=names)
class TestCase:
def __init__(self):
self.error_msg = None
self.error_details = None
self.test_key = f"test:{self.__class__.__name__.lower()}"
self.redis = redis.Redis()
def setup(self):
self.redis.delete(self.test_key)
def teardown(self):
self.redis.delete(self.test_key)
def test(self):
raise NotImplementedError("Subclasses must implement test method")
def run(self):
try:
self.setup()
self.test()
return True
except AssertionError as e:
self.error_msg = str(e)
import traceback
self.error_details = traceback.format_exc()
return False
except Exception as e:
self.error_msg = f"Unexpected error: {str(e)}"
import traceback
self.error_details = traceback.format_exc()
return False
finally:
self.teardown()
def getname(self):
"""Each test class should override this to provide its name"""
return self.__class__.__name__
def estimated_runtime(self):
""""Each test class should override this if it takes a significant amount of time to run. Default is 100ms"""
return 0.1
def find_test_classes():
test_classes = []
tests_dir = 'tests'
if not os.path.exists(tests_dir):
return []
for file in os.listdir(tests_dir):
if file.endswith('.py'):
module_name = f"tests.{file[:-3]}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and obj.__name__ != 'TestCase' and hasattr(obj, 'test'):
test_classes.append(obj())
except Exception as e:
print(f"Error loading {file}: {e}")
return test_classes
def run_tests():
print("================================================\n"+
"Make sure to have Redis running in the localhost\n"+
"with --enable-debug-command yes\n"+
"================================================\n")
tests = find_test_classes()
if not tests:
print("No tests found!")
return
# Sort tests by estimated runtime
tests.sort(key=lambda t: t.estimated_runtime())
passed = 0
total = len(tests)
for test in tests:
print(f"{test.getname()}: ", end="")
sys.stdout.flush()
start_time = time.time()
success = test.run()
duration = time.time() - start_time
if success:
print(colored("OK", "green"), f"({duration:.2f}s)")
passed += 1
else:
print(colored("ERR", "red"), f"({duration:.2f}s)")
print(f"Error: {test.error_msg}")
if test.error_details:
print("\nTraceback:")
print(test.error_details)
print("\n" + "="*50)
print(f"\nTest Summary: {passed}/{total} tests passed")
if passed == total:
print(colored("\nALL TESTS PASSED!", "green"))
else:
print(colored(f"\n{total-passed} TESTS FAILED!", "red"))
if __name__ == "__main__":
run_tests()

21
tests/basic_commands.py Normal file
View File

@ -0,0 +1,21 @@
from test import TestCase, generate_random_vector
import struct
class BasicCommands(TestCase):
def getname(self):
return "VADD, VDIM, VCARD basic usage"
def test(self):
# Test VADD
vec = generate_random_vector(4)
vec_bytes = struct.pack('4f', *vec)
result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
assert result == 1, "VADD should return 1 for first item"
# Test VDIM
dim = self.redis.execute_command('VDIM', self.test_key)
assert dim == 4, f"VDIM should return 4, got {dim}"
# Test VCARD
card = self.redis.execute_command('VCARD', self.test_key)
assert card == 1, f"VCARD should return 1, got {card}"

35
tests/basic_similarity.py Normal file
View File

@ -0,0 +1,35 @@
from test import TestCase
class BasicSimilarity(TestCase):
def getname(self):
return "VSIM reported distance makes sense with 4D vectors"
def test(self):
# Add two very similar vectors, one different
vec1 = [1, 0, 0, 0]
vec2 = [0.99, 0.01, 0, 0]
vec3 = [0.1, 1, -1, 0.5]
# Add vectors using VALUES format
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1], f'{self.test_key}:item:1')
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec2], f'{self.test_key}:item:2')
self.redis.execute_command('VADD', self.test_key, 'VALUES', 4,
*[str(x) for x in vec3], f'{self.test_key}:item:3')
# Query similarity with vec1
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', 4,
*[str(x) for x in vec1], 'WITHSCORES')
# Convert results to dictionary
results_dict = {}
for i in range(0, len(result), 2):
key = result[i].decode()
score = float(result[i+1])
results_dict[key] = score
# Verify results
assert results_dict[f'{self.test_key}:item:1'] > 0.99, "Self-similarity should be very high"
assert results_dict[f'{self.test_key}:item:2'] > 0.99, "Similar vector should have high similarity"
assert results_dict[f'{self.test_key}:item:3'] < 0.8, "Not very similar vector should have low similarity"

View File

@ -0,0 +1,48 @@
from test import TestCase, fill_redis_with_vectors, generate_random_vector
import threading, time
class ConcurrentVSIMAndDEL(TestCase):
def getname(self):
return "Concurrent VSIM and DEL operations"
def estimated_runtime(self):
return 2
def test(self):
# Fill the key with 5000 random vectors
dim = 128
count = 5000
fill_redis_with_vectors(self.redis, self.test_key, count, dim)
# List to store results from threads
thread_results = []
def vsim_thread():
"""Thread function to perform VSIM operations until the key is deleted"""
while True:
query_vec = generate_random_vector(dim)
result = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in query_vec], 'COUNT', 10)
if not result:
# Empty array detected, key is deleted
thread_results.append(True)
break
# Start multiple threads to perform VSIM operations
threads = []
for _ in range(4): # Start 4 threads
t = threading.Thread(target=vsim_thread)
t.start()
threads.append(t)
# Delete the key while threads are still running
time.sleep(1)
self.redis.delete(self.test_key)
# Wait for all threads to finish (they will exit once they detect the key is deleted)
for t in threads:
t.join()
# Verify that all threads detected an empty array or error
assert len(thread_results) == len(threads), "Not all threads detected the key deletion"
assert all(thread_results), "Some threads did not detect an empty array or error after DEL"

173
tests/deletion.py Normal file
View File

@ -0,0 +1,173 @@
from test import TestCase, fill_redis_with_vectors, generate_random_vector
import random
"""
A note about this test:
It was experimentally tried to modify hnsw.c in order to
avoid calling hnsw_reconnect_nodes(). In this case, the test
fails very often with EF set to 250, while it hardly
fails at all with the same parameters if hnsw_reconnect_nodes()
is called.
Note that for the nature of the test (it is very strict) it can
still fail from time to time, without this signaling any
actual bug.
"""
class VREM(TestCase):
def getname(self):
return "Deletion and graph state after deletion"
def estimated_runtime(self):
return 2.0
def format_neighbors_with_scores(self, links_result, old_links=None, items_to_remove=None):
"""Format neighbors with their similarity scores and status indicators"""
if not links_result:
return "No neighbors"
output = []
for level, neighbors in enumerate(links_result):
level_num = len(links_result) - level - 1
output.append(f"Level {level_num}:")
# Get neighbors and scores
neighbors_with_scores = []
for i in range(0, len(neighbors), 2):
neighbor = neighbors[i].decode() if isinstance(neighbors[i], bytes) else neighbors[i]
score = float(neighbors[i+1]) if i+1 < len(neighbors) else None
status = ""
# For old links, mark deleted ones
if items_to_remove and neighbor in items_to_remove:
status = " [lost]"
# For new links, mark newly added ones
elif old_links is not None:
# Check if this neighbor was in the old links at this level
was_present = False
if old_links and level < len(old_links):
old_neighbors = [n.decode() if isinstance(n, bytes) else n
for n in old_links[level]]
was_present = neighbor in old_neighbors
if not was_present:
status = " [gained]"
if score is not None:
neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor} ({score:.6f}){status}")
else:
neighbors_with_scores.append(f"{len(neighbors_with_scores)+1}. {neighbor}{status}")
output.extend([" " + n for n in neighbors_with_scores])
return "\n".join(output)
def test(self):
# 1. Fill server with random elements
dim = 128
count = 5000
data = fill_redis_with_vectors(self.redis, self.test_key, count, dim)
# 2. Do VSIM to get 200 items
query_vec = generate_random_vector(dim)
results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in query_vec],
'COUNT', 200, 'WITHSCORES')
# Convert results to list of (item, score) pairs, sorted by score
items = []
for i in range(0, len(results), 2):
item = results[i].decode()
score = float(results[i+1])
items.append((item, score))
items.sort(key=lambda x: x[1], reverse=True) # Sort by similarity
# Store the graph structure for all items before deletion
neighbors_before = {}
for item, _ in items:
links = self.redis.execute_command('VLINKS', self.test_key, item, 'WITHSCORES')
if links: # Some items might not have links
neighbors_before[item] = links
# 3. Remove 100 random items
items_to_remove = set(item for item, _ in random.sample(items, 100))
# Keep track of top 10 non-removed items
top_remaining = []
for item, score in items:
if item not in items_to_remove:
top_remaining.append((item, score))
if len(top_remaining) == 10:
break
# Remove the items
for item in items_to_remove:
result = self.redis.execute_command('VREM', self.test_key, item)
assert result == 1, f"VREM failed to remove {item}"
# 4. Do VSIM again with same vector
new_results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in query_vec],
'COUNT', 200, 'WITHSCORES',
'EF', 500)
# Convert new results to dict of item -> score
new_scores = {}
for i in range(0, len(new_results), 2):
item = new_results[i].decode()
score = float(new_results[i+1])
new_scores[item] = score
failure = False
failed_item = None
failed_reason = None
# 5. Verify all top 10 non-removed items are still found with similar scores
for item, old_score in top_remaining:
if item not in new_scores:
failure = True
failed_item = item
failed_reason = "missing"
break
new_score = new_scores[item]
if abs(new_score - old_score) >= 0.01:
failure = True
failed_item = item
failed_reason = f"score changed: {old_score:.6f} -> {new_score:.6f}"
break
if failure:
print("\nTest failed!")
print(f"Problem with item: {failed_item} ({failed_reason})")
print("\nOriginal neighbors (with similarity scores):")
if failed_item in neighbors_before:
print(self.format_neighbors_with_scores(
neighbors_before[failed_item],
items_to_remove=items_to_remove))
else:
print("No neighbors found in original graph")
print("\nCurrent neighbors (with similarity scores):")
current_links = self.redis.execute_command('VLINKS', self.test_key,
failed_item, 'WITHSCORES')
if current_links:
print(self.format_neighbors_with_scores(
current_links,
old_links=neighbors_before.get(failed_item)))
else:
print("No neighbors in current graph")
print("\nOriginal results (top 20):")
for item, score in items[:20]:
deleted = "[deleted]" if item in items_to_remove else ""
print(f"{item}: {score:.6f} {deleted}")
print("\nNew results after removal (top 20):")
new_items = []
for i in range(0, len(new_results), 2):
item = new_results[i].decode()
score = float(new_results[i+1])
new_items.append((item, score))
new_items.sort(key=lambda x: x[1], reverse=True)
for item, score in new_items[:20]:
print(f"{item}: {score:.6f}")
raise AssertionError(f"Test failed: Problem with item {failed_item} ({failed_reason}). *** IMPORTANT *** This test may fail from time to time without indicating that there is a bug. However normally it should pass. The fact is that it's a quite extreme test where we destroy 50% of nodes of top results and still expect perfect recall, with vectors that are very hostile because of the distribution used.")

27
tests/evict_empty.py Normal file
View File

@ -0,0 +1,27 @@
from test import TestCase, generate_random_vector
import struct
class VREM_LastItemDeletesKey(TestCase):
def getname(self):
return "VREM last item deletes key"
def test(self):
# Generate a random vector
vec = generate_random_vector(4)
vec_bytes = struct.pack('4f', *vec)
# Add the vector to the key
result = self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
assert result == 1, "VADD should return 1 for first item"
# Verify the key exists
exists = self.redis.exists(self.test_key)
assert exists == 1, "Key should exist after VADD"
# Remove the item
result = self.redis.execute_command('VREM', self.test_key, f'{self.test_key}:item:1')
assert result == 1, "VREM should return 1 for successful removal"
# Verify the key no longer exists
exists = self.redis.exists(self.test_key)
assert exists == 0, "Key should no longer exist after VREM of last item"

56
tests/large_scale.py Normal file
View File

@ -0,0 +1,56 @@
from test import TestCase, fill_redis_with_vectors, generate_random_vector
import random
class LargeScale(TestCase):
def getname(self):
return "Large Scale Comparison"
def estimated_runtime(self):
return 10
def test(self):
dim = 300
count = 20000
k = 50
# Fill Redis and get reference data for comparison
random.seed(42) # Make test deterministic
data = fill_redis_with_vectors(self.redis, self.test_key, count, dim)
# Generate query vector
query_vec = generate_random_vector(dim)
# Get results from Redis with good exploration factor
redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in query_vec],
'COUNT', k, 'WITHSCORES', 'EF', 500)
# Convert Redis results to dict
redis_results = {}
for i in range(0, len(redis_raw), 2):
key = redis_raw[i].decode()
score = float(redis_raw[i+1])
redis_results[key] = score
# Get results from linear scan
linear_results = data.find_k_nearest(query_vec, k)
linear_items = {name: score for name, score in linear_results}
# Compare overlap
redis_set = set(redis_results.keys())
linear_set = set(linear_items.keys())
overlap = len(redis_set & linear_set)
# If test fails, print comparison for debugging
if overlap < k * 0.7:
data.print_comparison({'items': redis_results, 'query_vector': query_vec}, k)
assert overlap >= k * 0.7, \
f"Expected at least 70% overlap in top {k} results, got {overlap/k*100:.1f}%"
# Verify scores for common items
for item in redis_set & linear_set:
redis_score = redis_results[item]
linear_score = linear_items[item]
assert abs(redis_score - linear_score) < 0.01, \
f"Score mismatch for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}"

85
tests/node_update.py Normal file
View File

@ -0,0 +1,85 @@
from test import TestCase, generate_random_vector
import struct
import math
import random
class VectorUpdateAndClusters(TestCase):
def getname(self):
return "VADD vector update with cluster relocation"
def estimated_runtime(self):
return 2.0 # Should take around 2 seconds
def generate_cluster_vector(self, base_vec, noise=0.1):
"""Generate a vector that's similar to base_vec with some noise."""
vec = [x + random.gauss(0, noise) for x in base_vec]
# Normalize
norm = math.sqrt(sum(x*x for x in vec))
return [x/norm for x in vec]
def test(self):
dim = 128
vectors_per_cluster = 5000
# Create two very different base vectors for our clusters
cluster1_base = generate_random_vector(dim)
cluster2_base = [-x for x in cluster1_base] # Opposite direction
# Add vectors from first cluster
for i in range(vectors_per_cluster):
vec = self.generate_cluster_vector(cluster1_base)
vec_bytes = struct.pack(f'{dim}f', *vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes,
f'{self.test_key}:cluster1:{i}')
# Add vectors from second cluster
for i in range(vectors_per_cluster):
vec = self.generate_cluster_vector(cluster2_base)
vec_bytes = struct.pack(f'{dim}f', *vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes,
f'{self.test_key}:cluster2:{i}')
# Pick a test vector from cluster1
test_key = f'{self.test_key}:cluster1:0'
# Verify it's in cluster1 using VSIM
initial_vec = self.generate_cluster_vector(cluster1_base)
results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in initial_vec],
'COUNT', 100, 'WITHSCORES')
# Count how many cluster1 items are in top results
cluster1_count = sum(1 for i in range(0, len(results), 2)
if b'cluster1' in results[i])
assert cluster1_count > 80, "Initial clustering check failed"
# Now update the test vector to be in cluster2
new_vec = self.generate_cluster_vector(cluster2_base, noise=0.05)
vec_bytes = struct.pack(f'{dim}f', *new_vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, test_key)
# Verify the embedding was actually updated using VEMB
emb_result = self.redis.execute_command('VEMB', self.test_key, test_key)
updated_vec = [float(x) for x in emb_result]
# Verify updated vector matches what we inserted
dot_product = sum(a*b for a,b in zip(updated_vec, new_vec))
similarity = dot_product / (math.sqrt(sum(x*x for x in updated_vec)) *
math.sqrt(sum(x*x for x in new_vec)))
assert similarity > 0.9, "Vector was not properly updated"
# Verify it's now in cluster2 using VSIM
results = self.redis.execute_command('VSIM', self.test_key, 'VALUES', dim,
*[str(x) for x in cluster2_base],
'COUNT', 100, 'WITHSCORES')
# Verify our updated vector is among top results
found = False
for i in range(0, len(results), 2):
if results[i].decode() == test_key:
found = True
similarity = float(results[i+1])
assert similarity > 0.80, f"Updated vector has low similarity: {similarity}"
break
assert found, "Updated vector not found in cluster2 proximity"

83
tests/persistence.py Normal file
View File

@ -0,0 +1,83 @@
from test import TestCase, fill_redis_with_vectors, generate_random_vector
import random
class HNSWPersistence(TestCase):
def getname(self):
return "HNSW Persistence"
def estimated_runtime(self):
return 30
def _verify_results(self, key, dim, query_vec, reduced_dim=None):
"""Run a query and return results dict"""
k = 10
args = ['VSIM', key]
if reduced_dim:
args.extend(['VALUES', dim])
args.extend([str(x) for x in query_vec])
else:
args.extend(['VALUES', dim])
args.extend([str(x) for x in query_vec])
args.extend(['COUNT', k, 'WITHSCORES'])
results = self.redis.execute_command(*args)
results_dict = {}
for i in range(0, len(results), 2):
key = results[i].decode()
score = float(results[i+1])
results_dict[key] = score
return results_dict
def test(self):
# Setup dimensions
dim = 128
reduced_dim = 32
count = 5000
random.seed(42)
# Create two datasets - one normal and one with dimension reduction
normal_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:normal", count, dim)
projected_data = fill_redis_with_vectors(self.redis, f"{self.test_key}:projected",
count, dim, reduced_dim)
# Generate query vectors we'll use before and after reload
query_vec_normal = generate_random_vector(dim)
query_vec_projected = generate_random_vector(dim)
# Get initial results for both sets
initial_normal = self._verify_results(f"{self.test_key}:normal",
dim, query_vec_normal)
initial_projected = self._verify_results(f"{self.test_key}:projected",
dim, query_vec_projected, reduced_dim)
# Force Redis to save and reload the dataset
self.redis.execute_command('DEBUG', 'RELOAD')
# Verify results after reload
reloaded_normal = self._verify_results(f"{self.test_key}:normal",
dim, query_vec_normal)
reloaded_projected = self._verify_results(f"{self.test_key}:projected",
dim, query_vec_projected, reduced_dim)
# Verify normal vectors results
assert len(initial_normal) == len(reloaded_normal), \
"Normal vectors: Result count mismatch before/after reload"
for key in initial_normal:
assert key in reloaded_normal, f"Normal vectors: Missing item after reload: {key}"
assert abs(initial_normal[key] - reloaded_normal[key]) < 0.0001, \
f"Normal vectors: Score mismatch for {key}: " + \
f"before={initial_normal[key]:.6f}, after={reloaded_normal[key]:.6f}"
# Verify projected vectors results
assert len(initial_projected) == len(reloaded_projected), \
"Projected vectors: Result count mismatch before/after reload"
for key in initial_projected:
assert key in reloaded_projected, \
f"Projected vectors: Missing item after reload: {key}"
assert abs(initial_projected[key] - reloaded_projected[key]) < 0.0001, \
f"Projected vectors: Score mismatch for {key}: " + \
f"before={initial_projected[key]:.6f}, after={reloaded_projected[key]:.6f}"

71
tests/reduce.py Normal file
View File

@ -0,0 +1,71 @@
from test import TestCase, fill_redis_with_vectors, generate_random_vector
class Reduce(TestCase):
def getname(self):
return "Dimension Reduction"
def estimated_runtime(self):
return 0.2
def test(self):
original_dim = 100
reduced_dim = 80
count = 1000
k = 50 # Number of nearest neighbors to check
# Fill Redis with vectors using REDUCE and get reference data
data = fill_redis_with_vectors(self.redis, self.test_key, count, original_dim, reduced_dim)
# Verify dimension is reduced
dim = self.redis.execute_command('VDIM', self.test_key)
assert dim == reduced_dim, f"Expected dimension {reduced_dim}, got {dim}"
# Generate query vector and get nearest neighbors using Redis
query_vec = generate_random_vector(original_dim)
redis_raw = self.redis.execute_command('VSIM', self.test_key, 'VALUES',
original_dim, *[str(x) for x in query_vec],
'COUNT', k, 'WITHSCORES')
# Convert Redis results to dict
redis_results = {}
for i in range(0, len(redis_raw), 2):
key = redis_raw[i].decode()
score = float(redis_raw[i+1])
redis_results[key] = score
# Get results from linear scan with original vectors
linear_results = data.find_k_nearest(query_vec, k)
linear_items = {name: score for name, score in linear_results}
# Compare overlap between reduced and non-reduced results
redis_set = set(redis_results.keys())
linear_set = set(linear_items.keys())
overlap = len(redis_set & linear_set)
overlap_ratio = overlap / k
# With random projection, we expect some loss of accuracy but should
# maintain at least some similarity structure.
# Note that gaussian distribution is the worse with this test, so
# in real world practice, things will be better.
min_expected_overlap = 0.1 # At least 10% overlap in top-k
assert overlap_ratio >= min_expected_overlap, \
f"Dimension reduction lost too much structure. Only {overlap_ratio*100:.1f}% overlap in top {k}"
# For items that appear in both results, scores should be reasonably correlated
common_items = redis_set & linear_set
for item in common_items:
redis_score = redis_results[item]
linear_score = linear_items[item]
# Allow for some deviation due to dimensionality reduction
assert abs(redis_score - linear_score) < 0.2, \
f"Score mismatch too high for {item}: Redis={redis_score:.3f} Linear={linear_score:.3f}"
# If test fails, print comparison for debugging
if overlap_ratio < min_expected_overlap:
print("\nLow overlap in results. Details:")
print("\nTop results from linear scan (original vectors):")
for name, score in linear_results:
print(f"{name}: {score:.3f}")
print("\nTop results from Redis (reduced vectors):")
for item, score in sorted(redis_results.items(), key=lambda x: x[1], reverse=True):
print(f"{item}: {score:.3f}")

98
tests/vadd_cas.py Normal file
View File

@ -0,0 +1,98 @@
from test import TestCase, generate_random_vector
import threading
import struct
import math
import time
import random
from typing import List, Dict
class ConcurrentCASTest(TestCase):
def getname(self):
return "Concurrent VADD with CAS"
def estimated_runtime(self):
return 1.5
def worker(self, vectors: List[List[float]], start_idx: int, end_idx: int,
dim: int, results: Dict[str, bool]):
"""Worker thread that adds a subset of vectors using VADD CAS"""
for i in range(start_idx, end_idx):
vec = vectors[i]
name = f"{self.test_key}:item:{i}"
vec_bytes = struct.pack(f'{dim}f', *vec)
# Try to add the vector with CAS
try:
result = self.redis.execute_command('VADD', self.test_key, 'FP32',
vec_bytes, name, 'CAS')
results[name] = (result == 1) # Store if it was actually added
except Exception as e:
results[name] = False
print(f"Error adding {name}: {e}")
def verify_vector_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
dot_product = sum(a*b for a,b in zip(vec1, vec2))
norm1 = math.sqrt(sum(x*x for x in vec1))
norm2 = math.sqrt(sum(x*x for x in vec2))
return dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0
def test(self):
# Test parameters
dim = 128
total_vectors = 5000
num_threads = 8
vectors_per_thread = total_vectors // num_threads
# Generate all vectors upfront
random.seed(42) # For reproducibility
vectors = [generate_random_vector(dim) for _ in range(total_vectors)]
# Prepare threads and results dictionary
threads = []
results = {} # Will store success/failure for each vector
# Launch threads
for i in range(num_threads):
start_idx = i * vectors_per_thread
end_idx = start_idx + vectors_per_thread if i < num_threads-1 else total_vectors
thread = threading.Thread(target=self.worker,
args=(vectors, start_idx, end_idx, dim, results))
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify cardinality
card = self.redis.execute_command('VCARD', self.test_key)
assert card == total_vectors, \
f"Expected {total_vectors} elements, but found {card}"
# Verify each vector
num_verified = 0
for i in range(total_vectors):
name = f"{self.test_key}:item:{i}"
# Verify the item was successfully added
assert results[name], f"Vector {name} was not successfully added"
# Get the stored vector
stored_vec_raw = self.redis.execute_command('VEMB', self.test_key, name)
stored_vec = [float(x) for x in stored_vec_raw]
# Verify vector dimensions
assert len(stored_vec) == dim, \
f"Stored vector dimension mismatch for {name}: {len(stored_vec)} != {dim}"
# Calculate similarity with original vector
similarity = self.verify_vector_similarity(vectors[i], stored_vec)
assert similarity > 0.99, \
f"Low similarity ({similarity}) for {name}"
num_verified += 1
# Final verification
assert num_verified == total_vectors, \
f"Only verified {num_verified} out of {total_vectors} vectors"

41
tests/vemb.py Normal file
View File

@ -0,0 +1,41 @@
from test import TestCase
import struct
import math
class VEMB(TestCase):
def getname(self):
return "VEMB Command"
def test(self):
dim = 4
# Add same vector in both formats
vec = [1, 0, 0, 0]
norm = math.sqrt(sum(x*x for x in vec))
vec = [x/norm for x in vec] # Normalize the vector
# Add using FP32
vec_bytes = struct.pack(f'{dim}f', *vec)
self.redis.execute_command('VADD', self.test_key, 'FP32', vec_bytes, f'{self.test_key}:item:1')
# Add using VALUES
self.redis.execute_command('VADD', self.test_key, 'VALUES', dim,
*[str(x) for x in vec], f'{self.test_key}:item:2')
# Get both back with VEMB
result1 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:1')
result2 = self.redis.execute_command('VEMB', self.test_key, f'{self.test_key}:item:2')
retrieved_vec1 = [float(x) for x in result1]
retrieved_vec2 = [float(x) for x in result2]
# Compare both vectors with original (allow for small quantization errors)
for i in range(dim):
assert abs(vec[i] - retrieved_vec1[i]) < 0.01, \
f"FP32 vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec1[i]}"
assert abs(vec[i] - retrieved_vec2[i]) < 0.01, \
f"VALUES vector component {i} mismatch: expected {vec[i]}, got {retrieved_vec2[i]}"
# Test non-existent item
result = self.redis.execute_command('VEMB', self.test_key, 'nonexistent')
assert result is None, "Non-existent item should return nil"

1208
vset.c Normal file

File diff suppressed because it is too large Load Diff

315
w2v.c Normal file
View File

@ -0,0 +1,315 @@
/*
* HNSW (Hierarchical Navigable Small World) Implementation
* Based on the paper by Yu. A. Malkov, D. A. Yashunin
*
* Copyright(C) 2024 Salvatore Sanfilippo. All Rights Reserved.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <time.h>
#include <stdint.h>
#include <pthread.h>
#include <stdatomic.h>
#include "hnsw.h"
/* Get current time in milliseconds */
uint64_t ms_time(void) {
struct timeval tv;
gettimeofday(&tv, NULL);
return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
}
/* Example usage in main() */
int w2v_single_thread(int quantization, uint64_t numele, int massdel, int recall) {
/* Create index */
HNSW *index = hnsw_new(300, quantization);
float v[300];
uint16_t wlen;
FILE *fp = fopen("word2vec.bin","rb");
if (fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
fread(header,8,1,fp); // Skip header
uint64_t id = 0;
uint64_t start_time = ms_time();
char *word = NULL;
hnswNode *search_node = NULL;
while(id < numele) {
if (fread(&wlen,2,1,fp) == 0) break;
word = malloc(wlen+1);
fread(word,wlen,1,fp);
word[wlen] = 0;
fread(v,300*sizeof(float),1,fp);
// Plain API that acquires a write lock for the whole time.
hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
if (!strcmp(word,"banana")) search_node = added;
if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
}
uint64_t elapsed = ms_time() - start_time;
fclose(fp);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)index->node_count,
(unsigned long long)id*1000/elapsed, word);
/* Search query */
if (search_node == NULL) search_node = index->head;
hnsw_get_node_vector(index,search_node,v);
hnswNode *neighbors[10];
float distances[10];
int found, j;
start_time = ms_time();
for (j = 0; j < 20000; j++)
found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
elapsed = ms_time() - start_time;
printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
j, (unsigned long long)j*1000/elapsed, found);
if (found > 0) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
// Recall test (slow).
if (recall) {
hnsw_print_stats(index);
hnsw_test_graph_recall(index,200,0);
}
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
if (massdel) {
int remove_perc = 95;
printf("\nRemoving %d%% of nodes...\n", remove_perc);
uint64_t initial_nodes = index->node_count;
hnswNode *current = index->head;
while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
hnswNode *next = current->next;
hnsw_delete_node(index,current,free);
current = next;
// In order to don't remove only contiguous nodes, from time
// skip a node.
if (current && !(random() % remove_perc)) current = current->next;
}
printf("%llu nodes left\n", (unsigned long long)index->node_count);
// Test again.
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
hnsw_test_graph_recall(index,200,0);
}
hnsw_free(index,free);
return 0;
}
struct threadContext {
pthread_mutex_t FileAccessMutex;
uint64_t numele;
_Atomic uint64_t SearchesDone;
_Atomic uint64_t id;
FILE *fp;
HNSW *index;
float *search_vector;
};
// Note that in practical terms inserting with many concurrent threads
// may be *slower* and not faster, because there is a lot of
// contention. So this is more a robustness test than anything else.
//
// The optimistic commit API goal is actually to exploit the ability to
// add faster when there are many concurrent reads.
void *threaded_insert(void *ctxptr) {
struct threadContext *ctx = ctxptr;
char *word;
float v[300];
uint16_t wlen;
while(1) {
pthread_mutex_lock(&ctx->FileAccessMutex);
if (fread(&wlen,2,1,ctx->fp) == 0) break;
pthread_mutex_unlock(&ctx->FileAccessMutex);
word = malloc(wlen+1);
fread(word,wlen,1,ctx->fp);
word[wlen] = 0;
fread(v,300*sizeof(float),1,ctx->fp);
// Check-and-set API that performs the costly scan for similar
// nodes concurrently with other read threads, and finally
// applies the check if the graph wasn't modified.
InsertContext *ic;
uint64_t next_id = ctx->id++;
ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, word, 200);
if (hnsw_try_commit_insert(ctx->index, ic) == NULL) {
// This time try locking since the start.
hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
}
if (next_id >= ctx->numele) break;
if (!((next_id+1) % 10000))
printf("%llu added\n", (unsigned long long)next_id+1);
}
return NULL;
}
void *threaded_search(void *ctxptr) {
struct threadContext *ctx = ctxptr;
/* Search query */
hnswNode *neighbors[10];
float distances[10];
int found = 0;
uint64_t last_id = 0;
while(ctx->id < 1000000) {
int slot = hnsw_acquire_read_slot(ctx->index);
found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
hnsw_release_read_slot(ctx->index,slot);
last_id = ++ctx->id;
}
if (found > 0 && last_id == 1000000) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
return NULL;
}
int w2v_multi_thread(int numthreads, int quantization, uint64_t numele) {
/* Create index */
struct threadContext ctx;
ctx.index = hnsw_new(300,quantization);
ctx.fp = fopen("word2vec.bin","rb");
if (ctx.fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
fread(header,8,1,ctx.fp); // Skip header
pthread_mutex_init(&ctx.FileAccessMutex,NULL);
uint64_t start_time = ms_time();
ctx.id = 0;
ctx.numele = numele;
pthread_t threads[numthreads];
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_insert, &ctx);
// Wait for all the threads to terminate adding items.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
uint64_t elapsed = ms_time() - start_time;
fclose(ctx.fp);
// Obtain the last word.
hnswNode *node = ctx.index->head;
char *word = node->value;
// We will search this last inserted word in the next test.
// Let's save its embedding.
ctx.search_vector = malloc(sizeof(float)*300);
hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)ctx.index->node_count,
(unsigned long long)ctx.id*1000/elapsed, word);
/* Search query */
start_time = ms_time();
ctx.id = 0; // We will use this atomic field to stop at N queries done.
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_search, &ctx);
// Wait for all the threads to terminate searching.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
elapsed = ms_time() - start_time;
printf("%llu searches performed (%llu searches/sec)\n",
(unsigned long long)ctx.id,
(unsigned long long)ctx.id*1000/elapsed);
hnsw_print_stats(ctx.index);
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
printf("%llu connected nodes. Links all reciprocal: %d\n",
(unsigned long long)connected_nodes, reciprocal_links);
hnsw_free(ctx.index,free);
return 0;
}
int main(int argc, char **argv) {
int quantization = HNSW_QUANT_NONE;
int numthreads = 0;
uint64_t numele = 20000;
/* This you can enable in single thread mode for testing: */
int massdel = 0; // If true, does the mass deletion test.
int recall = 0; // If true, does the recall test.
for (int j = 1; j < argc; j++) {
int moreargs = argc-j-1;
if (!strcasecmp(argv[j],"--quant")) {
quantization = HNSW_QUANT_Q8;
} else if (!strcasecmp(argv[j],"--bin")) {
quantization = HNSW_QUANT_BIN;
} else if (!strcasecmp(argv[j],"--mass-del")) {
massdel = 1;
} else if (!strcasecmp(argv[j],"--recall")) {
recall = 1;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
numthreads = atoi(argv[j+1]);
j++;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
numele = strtoll(argv[j+1],NULL,0);
j++;
if (numele < 1) numele = 1;
} else if (!strcasecmp(argv[j],"--help")) {
printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--mass-del] [--recall]\n", argv[0]);
exit(0);
} else {
printf("Unrecognized option: %s\n", argv[j]);
exit(1);
}
}
if (quantization == HNSW_QUANT_NONE) {
printf("You can enable quantization with --quant\n");
}
if (numthreads > 0) {
w2v_multi_thread(numthreads,quantization,numele);
} else {
printf("Single thread execution. Use --threads 4 for concurrent API\n");
w2v_single_thread(quantization,numele,massdel,recall);
}
}