redis/modules/vector-sets/w2v.c

540 lines
18 KiB
C

/*
* HNSW (Hierarchical Navigable Small World) Implementation
* Based on the paper by Yu. A. Malkov, D. A. Yashunin
*
* Copyright (c) 2009-Present, Redis Ltd.
* All rights reserved.
*
* Licensed under your choice of (a) the Redis Source Available License 2.0
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
* GNU Affero General Public License v3 (AGPLv3).
* Originally authored by: Salvatore Sanfilippo
*/
#define _DEFAULT_SOURCE
#define _USE_MATH_DEFINES
#define _POSIX_C_SOURCE 200809L
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <sys/time.h>
#include <time.h>
#include <stdint.h>
#include <pthread.h>
#include <stdatomic.h>
#include <math.h>
#include "hnsw.h"
/* Get current time in milliseconds */
uint64_t ms_time(void) {
struct timeval tv;
gettimeofday(&tv, NULL);
return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
}
/* Implementation of the recall test with random vectors. */
void test_recall(HNSW *index, int ef) {
const int num_test_vectors = 10000;
const int k = 100; // Number of nearest neighbors to find.
if (ef < k) ef = k;
// Add recall distribution counters (2% bins from 0-100%).
int recall_bins[50] = {0};
// Create array to store vectors for mixing.
int num_source_vectors = 1000; // Enough, since we mix them.
float **source_vectors = malloc(sizeof(float*) * num_source_vectors);
if (!source_vectors) {
printf("Failed to allocate memory for source vectors\n");
return;
}
// Allocate memory for each source vector.
for (int i = 0; i < num_source_vectors; i++) {
source_vectors[i] = malloc(sizeof(float) * 300);
if (!source_vectors[i]) {
printf("Failed to allocate memory for source vector %d\n", i);
// Clean up already allocated vectors.
for (int j = 0; j < i; j++) free(source_vectors[j]);
free(source_vectors);
return;
}
}
/* Populate source vectors from the index, we just scan the
* first N items. */
int source_count = 0;
hnswNode *current = index->head;
while (current && source_count < num_source_vectors) {
hnsw_get_node_vector(index, current, source_vectors[source_count]);
source_count++;
current = current->next;
}
if (source_count < num_source_vectors) {
printf("Warning: Only found %d nodes for source vectors\n",
source_count);
num_source_vectors = source_count;
}
// Allocate memory for test vector.
float *test_vector = malloc(sizeof(float) * 300);
if (!test_vector) {
printf("Failed to allocate memory for test vector\n");
for (int i = 0; i < num_source_vectors; i++) {
free(source_vectors[i]);
}
free(source_vectors);
return;
}
// Allocate memory for results.
hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef);
hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef);
float *hnsw_distances = malloc(sizeof(float) * ef);
float *linear_distances = malloc(sizeof(float) * ef);
if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) {
printf("Failed to allocate memory for results\n");
if (hnsw_results) free(hnsw_results);
if (linear_results) free(linear_results);
if (hnsw_distances) free(hnsw_distances);
if (linear_distances) free(linear_distances);
for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
free(source_vectors);
free(test_vector);
return;
}
// Initialize random seed.
srand(time(NULL));
// Perform recall test.
printf("\nPerforming recall test with EF=%d on %d random vectors...\n",
ef, num_test_vectors);
double total_recall = 0.0;
for (int t = 0; t < num_test_vectors; t++) {
// Create a random vector by mixing 3 existing vectors.
float weights[3] = {0.0};
int src_indices[3] = {0};
// Generate random weights.
float weight_sum = 0.0;
for (int i = 0; i < 3; i++) {
weights[i] = (float)rand() / RAND_MAX;
weight_sum += weights[i];
src_indices[i] = rand() % num_source_vectors;
}
// Normalize weights.
for (int i = 0; i < 3; i++) weights[i] /= weight_sum;
// Mix vectors.
memset(test_vector, 0, sizeof(float) * 300);
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 300; j++) {
test_vector[j] +=
weights[i] * source_vectors[src_indices[i]][j];
}
}
// Perform HNSW search with the specified EF parameter.
int slot = hnsw_acquire_read_slot(index);
int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0);
// Perform linear search (ground truth).
int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL);
hnsw_release_read_slot(index, slot);
// Calculate recall for this query (intersection size / k).
if (hnsw_found > k) hnsw_found = k;
if (linear_found > k) linear_found = k;
int intersection_count = 0;
for (int i = 0; i < linear_found; i++) {
for (int j = 0; j < hnsw_found; j++) {
if (linear_results[i] == hnsw_results[j]) {
intersection_count++;
break;
}
}
}
double recall = (double)intersection_count / linear_found;
total_recall += recall;
// Add to distribution bins (2% steps)
int bin_index = (int)(recall * 50);
if (bin_index >= 50) bin_index = 49; // Handle 100% recall case
recall_bins[bin_index]++;
// Show progress.
if ((t+1) % 1000 == 0 || t == num_test_vectors-1) {
printf("Processed %d/%d queries, current avg recall: %.2f%%\n",
t+1, num_test_vectors, (total_recall / (t+1)) * 100);
}
}
// Calculate and print final average recall.
double avg_recall = (total_recall / num_test_vectors) * 100;
printf("\nRecall Test Results:\n");
printf("Average recall@%d (EF=%d): %.2f%%\n", k, ef, avg_recall);
// Print recall distribution histogram.
printf("\nRecall Distribution (2%% bins):\n");
printf("================================\n");
// Find the maximum bin count for scaling.
int max_count = 0;
for (int i = 0; i < 50; i++) {
if (recall_bins[i] > max_count) max_count = recall_bins[i];
}
// Scale factor for histogram (max 50 chars wide)
const int max_bars = 50;
double scale = (max_count > max_bars) ? (double)max_bars / max_count : 1.0;
// Print the histogram.
for (int i = 0; i < 50; i++) {
int bar_len = (int)(recall_bins[i] * scale);
printf("%3d%%-%-3d%% | %-6d |", i*2, (i+1)*2, recall_bins[i]);
for (int j = 0; j < bar_len; j++) printf("#");
printf("\n");
}
// Cleanup.
free(hnsw_results);
free(linear_results);
free(hnsw_distances);
free(linear_distances);
free(test_vector);
for (int i = 0; i < num_source_vectors; i++) free(source_vectors[i]);
free(source_vectors);
}
/* Example usage in main() */
int w2v_single_thread(int m_param, int quantization, uint64_t numele, int massdel, int self_recall, int recall_ef) {
/* Create index */
HNSW *index = hnsw_new(300, quantization, m_param);
float v[300];
uint16_t wlen;
FILE *fp = fopen("word2vec.bin","rb");
if (fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
if (fread(header,8,1,fp) <= 0) { // Skip header
perror("Unexpected EOF");
exit(1);
}
uint64_t id = 0;
uint64_t start_time = ms_time();
char *word = NULL;
hnswNode *search_node = NULL;
while(id < numele) {
if (fread(&wlen,2,1,fp) == 0) break;
word = malloc(wlen+1);
if (fread(word,wlen,1,fp) <= 0) {
perror("unexpected EOF");
exit(1);
}
word[wlen] = 0;
if (fread(v,300*sizeof(float),1,fp) <= 0) {
perror("unexpected EOF");
exit(1);
}
// Plain API that acquires a write lock for the whole time.
hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
if (!strcmp(word,"banana")) search_node = added;
if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
}
uint64_t elapsed = ms_time() - start_time;
fclose(fp);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)index->node_count,
(unsigned long long)id*1000/elapsed, word);
/* Search query */
if (search_node == NULL) search_node = index->head;
hnsw_get_node_vector(index,search_node,v);
hnswNode *neighbors[10];
float distances[10];
int found, j;
start_time = ms_time();
for (j = 0; j < 20000; j++)
found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
elapsed = ms_time() - start_time;
printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
j, (unsigned long long)j*1000/elapsed, found);
if (found > 0) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
// Self-recall test (ability to find the node by its own vector).
if (self_recall) {
hnsw_print_stats(index);
hnsw_test_graph_recall(index,200,0);
}
// Recall test with random vectors.
if (recall_ef > 0) {
test_recall(index, recall_ef);
}
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
if (massdel) {
int remove_perc = 95;
printf("\nRemoving %d%% of nodes...\n", remove_perc);
uint64_t initial_nodes = index->node_count;
hnswNode *current = index->head;
while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
hnswNode *next = current->next;
hnsw_delete_node(index,current,free);
current = next;
// In order to don't remove only contiguous nodes, from time
// skip a node.
if (current && !(random() % remove_perc)) current = current->next;
}
printf("%llu nodes left\n", (unsigned long long)index->node_count);
// Test again.
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
hnsw_test_graph_recall(index,200,0);
}
hnsw_free(index,free);
return 0;
}
struct threadContext {
pthread_mutex_t FileAccessMutex;
uint64_t numele;
_Atomic uint64_t SearchesDone;
_Atomic uint64_t id;
FILE *fp;
HNSW *index;
float *search_vector;
};
// Note that in practical terms inserting with many concurrent threads
// may be *slower* and not faster, because there is a lot of
// contention. So this is more a robustness test than anything else.
//
// The optimistic commit API goal is actually to exploit the ability to
// add faster when there are many concurrent reads.
void *threaded_insert(void *ctxptr) {
struct threadContext *ctx = ctxptr;
char *word;
float v[300];
uint16_t wlen;
while(1) {
pthread_mutex_lock(&ctx->FileAccessMutex);
if (fread(&wlen,2,1,ctx->fp) == 0) break;
pthread_mutex_unlock(&ctx->FileAccessMutex);
word = malloc(wlen+1);
if (fread(word,wlen,1,ctx->fp) <= 0) {
perror("Unexpected EOF");
exit(1);
}
word[wlen] = 0;
if (fread(v,300*sizeof(float),1,ctx->fp) <= 0) {
perror("Unexpected EOF");
exit(1);
}
// Check-and-set API that performs the costly scan for similar
// nodes concurrently with other read threads, and finally
// applies the check if the graph wasn't modified.
InsertContext *ic;
uint64_t next_id = ctx->id++;
ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, 200);
if (hnsw_try_commit_insert(ctx->index, ic, word) == NULL) {
// This time try locking since the start.
hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
}
if (next_id >= ctx->numele) break;
if (!((next_id+1) % 10000))
printf("%llu added\n", (unsigned long long)next_id+1);
}
return NULL;
}
void *threaded_search(void *ctxptr) {
struct threadContext *ctx = ctxptr;
/* Search query */
hnswNode *neighbors[10];
float distances[10];
int found = 0;
uint64_t last_id = 0;
while(ctx->id < 1000000) {
int slot = hnsw_acquire_read_slot(ctx->index);
found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
hnsw_release_read_slot(ctx->index,slot);
last_id = ++ctx->id;
}
if (found > 0 && last_id == 1000000) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
return NULL;
}
int w2v_multi_thread(int m_param, int numthreads, int quantization, uint64_t numele) {
/* Create index */
struct threadContext ctx;
ctx.index = hnsw_new(300, quantization, m_param);
ctx.fp = fopen("word2vec.bin","rb");
if (ctx.fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
if (fread(header,8,1,ctx.fp) <= 0) { // Skip header
perror("Unexpected EOF");
exit(1);
}
pthread_mutex_init(&ctx.FileAccessMutex,NULL);
uint64_t start_time = ms_time();
ctx.id = 0;
ctx.numele = numele;
pthread_t threads[numthreads];
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_insert, &ctx);
// Wait for all the threads to terminate adding items.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
uint64_t elapsed = ms_time() - start_time;
fclose(ctx.fp);
// Obtain the last word.
hnswNode *node = ctx.index->head;
char *word = node->value;
// We will search this last inserted word in the next test.
// Let's save its embedding.
ctx.search_vector = malloc(sizeof(float)*300);
hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)ctx.index->node_count,
(unsigned long long)ctx.id*1000/elapsed, word);
/* Search query */
start_time = ms_time();
ctx.id = 0; // We will use this atomic field to stop at N queries done.
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_search, &ctx);
// Wait for all the threads to terminate searching.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
elapsed = ms_time() - start_time;
printf("%llu searches performed (%llu searches/sec)\n",
(unsigned long long)ctx.id,
(unsigned long long)ctx.id*1000/elapsed);
hnsw_print_stats(ctx.index);
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
printf("%llu connected nodes. Links all reciprocal: %d\n",
(unsigned long long)connected_nodes, reciprocal_links);
hnsw_free(ctx.index,free);
return 0;
}
int main(int argc, char **argv) {
int quantization = HNSW_QUANT_NONE;
int numthreads = 0;
uint64_t numele = 20000;
int m_param = 0; // Default value (0 means use HNSW_DEFAULT_M)
/* This you can enable in single thread mode for testing: */
int massdel = 0; // If true, does the mass deletion test.
int self_recall = 0; // If true, does the self-recall test.
int recall_ef = 0; // If not 0, does the recall test with this EF value.
for (int j = 1; j < argc; j++) {
int moreargs = argc-j-1;
if (!strcasecmp(argv[j],"--quant")) {
quantization = HNSW_QUANT_Q8;
} else if (!strcasecmp(argv[j],"--bin")) {
quantization = HNSW_QUANT_BIN;
} else if (!strcasecmp(argv[j],"--mass-del")) {
massdel = 1;
} else if (!strcasecmp(argv[j],"--self-recall")) {
self_recall = 1;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--recall")) {
recall_ef = atoi(argv[j+1]);
j++;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
numthreads = atoi(argv[j+1]);
j++;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
numele = strtoll(argv[j+1],NULL,0);
j++;
if (numele < 1) numele = 1;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--m")) {
m_param = atoi(argv[j+1]);
j++;
} else if (!strcasecmp(argv[j],"--help")) {
printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--m <count>] [--mass-del] [--self-recall] [--recall <ef>]\n", argv[0]);
exit(0);
} else {
printf("Unrecognized option or wrong number of arguments: %s\n", argv[j]);
exit(1);
}
}
if (quantization == HNSW_QUANT_NONE) {
printf("You can enable quantization with --quant\n");
}
if (numthreads > 0) {
w2v_multi_thread(m_param, numthreads, quantization, numele);
} else {
printf("Single thread execution. Use --threads 4 for concurrent API\n");
w2v_single_thread(m_param, quantization, numele, massdel, self_recall, recall_ef);
}
}