Expr filtering: implement HNSW filter in search_layer().

This commit is contained in:
antirez 2025-02-21 14:52:23 +01:00
parent 438adc917b
commit 025790fc50
1 changed files with 59 additions and 17 deletions

76
hnsw.c
View File

@ -618,9 +618,19 @@ void hnsw_add_node(HNSW *index, hnswNode *node) {
}
/* Search the specified layer starting from the specified entry point
* to collect 'ef' nodes that are near to 'query'. */
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot)
* to collect 'ef' nodes that are near to 'query'.
*
* This function implements optional hybrid search, so that each node
* can be accepted or not based on its associated value. In this case
* a callback 'filter_callback' should be passed, together with a maximum
* effort for the search (number of candidates to evaluate), since even
* with a a low "EF" value we risk that there are too few nodes that satisfy
* the provided filter, and we could trigger a full scan. */
pqueue *search_layer_with_filter(
HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot,
int (*filter_callback)(void *value, void *privdata),
void *filter_privdata, uint32_t max_candidates)
{
// Mark visited nodes with a never seen epoch.
index->current_epoch[slot]++;
@ -633,27 +643,33 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
return NULL;
}
// Take track of the total effort: only used when filtering via
// a callback to have a bound effort.
uint32_t evaluated_candidates = 1;
// Add entry point.
float dist = hnsw_distance(index, query, entry_point);
pq_push(candidates, entry_point, dist);
pq_push(results, entry_point, dist);
if (filter_callback == NULL ||
filter_callback(entry_point->value, filter_privdata))
{
pq_push(results, entry_point, dist);
}
entry_point->visited_epoch[slot] = index->current_epoch[slot];
// Process candidates.
while (candidates->count > 0) {
// Pop closest element and use its saved distance.
// Max effort. If zero, we keep scanning.
if (filter_callback &&
max_candidates &&
evaluated_candidates >= max_candidates) break;
float cur_dist;
hnswNode *current = pq_pop(candidates, &cur_dist);
evaluated_candidates++;
/* Stop if we can't get better results. Note that this can
* be true only if we already collected 'ef' elements in
* the priority queue. This is why: if we have less than EF
* elements, later in the for loop that checks the neighbors we
* add new elements BOTH in the results and candidates pqueue: this
* means that before accumulating EF elements, the worst candidate
* can be as bad as the worst result, but not worse. */
float furthest = pq_max_distance(results);
if (cur_dist > furthest) break;
if (results->count >= ef && cur_dist > furthest) break;
/* Check neighbors. */
for (uint32_t i = 0; i < current->layers[layer].num_links; i++) {
@ -664,11 +680,29 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
neighbor->visited_epoch[slot] = index->current_epoch[slot];
float neighbor_dist = hnsw_distance(index, query, neighbor);
// Add to results if better than current max or results not full.
furthest = pq_max_distance(results);
if (neighbor_dist < furthest || results->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
pq_push(results, neighbor, neighbor_dist);
if (filter_callback == NULL) {
/* Original HNSW logic when no filtering:
* Add to results if better than current max or
* results not full. */
if (neighbor_dist < furthest || results->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
pq_push(results, neighbor, neighbor_dist);
}
} else {
/* With filtering: we add candidates even if doesn't match
* the filter, in order to continue to explore the graph. */
if (neighbor_dist < furthest || candidates->count < ef) {
pq_push(candidates, neighbor, neighbor_dist);
}
/* Add results only if passes filter. */
if (filter_callback(neighbor->value, filter_privdata)) {
if (neighbor_dist < furthest || results->count < ef) {
pq_push(results, neighbor, neighbor_dist);
}
}
}
}
}
@ -677,6 +711,14 @@ pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
return results;
}
/* Just a wrapper without hybrid search callback. */
pqueue *search_layer(HNSW *index, hnswNode *query, hnswNode *entry_point,
uint32_t ef, uint32_t layer, uint32_t slot)
{
return search_layer_with_filter(index, query, entry_point, ef, layer, slot,
NULL, NULL, 0);
}
/* This function is used in order to initialize a node allocated in the
* function stack with the specified vector. The idea is that we can
* easily use hnsw_distance() from a vector and the HNSW nodes this way: