Add filtering for kNN vector indexer test scenarios (#130751)

* Add filtering for kNN vector indexer test scenarios

* [CI] Auto commit changes from spotless

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Benjamin Trent 2025-07-08 10:27:42 -04:00 committed by GitHub
parent 2f7527722a
commit c7c5d4af74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 25 deletions

View File

@ -46,6 +46,8 @@ record CmdLineArgs(
int indexThreads,
boolean reindex,
boolean forceMerge,
float filterSelectivity,
long seed,
VectorSimilarityFunction vectorSpace,
int quantizeBits,
VectorEncoding vectorEncoding,
@ -75,6 +77,8 @@ record CmdLineArgs(
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
static final ParseField SEED_FIELD = new ParseField("seed");
static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
Builder builder = PARSER.apply(parser, null);
@ -106,6 +110,8 @@ record CmdLineArgs(
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
}
@Override
@ -136,6 +142,9 @@ record CmdLineArgs(
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination);
builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity);
builder.field(SEED_FIELD.getPreferredName(), seed);
return builder.endObject();
}
@ -167,6 +176,8 @@ record CmdLineArgs(
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private int dimensions;
private boolean earlyTermination;
private float filterSelectivity = 1f;
private long seed = 1751900822751L;
public Builder setDocVectors(String docVectors) {
this.docVectors = PathUtils.get(docVectors);
@ -278,6 +289,16 @@ record CmdLineArgs(
return this;
}
public Builder setFilterSelectivity(float filterSelectivity) {
this.filterSelectivity = filterSelectivity;
return this;
}
public Builder setSeed(long seed) {
this.seed = seed;
return this;
}
public CmdLineArgs build() {
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
@ -305,6 +326,8 @@ record CmdLineArgs(
indexThreads,
reindex,
forceMerge,
filterSelectivity,
seed,
vectorSpace,
quantizeBits,
vectorEncoding,

View File

@ -178,10 +178,20 @@ public class KnnIndexTester {
? cmdLineArgs.nProbes()
: new int[] { 0 };
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
Results indexResults = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs());
Results indexResults = new Results(
cmdLineArgs.docVectors().getFileName().toString(),
indexType,
cmdLineArgs.numDocs(),
cmdLineArgs.filterSelectivity()
);
Results[] results = new Results[nProbes.length];
for (int i = 0; i < nProbes.length; i++) {
results[i] = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs());
results[i] = new Results(
cmdLineArgs.docVectors().getFileName().toString(),
indexType,
cmdLineArgs.numDocs(),
cmdLineArgs.filterSelectivity()
);
}
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
Codec codec = createCodec(cmdLineArgs);
@ -244,7 +254,8 @@ public class KnnIndexTester {
"avg_cpu_count",
"QPS",
"recall",
"visited" };
"visited",
"filter_selectivity" };
// Calculate appropriate column widths based on headers and data
@ -274,7 +285,8 @@ public class KnnIndexTester {
String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount),
String.format(Locale.ROOT, "%.2f", queryResult.qps),
String.format(Locale.ROOT, "%.2f", queryResult.avgRecall),
String.format(Locale.ROOT, "%.2f", queryResult.averageVisited) };
String.format(Locale.ROOT, "%.2f", queryResult.averageVisited),
String.format(Locale.ROOT, "%.2f", queryResult.filterSelectivity), };
}
printBlock(sb, searchHeaders, queryResultsArray);
@ -339,6 +351,7 @@ public class KnnIndexTester {
static class Results {
final String indexType, indexName;
final int numDocs;
final float filterSelectivity;
long indexTimeMS;
long forceMergeTimeMS;
int numSegments;
@ -350,10 +363,11 @@ public class KnnIndexTester {
double netCpuTimeMS;
double avgCpuCount;
Results(String indexName, String indexType, int numDocs) {
Results(String indexName, String indexType, int numDocs, float filterSelectivity) {
this.indexName = indexName;
this.indexType = indexType;
this.numDocs = numDocs;
this.filterSelectivity = filterSelectivity;
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.test.knn;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -32,17 +33,29 @@ import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSou
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.common.io.Channels;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@ -64,9 +77,11 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
@ -88,8 +103,8 @@ class KnnSearcher {
private final Path queryPath;
private final int numDocs;
private final int numQueryVectors;
private final long randomSeed = 42;
private final float selectivity = 1f;
private final long randomSeed;
private final float selectivity;
private final int topK;
private final int efSearch;
private final int nProbe;
@ -120,9 +135,12 @@ class KnnSearcher {
this.indexType = cmdLineArgs.indexType();
this.searchThreads = cmdLineArgs.searchThreads();
this.numSearchers = cmdLineArgs.numSearchers();
this.randomSeed = cmdLineArgs.seed();
this.selectivity = cmdLineArgs.filterSelectivity();
}
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
Query filterQuery = this.selectivity < 1f ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity) : null;
TopDocs[] results = new TopDocs[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][];
long elapsed, totalCpuTimeMS, totalVisited = 0;
@ -164,10 +182,10 @@ class KnnSearcher {
for (int i = 0; i < numQueryVectors; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
targetReader.next(targetBytes);
doVectorQuery(targetBytes, searcher, earlyTermination);
doVectorQuery(targetBytes, searcher, filterQuery, earlyTermination);
} else {
targetReader.next(target);
doVectorQuery(target, searcher, earlyTermination);
doVectorQuery(target, searcher, filterQuery, earlyTermination);
}
}
targetReader.reset();
@ -180,7 +198,7 @@ class KnnSearcher {
for (int s = 0; s < numSearchers; s++) {
queryConsumers[s] = i -> {
try {
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
@ -194,7 +212,7 @@ class KnnSearcher {
for (int s = 0; s < numSearchers; s++) {
queryConsumers[s] = i -> {
try {
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
@ -274,7 +292,7 @@ class KnnSearcher {
}
}
logger.info("checking results");
int[][] nn = getOrCalculateExactNN(offsetByteSize);
int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery);
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed;
@ -284,7 +302,34 @@ class KnnSearcher {
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
}
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
FixedBitSet bitSet = new FixedBitSet(size);
for (int i = 0; i < size; i++) {
if (random.nextFloat() < selectivity) {
bitSet.set(i);
} else {
bitSet.clear(i);
}
}
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
BitSet[] segmentDocs = new BitSet[reader.leaves().size()];
for (var leafContext : reader.leaves()) {
var storedFields = leafContext.reader().storedFields();
FixedBitSet segmentBitSet = new FixedBitSet(reader.maxDoc());
for (int d = 0; d < leafContext.reader().maxDoc(); d++) {
int docID = Integer.parseInt(storedFields.document(d, Set.of(ID_FIELD)).get(ID_FIELD));
if (bitSet.get(docID)) {
segmentBitSet.set(d);
}
}
segmentDocs[leafContext.ord] = segmentBitSet;
}
return new BitSetQuery(segmentDocs);
}
}
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQuery) throws IOException {
// look in working directory for cached nn file
String hash = Integer.toString(
Objects.hash(
@ -312,9 +357,9 @@ class KnnSearcher {
// checking low-precision recall
int[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
nn = computeExactNNByte(queryPath, filterQuery, vectorFileOffsetBytes);
} else {
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
nn = computeExactNN(queryPath, filterQuery, vectorFileOffsetBytes);
}
writeExactNN(nn, nnPath);
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
@ -333,7 +378,7 @@ class KnnSearcher {
return true;
}
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException {
Query knnQuery;
if (overSamplingFactor > 1f) {
throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors");
@ -346,7 +391,7 @@ class KnnSearcher {
vector,
topK,
efSearch,
null,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
@ -360,7 +405,7 @@ class KnnSearcher {
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
}
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException {
Query knnQuery;
int topK = this.topK;
if (overSamplingFactor > 1f) {
@ -369,14 +414,14 @@ class KnnSearcher {
}
int efSearch = Math.max(topK, this.efSearch);
if (indexType == KnnIndexTester.IndexType.IVF) {
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe);
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe);
} else {
knnQuery = new ESKnnFloatVectorQuery(
VECTOR_FIELD,
vector,
topK,
efSearch,
null,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
@ -449,7 +494,7 @@ class KnnSearcher {
}
}
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
private int[][] computeExactNN(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
@ -463,7 +508,7 @@ class KnnSearcher {
for (int i = 0; i < numQueryVectors; i++) {
float[] queryVector = new float[dim];
queryReader.next(queryVector);
tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, similarityFunction));
tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, filterQuery, similarityFunction));
}
ForkJoinPool.commonPool().invokeAll(tasks);
}
@ -471,7 +516,7 @@ class KnnSearcher {
}
}
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
@ -480,7 +525,7 @@ class KnnSearcher {
for (int i = 0; i < numQueryVectors; i++) {
byte[] queryVector = new byte[dim];
queryReader.next(queryVector);
tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, similarityFunction));
tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, filterQuery, similarityFunction));
}
ForkJoinPool.commonPool().invokeAll(tasks);
}
@ -495,6 +540,7 @@ class KnnSearcher {
private final int[][] result;
private final IndexReader reader;
private final VectorSimilarityFunction similarityFunction;
private final Query filterQuery;
private final int topK;
ComputeNNFloatTask(
@ -503,6 +549,7 @@ class KnnSearcher {
float[] query,
int[][] result,
IndexReader reader,
Query filterQuery,
VectorSimilarityFunction similarityFunction
) {
this.queryOrd = queryOrd;
@ -510,6 +557,7 @@ class KnnSearcher {
this.result = result;
this.reader = reader;
this.similarityFunction = similarityFunction;
this.filterQuery = filterQuery;
this.topK = topK;
}
@ -520,6 +568,11 @@ class KnnSearcher {
var queryVector = new ConstKnnFloatValueSource(query);
var docVectors = new FloatKnnVectorFieldSource(VECTOR_FIELD);
Query query = new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
if (filterQuery != null) {
query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD)
.add(filterQuery, BooleanClause.Occur.FILTER)
.build();
}
var topDocs = searcher.search(query, topK);
result[queryOrd] = getResultIds(topDocs, reader.storedFields());
if ((queryOrd + 1) % 10 == 0) {
@ -538,13 +591,22 @@ class KnnSearcher {
private final byte[] query;
private final int[][] result;
private final IndexReader reader;
private final Query filterQuery;
private final VectorSimilarityFunction similarityFunction;
ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader, VectorSimilarityFunction similarityFunction) {
ComputeNNByteTask(
int queryOrd,
byte[] query,
int[][] result,
IndexReader reader,
Query filterQuery,
VectorSimilarityFunction similarityFunction
) {
this.queryOrd = queryOrd;
this.query = query;
this.result = result;
this.reader = reader;
this.filterQuery = filterQuery;
this.similarityFunction = similarityFunction;
}
@ -556,6 +618,11 @@ class KnnSearcher {
var queryVector = new ConstKnnByteVectorValueSource(query);
var docVectors = new ByteKnnVectorFieldSource(VECTOR_FIELD);
Query query = new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
if (filterQuery != null) {
query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD)
.add(filterQuery, BooleanClause.Occur.FILTER)
.build();
}
var topDocs = searcher.search(query, topK);
result[queryOrd] = getResultIds(topDocs, reader.storedFields());
if ((queryOrd + 1) % 10 == 0) {
@ -582,4 +649,57 @@ class KnnSearcher {
return resultIds;
}
private static class BitSetQuery extends Query {
private final BitSet[] segmentDocs;
BitSetQuery(BitSet[] segmentDocs) {
this.segmentDocs = segmentDocs;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
var bitSet = segmentDocs[context.ord];
var cardinality = bitSet.cardinality();
var scorer = new ConstantScoreScorer(score(), scoreMode, new BitSetIterator(bitSet, cardinality));
return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) throws IOException {
return scorer;
}
@Override
public long cost() {
return cardinality;
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}
@Override
public void visit(QueryVisitor visitor) {}
@Override
public String toString(String field) {
return "BitSetQuery";
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && Arrays.equals(segmentDocs, ((BitSetQuery) other).segmentDocs);
}
@Override
public int hashCode() {
return 31 * classHash() + Arrays.hashCode(segmentDocs);
}
}
}