From f829d46535137b6cd08923c1bcaf4013546b6f9a Mon Sep 17 00:00:00 2001 From: antirez Date: Sat, 8 Mar 2025 16:15:15 +0100 Subject: [PATCH] HNSW: creation time M parameter VS hardcoded. --- hnsw.c | 49 ++++++++++++++++++++++++++----------------------- hnsw.h | 8 ++++++-- vset.c | 2 +- w2v.c | 4 ++-- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/hnsw.c b/hnsw.c index 14135bb2b..779245327 100644 --- a/hnsw.c +++ b/hnsw.c @@ -54,12 +54,6 @@ /* Algorithm parameters. */ -#define HNSW_M 16 /* Number of max connections per node. Note that - * layer zero has twice as many. Also note that - * when a new node is added, we will populate - * even layer 0 links to just HNSW_M neighbors, so - * initially half layer 0 slots will be empty. */ -#define HNSW_M0 (HNSW_M*2) /* Maximum number of connections for layer 0 */ #define HNSW_P 0.25 /* Probability of level increase. */ #define HNSW_MAX_LEVEL 16 /* Max level nodes can reach. */ #define HNSW_EF_C 200 /* Default size of dynamic candidate list while @@ -68,7 +62,8 @@ * used when deleting nodes for the search step * needed sometimes to reconnect nodes that remain * orphaned of one link. */ - +#define HNSW_DEFAULT_M 16 /* Useful if 0 is given at creation time. */ +#define HNSW_MAX_M 1024 /* Hard limit for M. */ static void (*hfree)(void *p) = free; static void *(*hmalloc)(size_t s) = malloc; @@ -391,10 +386,15 @@ uint32_t random_level() { } /* Create new HNSW index, quantized or not. */ -HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type) { +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m) { HNSW *index = hmalloc(sizeof(HNSW)); if (!index) return NULL; + /* M parameter sanity check. */ + if (m == 0) m = HNSW_DEFAULT_M; + else if (m > HNSW_MAX_M) m = HNSW_MAX_M; + + index->M = m; index->quant_type = quant_type; index->enter_point = NULL; index->max_level = 0; @@ -556,7 +556,7 @@ hnswNode *hnsw_node_new(HNSW *index, uint64_t id, const float *vector, const int /* Initialize each layer. */ for (uint32_t i = 0; i <= level; i++) { - uint32_t max_links = (i == 0) ? HNSW_M0 : HNSW_M; + uint32_t max_links = (i == 0) ? index->M*2 : index->M; node->layers[i].max_links = max_links; node->layers[i].num_links = 0; node->layers[i].worst_distance = 0; @@ -939,7 +939,7 @@ void hnsw_update_worst_neighbor_on_remove(HNSW *index, hnswNode *node, uint32_t void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, uint32_t layer, uint32_t required_links, int aggressive) { - uint32_t max_links = (layer == 0) ? HNSW_M0 : HNSW_M; + uint32_t max_links = (layer == 0) ? index->M*2 : index->M; for (uint32_t i = 0; i < candidates->count; i++) { hnswNode *neighbor = pq_get_node(candidates,i); @@ -1031,13 +1031,13 @@ void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, * if we remove the candidate node as its link. Let's check if * this is the case: */ if (aggressive == 0 && - worst_node->layers[layer].num_links <= HNSW_M/2) + worst_node->layers[layer].num_links <= index->M/2) continue; /* Aggressive level = 1. It's ok if the node remains with just * HNSW_M/4 links. */ else if (aggressive == 1 && - worst_node->layers[layer].num_links <= HNSW_M/4) + worst_node->layers[layer].num_links <= index->M/4) continue; /* If aggressive is set to 2, then the new node we are adding failed @@ -1045,7 +1045,8 @@ void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, * node, so let's see if the target node has some other link * that is well connected in the graph: we could drop it instead * of the worst link. */ - if (aggressive == 2 && worst_node->layers[layer].num_links <= HNSW_M/4) + if (aggressive == 2 && worst_node->layers[layer].num_links <= + index->M/4) { /* Let's see if we can find at least a candidate link that * would remain with a few connections. Track the one @@ -1059,13 +1060,13 @@ void select_neighbors(HNSW *index, pqueue *candidates, hnswNode *new_node, /* Skip this if it would remain too disconnected as well. * - * NOTE about HNSW_M/4 min connections requirement: + * NOTE about index->M/4 min connections requirement: * * It is not too strict, since leaving a node with just a * single link does not just leave it too weakly connected, but * also sometimes creates cycles with few disconnected * nodes linked among them. */ - if (to_drop->layers[layer].num_links <= HNSW_M/4) continue; + if (to_drop->layers[layer].num_links <= index->M/4) continue; float link_dist = hnsw_distance(index, neighbor, to_drop); if (worst_node == NULL || link_dist > max_dist) { @@ -1401,8 +1402,10 @@ void hnsw_reconnect_nodes(HNSW *index, hnswNode **nodes, int count, uint32_t lay /* Try to connect with aggressiveness proportional to the * node linking condition. */ int aggressiveness = - nodes[i]->layers[layer].num_links > HNSW_M / 2 ? 1 : 2; - select_neighbors(index, candidates, nodes[i], layer, wanted_links, aggressiveness); + (nodes[i]->layers[layer].num_links > index->M / 2) + ? 1 : 2; + select_neighbors(index, candidates, nodes[i], layer, + wanted_links, aggressiveness); debugmsg("Final links with broader search: %d (wanted: %d)\n", (int)nodes[i]->layers[layer].num_links, wanted_links); pq_free(candidates); } @@ -1729,25 +1732,25 @@ hnswNode *hnsw_commit_insert_nolock(HNSW *index, InsertContext *ctx) { for (int lc = MIN(node->level,index->max_level); lc >= 0; lc--) { if (ctx->level_queues[lc] == NULL) continue; - /* Try to provide HNSW_M connections to our node. The call + /* Try to provide index->M connections to our node. The call * is not guaranteed to be able to provide all the links we would * like to have for the new node: they must be bi-directional, obey * certain quality checks, and so forth, so later there are further * calls to force the hand a bit if needed. * * Let's start with aggressiveness = 0. */ - select_neighbors(index, ctx->level_queues[lc], node, lc, HNSW_M, 0); + select_neighbors(index, ctx->level_queues[lc], node, lc, index->M, 0); /* Layer 0 and too few connections? Let's be more aggressive. */ - if (lc == 0 && node->layers[0].num_links < HNSW_M/2) { + if (lc == 0 && node->layers[0].num_links < index->M/2) { select_neighbors(index, ctx->level_queues[lc], node, lc, - HNSW_M, 1); + index->M, 1); /* Still too few connections? Let's go to * aggressiveness level '2' in linking strategy. */ - if (node->layers[0].num_links < HNSW_M/4) { + if (node->layers[0].num_links < index->M/4) { select_neighbors(index, ctx->level_queues[lc], node, lc, - HNSW_M/4, 2); + index->M/4, 2); } } } diff --git a/hnsw.h b/hnsw.h index 7f49a2b2b..ee6186785 100644 --- a/hnsw.h +++ b/hnsw.h @@ -27,7 +27,8 @@ typedef struct { * reallocate the node in very particular * conditions in order to allow linking of * new inserted nodes, so this may change - * dynamically for a small set of nodes. */ + * dynamically and be > M*2 for a small set of + * nodes. */ float worst_distance; /* Distance to the worst neighbor */ uint32_t worst_idx; /* Index of the worst neighbor */ } hnswNodeLayer; @@ -74,6 +75,9 @@ typedef struct hnswCursor { /* Main HNSW index structure */ typedef struct HNSW { hnswNode *enter_point; /* Entry point for the graph */ + uint32_t M; /* M as in the paper: layer 0 has M*2 max + neighbors (M populated at insertion time) + while all the other layers have M neighbors. */ uint32_t max_level; /* Current maximum level in the graph */ uint32_t vector_dim; /* Dimensionality of stored vectors */ uint64_t node_count; /* Total number of nodes */ @@ -110,7 +114,7 @@ typedef struct hnswSerNode { typedef struct InsertContext InsertContext; /* Core HNSW functions */ -HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type); +HNSW *hnsw_new(uint32_t vector_dim, uint32_t quant_type, uint32_t m); void hnsw_free(HNSW *index,void(*free_value)(void*value)); void hnsw_node_free(hnswNode *node); void hnsw_print_stats(HNSW *index); diff --git a/vset.c b/vset.c index 034d8fb97..310ef462e 100644 --- a/vset.c +++ b/vset.c @@ -106,7 +106,7 @@ struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type) if (!o) return NULL; o->id = VectorSetTypeNextId++; - o->hnsw = hnsw_new(dim,quant_type); + o->hnsw = hnsw_new(dim,quant_type,0); if (!o->hnsw) { RedisModule_Free(o); return NULL; diff --git a/w2v.c b/w2v.c index 012dfd95c..f4caa4ba5 100644 --- a/w2v.c +++ b/w2v.c @@ -26,7 +26,7 @@ uint64_t ms_time(void) { /* Example usage in main() */ int w2v_single_thread(int quantization, uint64_t numele, int massdel, int recall) { /* Create index */ - HNSW *index = hnsw_new(300, quantization); + HNSW *index = hnsw_new(300, quantization, 0); float v[300]; uint16_t wlen; @@ -201,7 +201,7 @@ int w2v_multi_thread(int numthreads, int quantization, uint64_t numele) { /* Create index */ struct threadContext ctx; - ctx.index = hnsw_new(300,quantization); + ctx.index = hnsw_new(300,quantization,0); ctx.fp = fopen("word2vec.bin","rb"); if (ctx.fp == NULL) {