Improve HNSW filtered search speed through new heuristic (#126876)

Apache Lucene 10.2 exposes a new search strategy for executing filtered searches over HNSW graphs.

This PR switches to utilizing that strategy by default as it generally provides a much better recall/latency pareto frontier than our regular hnsw fanout search.

Additionally, a new tech-preview setting is provided to potentially revert to the old fanout behavior if issues arise.
This commit is contained in:
Benjamin Trent 2025-05-06 13:41:16 -04:00 committed by GitHub
parent 22a52a9c64
commit 8bb7dc4058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 412 additions and 51 deletions

View File

@ -0,0 +1,5 @@
pr: 126876
summary: Improve HNSW filtered search speed through new heuristic
area: Vector Search
type: enhancement
issues: []

View File

@ -249,6 +249,12 @@ $$$index-final-pipeline$$$
$$$index-hidden$$$ `index.hidden`
: Indicates whether the index should be hidden by default. Hidden indices are not returned by default when using a wildcard expression. This behavior is controlled per request through the use of the `expand_wildcards` parameter. Possible values are `true` and `false` (default).
$$$index-dense-vector-hnsw-filter-heuristic$$$ `index.dense_vector.hnsw_filter_heuristic`
: The heuristic to utilize when executing a filtered search against vectors in an HNSW graph. This setting is in technical preview may be changed or removed in a future release. It can be set to:
* `acorn` (default) - Only vectors that match the filter criteria are searched. This is the fastest option, and generally provides faster searches at similar recall to `fanout`, but `num_candidates` might need to be increased for exceptionally high recall requirements.
* `fanout` - All vectors are compared with the query vector, but only those passing the criteria are added to the search results. Can be slower than `acorn`, but may yield higher recall.
$$$index-esql-stored-fields-sequential-proportion$$$
`index.esql.stored_fields_sequential_proportion`

View File

@ -36,6 +36,7 @@ import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.index.store.FsDirectoryFactory;
import org.elasticsearch.index.store.Store;
@ -157,6 +158,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
IndexSettings.INDEX_TRANSLOG_RETENTION_AGE_SETTING,
IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING,
IndexSettings.INDEX_SEARCH_IDLE_AFTER,
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC,
IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY,
IndexSettings.IGNORE_ABOVE_SETTING,
FieldMapper.IGNORE_MALFORMED_SETTING,

View File

@ -29,6 +29,7 @@ import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.translog.Translog;
import org.elasticsearch.indices.recovery.RecoverySettings;
import org.elasticsearch.ingest.IngestService;
@ -896,6 +897,7 @@ public final class IndexSettings {
private volatile int maxTokenCount;
private volatile int maxNgramDiff;
private volatile int maxShingleDiff;
private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic;
private volatile TimeValue searchIdleAfter;
private volatile int maxAnalyzedOffset;
private volatile boolean weightMatchesEnabled;
@ -1091,6 +1093,7 @@ public final class IndexSettings {
logsdbAddHostNameField = scopedSettings.get(LOGSDB_ADD_HOST_NAME_FIELD);
skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING);
skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING);
hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC);
indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING);
recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings);
recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false
@ -1203,6 +1206,7 @@ public final class IndexSettings {
this::setSkipIgnoredSourceWrite
);
scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic);
}
private void setSearchIdleAfter(TimeValue searchIdleAfter) {
@ -1821,4 +1825,16 @@ public final class IndexSettings {
public IndexRouting getIndexRouting() {
return indexRouting;
}
/**
* The heuristic to utilize when executing filtered search on vectors indexed
* in HNSW format.
*/
public DenseVectorFieldMapper.FilterHeuristic getHnswFilterHeuristic() {
return this.hnswFilterHeuristic;
}
private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuristic) {
this.hnswFilterHeuristic = heuristic;
}
}

View File

@ -166,6 +166,7 @@ public class IndexVersions {
public static final IndexVersion UPGRADE_TO_LUCENE_10_2_1 = def(9_023_00_0, Version.LUCENE_10_2_1);
public static final IndexVersion DEFAULT_OVERSAMPLE_VALUE_FOR_BBQ = def(9_024_0_00, Version.LUCENE_10_2_1);
public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ = def(9_025_0_00, Version.LUCENE_10_2_1);
public static final IndexVersion DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC = def(9_026_0_00, Version.LUCENE_10_2_1);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View File

@ -33,10 +33,12 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.IndexVersion;
@ -93,6 +95,7 @@ import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_INDEX_VERSION_CREATED;
import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
@ -108,6 +111,51 @@ public class DenseVectorFieldMapper extends FieldMapper {
return Math.abs(magnitude - 1.0f) > EPS;
}
/**
* The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph.
*/
public enum FilterHeuristic {
/**
* This heuristic searches the entire graph, doing vector comparisons in all immediate neighbors
* but only collects vectors that match the filtering criteria.
*/
FANOUT {
static final KnnSearchStrategy FANOUT_STRATEGY = new KnnSearchStrategy.Hnsw(0);
@Override
public KnnSearchStrategy getKnnSearchStrategy() {
return FANOUT_STRATEGY;
}
},
/**
* This heuristic will only compare vectors that match the filtering criteria.
*/
ACORN {
static final KnnSearchStrategy ACORN_STRATEGY = new KnnSearchStrategy.Hnsw(60);
@Override
public KnnSearchStrategy getKnnSearchStrategy() {
return ACORN_STRATEGY;
}
};
public abstract KnnSearchStrategy getKnnSearchStrategy();
}
public static final Setting<FilterHeuristic> HNSW_FILTER_HEURISTIC = Setting.enumSetting(FilterHeuristic.class, s -> {
IndexVersion version = SETTING_INDEX_VERSION_CREATED.get(s);
if (version.onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC)) {
return FilterHeuristic.ACORN.toString();
}
return FilterHeuristic.FANOUT.toString();
},
"index.dense_vector.hnsw_filter_heuristic",
fh -> {},
Setting.Property.IndexScope,
Setting.Property.ServerlessPublic,
Setting.Property.Dynamic
);
private static boolean hasRescoreIndexVersion(IndexVersion version) {
return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)
|| version.between(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0);
@ -2210,15 +2258,25 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float oversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
DenseVectorFieldMapper.FilterHeuristic heuristic
) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
);
}
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
k,
numCands,
filter,
similarityThreshold,
parentFilter,
knnSearchStrategy
);
case FLOAT -> createKnnFloatQuery(
queryVector.asFloatVector(),
k,
@ -2226,9 +2284,18 @@ public class DenseVectorFieldMapper extends FieldMapper {
oversample,
filter,
similarityThreshold,
parentFilter
parentFilter,
knnSearchStrategy
);
case BIT -> createKnnBitQuery(
queryVector.asByteVector(),
k,
numCands,
filter,
similarityThreshold,
parentFilter,
knnSearchStrategy
);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
};
}
@ -2246,12 +2313,13 @@ public class DenseVectorFieldMapper extends FieldMapper {
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
@ -2268,7 +2336,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
@ -2277,8 +2346,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
@ -2296,7 +2365,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float queryOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy knnSearchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
elementType.checkVectorBounds(queryVector);
@ -2330,8 +2400,16 @@ public class DenseVectorFieldMapper extends FieldMapper {
numCands = Math.max(adjustedK, numCands);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter);
? new ESDiversifyingChildrenFloatKnnVectorQuery(
name(),
queryVector,
filter,
adjustedK,
numCands,
parentFilter,
knnSearchStrategy
)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
if (rescore) {
knnQuery = new RescoreKnnVectorQuery(
name(),

View File

@ -13,6 +13,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;
public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
@ -25,9 +26,10 @@ public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildr
Query childFilter,
Integer k,
int numCands,
BitSetProducer parentsFilter
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
super(field, query, childFilter, numCands, parentsFilter);
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
}
@ -42,4 +44,8 @@ public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildr
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}

View File

@ -13,6 +13,7 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;
public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider {
@ -25,9 +26,10 @@ public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChild
Query childFilter,
Integer k,
int numCands,
BitSetProducer parentsFilter
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
super(field, query, childFilter, numCands, parentsFilter);
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
}
@ -42,4 +44,8 @@ public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChild
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}

View File

@ -12,14 +12,15 @@ package org.elasticsearch.search.vectors;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;
public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private long vectorOpsCount;
public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter) {
super(field, target, numCands, filter);
public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}
@ -39,4 +40,8 @@ public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryPro
public Integer kParam() {
return kParam;
}
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}

View File

@ -12,14 +12,15 @@ package org.elasticsearch.search.vectors;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;
public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private long vectorOpsCount;
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) {
super(field, target, numCands, filter);
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}
@ -39,4 +40,8 @@ public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryP
public Integer kParam() {
return kParam;
}
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}

View File

@ -552,8 +552,17 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet);
DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
oversample,
filterQuery,
vectorSimilarity,
parentBitSet,
heuristic
);
}
@Override

View File

@ -1881,7 +1881,16 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
Exception e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 128, 0, 0 }),
3,
3,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
e.getMessage(),
@ -1897,7 +1906,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
@ -1907,7 +1917,16 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }),
3,
3,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
e.getMessage(),
@ -1916,7 +1935,16 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
e = expectThrows(
IllegalArgumentException.class,
() -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null)
() -> denseVectorFieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }),
3,
3,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
e.getMessage(),
@ -1932,7 +1960,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];"));
@ -1946,7 +1975,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
@ -1963,7 +1993,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
@ -1997,7 +2028,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];"));
@ -2011,7 +2043,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(
@ -2028,7 +2061,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(

View File

@ -15,6 +15,8 @@ import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.mapper.FieldTypeTestCase;
@ -32,8 +34,10 @@ import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BIT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
@ -216,7 +220,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
for (int i = 0; i < dims; i++) {
queryVector[i] = randomFloat();
}
Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer);
Query query = field.createKnnQuery(
VectorData.fromFloats(queryVector),
10,
10,
null,
null,
null,
producer,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
query = rescoreKnnVectorQuery.innerQuery();
}
@ -240,11 +253,29 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
floatQueryVector[i] = queryVector[i];
}
VectorData vectorData = new VectorData(null, queryVector);
Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer);
Query query = field.createKnnQuery(
vectorData,
10,
10,
null,
null,
null,
producer,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
vectorData = new VectorData(floatQueryVector, null);
query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer);
query = field.createKnnQuery(
vectorData,
10,
10,
null,
null,
null,
producer,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
}
}
@ -312,7 +343,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
null,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
@ -333,7 +365,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
}
e = expectThrows(
IllegalArgumentException.class,
() -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null)
() -> dotProductField.createKnnQuery(
VectorData.fromFloats(queryVector),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
@ -349,7 +390,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
);
e = expectThrows(
IllegalArgumentException.class,
() -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null)
() -> cosineField.createKnnQuery(
VectorData.fromFloats(new float[BBQ_MIN_DIMS]),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
}
@ -370,7 +420,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
for (int i = 0; i < 4096; i++) {
queryVector[i] = randomFloat();
}
Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null);
Query query = fieldWith4096dims.createKnnQuery(
VectorData.fromFloats(queryVector),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
query = rescoreKnnVectorQuery.innerQuery();
}
@ -393,7 +452,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
queryVector[i] = randomByte();
}
VectorData vectorData = new VectorData(null, queryVector);
Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null);
Query query = fieldWith4096dims.createKnnQuery(
vectorData,
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
assertThat(query, instanceOf(KnnByteVectorQuery.class));
}
}
@ -411,7 +479,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
);
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null)
() -> unindexedField.createKnnQuery(
VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
@ -427,13 +504,31 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
);
e = expectThrows(
IllegalArgumentException.class,
() -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null)
() -> cosineField.createKnnQuery(
VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
e = expectThrows(
IllegalArgumentException.class,
() -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null)
() -> cosineField.createKnnQuery(
new VectorData(null, new byte[] { 0, 0, 0 }),
10,
10,
null,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
)
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
}
@ -458,7 +553,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
randomFloatBetween(1.0F, 10.0F, false),
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
if (elementType == BYTE) {
@ -504,7 +600,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(randomFloatBetween(1.1f, 9.9f, false))),
Collections.emptyMap()
);
Query query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 0f, null, null, null);
Query query = fieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 1, 4, 10 }),
10,
100,
0f,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
assertTrue(query instanceof ESKnnFloatVectorQuery);
// verify we can override a `0` to a positive number
@ -518,7 +623,16 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
randomIndexOptionsHnswQuantized(new DenseVectorFieldMapper.RescoreVector(0)),
Collections.emptyMap()
);
query = fieldType.createKnnQuery(VectorData.fromFloats(new float[] { 1, 4, 10 }), 10, 100, 2f, null, null, null);
query = fieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 1, 4, 10 }),
10,
100,
2f,
null,
null,
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
assertTrue(query instanceof RescoreKnnVectorQuery);
assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10));
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery();
@ -526,6 +640,55 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
}
public void testFilterSearchThreshold() {
List<Tuple<DenseVectorFieldMapper.ElementType, Function<Query, KnnSearchStrategy>>> cases = List.of(
Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()),
Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()),
Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy())
);
for (var tuple : cases) {
DenseVectorFieldType fieldType = new DenseVectorFieldType(
"f",
IndexVersion.current(),
tuple.v1(),
tuple.v1() == BIT ? 3 * 8 : 3,
true,
VectorSimilarity.COSINE,
randomIndexOptionsHnswQuantized(),
Collections.emptyMap()
);
// Test with a filter search threshold
Query query = fieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 1, 4, 10 }),
10,
100,
0f,
null,
null,
null,
DenseVectorFieldMapper.FilterHeuristic.FANOUT
);
KnnSearchStrategy strategy = tuple.v2().apply(query);
assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0));
query = fieldType.createKnnQuery(
VectorData.fromFloats(new float[] { 1, 4, 10 }),
10,
100,
0f,
null,
null,
null,
DenseVectorFieldMapper.FilterHeuristic.ACORN
);
strategy = tuple.v2().apply(query);
assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60));
}
}
private static void checkRescoreQueryParameters(
DenseVectorFieldType fieldType,
int k,
@ -542,7 +705,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
oversample,
null,
null,
null
null,
randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
);
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();

View File

@ -13,6 +13,7 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.support.PlainActionFuture;
@ -21,6 +22,7 @@ import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.query.InnerHitsRewriteContext;
@ -216,9 +218,29 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
numCands = Math.max(numCands, k);
}
final KnnSearchStrategy expectedStrategy = context.getIndexSettings()
.getIndexVersionCreated()
.onOrAfter(IndexVersions.DEFAULT_TO_ACORN_HNSW_FILTER_HEURISTIC)
? DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
: DenseVectorFieldMapper.FilterHeuristic.FANOUT.getKnnSearchStrategy();
Query knnVectorQueryBuilt = switch (elementType()) {
case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery);
case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery);
case BYTE, BIT -> new ESKnnByteVectorQuery(
VECTOR_FIELD,
queryBuilder.queryVector().asByteVector(),
k,
numCands,
filterQuery,
expectedStrategy
);
case FLOAT -> new ESKnnFloatVectorQuery(
VECTOR_FIELD,
queryBuilder.queryVector().asFloatVector(),
k,
numCands,
filterQuery,
expectedStrategy
);
};
if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
query = vectorSimilarityQuery.getInnerKnnQuery();

View File

@ -38,6 +38,7 @@ import org.elasticsearch.index.SearchSlowLog;
import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
import org.elasticsearch.index.engine.EngineConfig;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.IndicesRequestCache;
import org.elasticsearch.indices.IndicesService;
@ -529,7 +530,8 @@ public class TransportResumeFollowAction extends AcknowledgedTransportMasterNode
EngineConfig.INDEX_CODEC_SETTING,
DataTier.TIER_PREFERENCE_SETTING,
IndexSettings.BLOOM_FILTER_ID_FIELD_ENABLED_SETTING,
MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING
MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING,
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC
);
public static Settings filter(Settings originalSettings) {