From 68d30671256918c83d40c31c5c94eaed23d558ef Mon Sep 17 00:00:00 2001 From: antirez Date: Sat, 15 Mar 2025 10:24:20 +0100 Subject: [PATCH] w2v test: fix recall EF usage. --- w2v.c | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/w2v.c b/w2v.c index 339c74bbb..a8b21dcf4 100644 --- a/w2v.c +++ b/w2v.c @@ -29,6 +29,7 @@ uint64_t ms_time(void) { void test_recall(HNSW *index, int ef) { const int num_test_vectors = 10000; const int k = 100; // Number of nearest neighbors to find. + if (ef < k) ef = k; // Add recall distribution counters (2% bins from 0-100%). int recall_bins[50] = {0}; @@ -81,10 +82,10 @@ void test_recall(HNSW *index, int ef) { } // Allocate memory for results. - hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * k); - hnswNode **linear_results = malloc(sizeof(hnswNode*) * k); - float *hnsw_distances = malloc(sizeof(float) * k); - float *linear_distances = malloc(sizeof(float) * k); + hnswNode **hnsw_results = malloc(sizeof(hnswNode*) * ef); + hnswNode **linear_results = malloc(sizeof(hnswNode*) * ef); + float *hnsw_distances = malloc(sizeof(float) * ef); + float *linear_distances = malloc(sizeof(float) * ef); if (!hnsw_results || !linear_results || !hnsw_distances || !linear_distances) { printf("Failed to allocate memory for results\n"); @@ -133,13 +134,15 @@ void test_recall(HNSW *index, int ef) { // Perform HNSW search with the specified EF parameter. int slot = hnsw_acquire_read_slot(index); - int hnsw_found = hnsw_search_with_filter(index, test_vector, k, hnsw_results, hnsw_distances, slot, 0, NULL, NULL, ef); + int hnsw_found = hnsw_search(index, test_vector, ef, hnsw_results, hnsw_distances, slot, 0); // Perform linear search (ground truth). - int linear_found = hnsw_ground_truth_with_filter(index, test_vector, k, linear_results, linear_distances, slot, 0, NULL, NULL); + int linear_found = hnsw_ground_truth_with_filter(index, test_vector, ef, linear_results, linear_distances, slot, 0, NULL, NULL); hnsw_release_read_slot(index, slot); // Calculate recall for this query (intersection size / k). + if (hnsw_found > k) hnsw_found = k; + if (linear_found > k) linear_found = k; int intersection_count = 0; for (int i = 0; i < linear_found; i++) { for (int j = 0; j < hnsw_found; j++) {