redis/w2v.c

316 lines
9.9 KiB
C

/*
* HNSW (Hierarchical Navigable Small World) Implementation
* Based on the paper by Yu. A. Malkov, D. A. Yashunin
*
* Copyright(C) 2024 Salvatore Sanfilippo. All Rights Reserved.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <time.h>
#include <stdint.h>
#include <pthread.h>
#include <stdatomic.h>
#include "hnsw.h"
/* Get current time in milliseconds */
uint64_t ms_time(void) {
struct timeval tv;
gettimeofday(&tv, NULL);
return (uint64_t)tv.tv_sec * 1000 + (tv.tv_usec / 1000);
}
/* Example usage in main() */
int w2v_single_thread(int quantization, uint64_t numele, int massdel, int recall) {
/* Create index */
HNSW *index = hnsw_new(300, quantization);
float v[300];
uint16_t wlen;
FILE *fp = fopen("word2vec.bin","rb");
if (fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
fread(header,8,1,fp); // Skip header
uint64_t id = 0;
uint64_t start_time = ms_time();
char *word = NULL;
hnswNode *search_node = NULL;
while(id < numele) {
if (fread(&wlen,2,1,fp) == 0) break;
word = malloc(wlen+1);
fread(word,wlen,1,fp);
word[wlen] = 0;
fread(v,300*sizeof(float),1,fp);
// Plain API that acquires a write lock for the whole time.
hnswNode *added = hnsw_insert(index, v, NULL, 0, id++, word, 200);
if (!strcmp(word,"banana")) search_node = added;
if (!(id % 10000)) printf("%llu added\n", (unsigned long long)id);
}
uint64_t elapsed = ms_time() - start_time;
fclose(fp);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)index->node_count,
(unsigned long long)id*1000/elapsed, word);
/* Search query */
if (search_node == NULL) search_node = index->head;
hnsw_get_node_vector(index,search_node,v);
hnswNode *neighbors[10];
float distances[10];
int found, j;
start_time = ms_time();
for (j = 0; j < 20000; j++)
found = hnsw_search(index, v, 10, neighbors, distances, 0, 0);
elapsed = ms_time() - start_time;
printf("%d searches performed (%llu searches/sec), nodes found: %d\n",
j, (unsigned long long)j*1000/elapsed, found);
if (found > 0) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
// Recall test (slow).
if (recall) {
hnsw_print_stats(index);
hnsw_test_graph_recall(index,200,0);
}
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
if (massdel) {
int remove_perc = 95;
printf("\nRemoving %d%% of nodes...\n", remove_perc);
uint64_t initial_nodes = index->node_count;
hnswNode *current = index->head;
while (current && index->node_count > initial_nodes*(100-remove_perc)/100) {
hnswNode *next = current->next;
hnsw_delete_node(index,current,free);
current = next;
// In order to don't remove only contiguous nodes, from time
// skip a node.
if (current && !(random() % remove_perc)) current = current->next;
}
printf("%llu nodes left\n", (unsigned long long)index->node_count);
// Test again.
hnsw_validate_graph(index, &connected_nodes, &reciprocal_links);
hnsw_test_graph_recall(index,200,0);
}
hnsw_free(index,free);
return 0;
}
struct threadContext {
pthread_mutex_t FileAccessMutex;
uint64_t numele;
_Atomic uint64_t SearchesDone;
_Atomic uint64_t id;
FILE *fp;
HNSW *index;
float *search_vector;
};
// Note that in practical terms inserting with many concurrent threads
// may be *slower* and not faster, because there is a lot of
// contention. So this is more a robustness test than anything else.
//
// The optimistic commit API goal is actually to exploit the ability to
// add faster when there are many concurrent reads.
void *threaded_insert(void *ctxptr) {
struct threadContext *ctx = ctxptr;
char *word;
float v[300];
uint16_t wlen;
while(1) {
pthread_mutex_lock(&ctx->FileAccessMutex);
if (fread(&wlen,2,1,ctx->fp) == 0) break;
pthread_mutex_unlock(&ctx->FileAccessMutex);
word = malloc(wlen+1);
fread(word,wlen,1,ctx->fp);
word[wlen] = 0;
fread(v,300*sizeof(float),1,ctx->fp);
// Check-and-set API that performs the costly scan for similar
// nodes concurrently with other read threads, and finally
// applies the check if the graph wasn't modified.
InsertContext *ic;
uint64_t next_id = ctx->id++;
ic = hnsw_prepare_insert(ctx->index, v, NULL, 0, next_id, word, 200);
if (hnsw_try_commit_insert(ctx->index, ic) == NULL) {
// This time try locking since the start.
hnsw_insert(ctx->index, v, NULL, 0, next_id, word, 200);
}
if (next_id >= ctx->numele) break;
if (!((next_id+1) % 10000))
printf("%llu added\n", (unsigned long long)next_id+1);
}
return NULL;
}
void *threaded_search(void *ctxptr) {
struct threadContext *ctx = ctxptr;
/* Search query */
hnswNode *neighbors[10];
float distances[10];
int found = 0;
uint64_t last_id = 0;
while(ctx->id < 1000000) {
int slot = hnsw_acquire_read_slot(ctx->index);
found = hnsw_search(ctx->index, ctx->search_vector, 10, neighbors, distances, slot, 0);
hnsw_release_read_slot(ctx->index,slot);
last_id = ++ctx->id;
}
if (found > 0 && last_id == 1000000) {
printf("Found %d neighbors:\n", found);
for (int i = 0; i < found; i++) {
printf("Node ID: %llu, distance: %f, word: %s\n",
(unsigned long long)neighbors[i]->id,
distances[i], (char*)neighbors[i]->value);
}
}
return NULL;
}
int w2v_multi_thread(int numthreads, int quantization, uint64_t numele) {
/* Create index */
struct threadContext ctx;
ctx.index = hnsw_new(300,quantization);
ctx.fp = fopen("word2vec.bin","rb");
if (ctx.fp == NULL) {
perror("word2vec.bin file missing");
exit(1);
}
unsigned char header[8];
fread(header,8,1,ctx.fp); // Skip header
pthread_mutex_init(&ctx.FileAccessMutex,NULL);
uint64_t start_time = ms_time();
ctx.id = 0;
ctx.numele = numele;
pthread_t threads[numthreads];
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_insert, &ctx);
// Wait for all the threads to terminate adding items.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
uint64_t elapsed = ms_time() - start_time;
fclose(ctx.fp);
// Obtain the last word.
hnswNode *node = ctx.index->head;
char *word = node->value;
// We will search this last inserted word in the next test.
// Let's save its embedding.
ctx.search_vector = malloc(sizeof(float)*300);
hnsw_get_node_vector(ctx.index,node,ctx.search_vector);
printf("%llu words added (%llu words/sec), last word: %s\n",
(unsigned long long)ctx.index->node_count,
(unsigned long long)ctx.id*1000/elapsed, word);
/* Search query */
start_time = ms_time();
ctx.id = 0; // We will use this atomic field to stop at N queries done.
for (int j = 0; j < numthreads; j++)
pthread_create(&threads[j], NULL, threaded_search, &ctx);
// Wait for all the threads to terminate searching.
for (int j = 0; j < numthreads; j++)
pthread_join(threads[j],NULL);
elapsed = ms_time() - start_time;
printf("%llu searches performed (%llu searches/sec)\n",
(unsigned long long)ctx.id,
(unsigned long long)ctx.id*1000/elapsed);
hnsw_print_stats(ctx.index);
uint64_t connected_nodes;
int reciprocal_links;
hnsw_validate_graph(ctx.index, &connected_nodes, &reciprocal_links);
printf("%llu connected nodes. Links all reciprocal: %d\n",
(unsigned long long)connected_nodes, reciprocal_links);
hnsw_free(ctx.index,free);
return 0;
}
int main(int argc, char **argv) {
int quantization = HNSW_QUANT_NONE;
int numthreads = 0;
uint64_t numele = 20000;
/* This you can enable in single thread mode for testing: */
int massdel = 0; // If true, does the mass deletion test.
int recall = 0; // If true, does the recall test.
for (int j = 1; j < argc; j++) {
int moreargs = argc-j-1;
if (!strcasecmp(argv[j],"--quant")) {
quantization = HNSW_QUANT_Q8;
} else if (!strcasecmp(argv[j],"--bin")) {
quantization = HNSW_QUANT_BIN;
} else if (!strcasecmp(argv[j],"--mass-del")) {
massdel = 1;
} else if (!strcasecmp(argv[j],"--recall")) {
recall = 1;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--threads")) {
numthreads = atoi(argv[j+1]);
j++;
} else if (moreargs >= 1 && !strcasecmp(argv[j],"--numele")) {
numele = strtoll(argv[j+1],NULL,0);
j++;
if (numele < 1) numele = 1;
} else if (!strcasecmp(argv[j],"--help")) {
printf("%s [--quant] [--bin] [--thread <count>] [--numele <count>] [--mass-del] [--recall]\n", argv[0]);
exit(0);
} else {
printf("Unrecognized option: %s\n", argv[j]);
exit(1);
}
}
if (quantization == HNSW_QUANT_NONE) {
printf("You can enable quantization with --quant\n");
}
if (numthreads > 0) {
w2v_multi_thread(numthreads,quantization,numele);
} else {
printf("Single thread execution. Use --threads 4 for concurrent API\n");
w2v_single_thread(quantization,numele,massdel,recall);
}
}