Add filtering for kNN vector indexer test scenarios (#130751)
* Add filtering for kNN vector indexer test scenarios * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
2f7527722a
commit
c7c5d4af74
|
@ -46,6 +46,8 @@ record CmdLineArgs(
|
|||
int indexThreads,
|
||||
boolean reindex,
|
||||
boolean forceMerge,
|
||||
float filterSelectivity,
|
||||
long seed,
|
||||
VectorSimilarityFunction vectorSpace,
|
||||
int quantizeBits,
|
||||
VectorEncoding vectorEncoding,
|
||||
|
@ -75,6 +77,8 @@ record CmdLineArgs(
|
|||
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
|
||||
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
|
||||
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
|
||||
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
|
||||
static final ParseField SEED_FIELD = new ParseField("seed");
|
||||
|
||||
static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
|
||||
Builder builder = PARSER.apply(parser, null);
|
||||
|
@ -106,6 +110,8 @@ record CmdLineArgs(
|
|||
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
|
||||
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
|
||||
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
|
||||
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
|
||||
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,6 +142,9 @@ record CmdLineArgs(
|
|||
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
|
||||
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
|
||||
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
|
||||
builder.field(EARLY_TERMINATION_FIELD.getPreferredName(), earlyTermination);
|
||||
builder.field(FILTER_SELECTIVITY_FIELD.getPreferredName(), filterSelectivity);
|
||||
builder.field(SEED_FIELD.getPreferredName(), seed);
|
||||
return builder.endObject();
|
||||
}
|
||||
|
||||
|
@ -167,6 +176,8 @@ record CmdLineArgs(
|
|||
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
|
||||
private int dimensions;
|
||||
private boolean earlyTermination;
|
||||
private float filterSelectivity = 1f;
|
||||
private long seed = 1751900822751L;
|
||||
|
||||
public Builder setDocVectors(String docVectors) {
|
||||
this.docVectors = PathUtils.get(docVectors);
|
||||
|
@ -278,6 +289,16 @@ record CmdLineArgs(
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setFilterSelectivity(float filterSelectivity) {
|
||||
this.filterSelectivity = filterSelectivity;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setSeed(long seed) {
|
||||
this.seed = seed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CmdLineArgs build() {
|
||||
if (docVectors == null) {
|
||||
throw new IllegalArgumentException("Document vectors path must be provided");
|
||||
|
@ -305,6 +326,8 @@ record CmdLineArgs(
|
|||
indexThreads,
|
||||
reindex,
|
||||
forceMerge,
|
||||
filterSelectivity,
|
||||
seed,
|
||||
vectorSpace,
|
||||
quantizeBits,
|
||||
vectorEncoding,
|
||||
|
|
|
@ -178,10 +178,20 @@ public class KnnIndexTester {
|
|||
? cmdLineArgs.nProbes()
|
||||
: new int[] { 0 };
|
||||
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
|
||||
Results indexResults = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs());
|
||||
Results indexResults = new Results(
|
||||
cmdLineArgs.docVectors().getFileName().toString(),
|
||||
indexType,
|
||||
cmdLineArgs.numDocs(),
|
||||
cmdLineArgs.filterSelectivity()
|
||||
);
|
||||
Results[] results = new Results[nProbes.length];
|
||||
for (int i = 0; i < nProbes.length; i++) {
|
||||
results[i] = new Results(cmdLineArgs.docVectors().getFileName().toString(), indexType, cmdLineArgs.numDocs());
|
||||
results[i] = new Results(
|
||||
cmdLineArgs.docVectors().getFileName().toString(),
|
||||
indexType,
|
||||
cmdLineArgs.numDocs(),
|
||||
cmdLineArgs.filterSelectivity()
|
||||
);
|
||||
}
|
||||
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
|
||||
Codec codec = createCodec(cmdLineArgs);
|
||||
|
@ -244,7 +254,8 @@ public class KnnIndexTester {
|
|||
"avg_cpu_count",
|
||||
"QPS",
|
||||
"recall",
|
||||
"visited" };
|
||||
"visited",
|
||||
"filter_selectivity" };
|
||||
|
||||
// Calculate appropriate column widths based on headers and data
|
||||
|
||||
|
@ -274,7 +285,8 @@ public class KnnIndexTester {
|
|||
String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount),
|
||||
String.format(Locale.ROOT, "%.2f", queryResult.qps),
|
||||
String.format(Locale.ROOT, "%.2f", queryResult.avgRecall),
|
||||
String.format(Locale.ROOT, "%.2f", queryResult.averageVisited) };
|
||||
String.format(Locale.ROOT, "%.2f", queryResult.averageVisited),
|
||||
String.format(Locale.ROOT, "%.2f", queryResult.filterSelectivity), };
|
||||
}
|
||||
|
||||
printBlock(sb, searchHeaders, queryResultsArray);
|
||||
|
@ -339,6 +351,7 @@ public class KnnIndexTester {
|
|||
static class Results {
|
||||
final String indexType, indexName;
|
||||
final int numDocs;
|
||||
final float filterSelectivity;
|
||||
long indexTimeMS;
|
||||
long forceMergeTimeMS;
|
||||
int numSegments;
|
||||
|
@ -350,10 +363,11 @@ public class KnnIndexTester {
|
|||
double netCpuTimeMS;
|
||||
double avgCpuCount;
|
||||
|
||||
Results(String indexName, String indexType, int numDocs) {
|
||||
Results(String indexName, String indexType, int numDocs, float filterSelectivity) {
|
||||
this.indexName = indexName;
|
||||
this.indexType = indexType;
|
||||
this.numDocs = numDocs;
|
||||
this.filterSelectivity = filterSelectivity;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.elasticsearch.test.knn;
|
|||
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -32,17 +33,29 @@ import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSou
|
|||
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
|
||||
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
|
||||
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
|
||||
import org.apache.lucene.search.BooleanClause;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.ConstantScoreScorer;
|
||||
import org.apache.lucene.search.ConstantScoreWeight;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.KnnFloatVectorQuery;
|
||||
import org.apache.lucene.search.PatienceKnnVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.QueryVisitor;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.ScorerSupplier;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.elasticsearch.common.io.Channels;
|
||||
import org.elasticsearch.core.PathUtils;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
|
@ -64,9 +77,11 @@ import java.nio.file.Files;
|
|||
import java.nio.file.Path;
|
||||
import java.nio.file.attribute.FileTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
@ -88,8 +103,8 @@ class KnnSearcher {
|
|||
private final Path queryPath;
|
||||
private final int numDocs;
|
||||
private final int numQueryVectors;
|
||||
private final long randomSeed = 42;
|
||||
private final float selectivity = 1f;
|
||||
private final long randomSeed;
|
||||
private final float selectivity;
|
||||
private final int topK;
|
||||
private final int efSearch;
|
||||
private final int nProbe;
|
||||
|
@ -120,9 +135,12 @@ class KnnSearcher {
|
|||
this.indexType = cmdLineArgs.indexType();
|
||||
this.searchThreads = cmdLineArgs.searchThreads();
|
||||
this.numSearchers = cmdLineArgs.numSearchers();
|
||||
this.randomSeed = cmdLineArgs.seed();
|
||||
this.selectivity = cmdLineArgs.filterSelectivity();
|
||||
}
|
||||
|
||||
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
|
||||
Query filterQuery = this.selectivity < 1f ? generateRandomQuery(new Random(randomSeed), indexPath, numDocs, selectivity) : null;
|
||||
TopDocs[] results = new TopDocs[numQueryVectors];
|
||||
int[][] resultIds = new int[numQueryVectors][];
|
||||
long elapsed, totalCpuTimeMS, totalVisited = 0;
|
||||
|
@ -164,10 +182,10 @@ class KnnSearcher {
|
|||
for (int i = 0; i < numQueryVectors; i++) {
|
||||
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
|
||||
targetReader.next(targetBytes);
|
||||
doVectorQuery(targetBytes, searcher, earlyTermination);
|
||||
doVectorQuery(targetBytes, searcher, filterQuery, earlyTermination);
|
||||
} else {
|
||||
targetReader.next(target);
|
||||
doVectorQuery(target, searcher, earlyTermination);
|
||||
doVectorQuery(target, searcher, filterQuery, earlyTermination);
|
||||
}
|
||||
}
|
||||
targetReader.reset();
|
||||
|
@ -180,7 +198,7 @@ class KnnSearcher {
|
|||
for (int s = 0; s < numSearchers; s++) {
|
||||
queryConsumers[s] = i -> {
|
||||
try {
|
||||
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
|
||||
results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination);
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
|
@ -194,7 +212,7 @@ class KnnSearcher {
|
|||
for (int s = 0; s < numSearchers; s++) {
|
||||
queryConsumers[s] = i -> {
|
||||
try {
|
||||
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
|
||||
results[i] = doVectorQuery(queries[i], searcher, filterQuery, earlyTermination);
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
|
@ -274,7 +292,7 @@ class KnnSearcher {
|
|||
}
|
||||
}
|
||||
logger.info("checking results");
|
||||
int[][] nn = getOrCalculateExactNN(offsetByteSize);
|
||||
int[][] nn = getOrCalculateExactNN(offsetByteSize, filterQuery);
|
||||
finalResults.nProbe = indexType == KnnIndexTester.IndexType.IVF ? nProbe : 0;
|
||||
finalResults.avgRecall = checkResults(resultIds, nn, topK);
|
||||
finalResults.qps = (1000f * numQueryVectors) / elapsed;
|
||||
|
@ -284,7 +302,34 @@ class KnnSearcher {
|
|||
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
|
||||
}
|
||||
|
||||
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
|
||||
private static Query generateRandomQuery(Random random, Path indexPath, int size, float selectivity) throws IOException {
|
||||
FixedBitSet bitSet = new FixedBitSet(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (random.nextFloat() < selectivity) {
|
||||
bitSet.set(i);
|
||||
} else {
|
||||
bitSet.clear(i);
|
||||
}
|
||||
}
|
||||
|
||||
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
BitSet[] segmentDocs = new BitSet[reader.leaves().size()];
|
||||
for (var leafContext : reader.leaves()) {
|
||||
var storedFields = leafContext.reader().storedFields();
|
||||
FixedBitSet segmentBitSet = new FixedBitSet(reader.maxDoc());
|
||||
for (int d = 0; d < leafContext.reader().maxDoc(); d++) {
|
||||
int docID = Integer.parseInt(storedFields.document(d, Set.of(ID_FIELD)).get(ID_FIELD));
|
||||
if (bitSet.get(docID)) {
|
||||
segmentBitSet.set(d);
|
||||
}
|
||||
}
|
||||
segmentDocs[leafContext.ord] = segmentBitSet;
|
||||
}
|
||||
return new BitSetQuery(segmentDocs);
|
||||
}
|
||||
}
|
||||
|
||||
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQuery) throws IOException {
|
||||
// look in working directory for cached nn file
|
||||
String hash = Integer.toString(
|
||||
Objects.hash(
|
||||
|
@ -312,9 +357,9 @@ class KnnSearcher {
|
|||
// checking low-precision recall
|
||||
int[][] nn;
|
||||
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
|
||||
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
|
||||
nn = computeExactNNByte(queryPath, filterQuery, vectorFileOffsetBytes);
|
||||
} else {
|
||||
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
|
||||
nn = computeExactNN(queryPath, filterQuery, vectorFileOffsetBytes);
|
||||
}
|
||||
writeExactNN(nn, nnPath);
|
||||
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
|
||||
|
@ -333,7 +378,7 @@ class KnnSearcher {
|
|||
return true;
|
||||
}
|
||||
|
||||
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
|
||||
TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException {
|
||||
Query knnQuery;
|
||||
if (overSamplingFactor > 1f) {
|
||||
throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors");
|
||||
|
@ -346,7 +391,7 @@ class KnnSearcher {
|
|||
vector,
|
||||
topK,
|
||||
efSearch,
|
||||
null,
|
||||
filterQuery,
|
||||
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
|
||||
);
|
||||
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
|
||||
|
@ -360,7 +405,7 @@ class KnnSearcher {
|
|||
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
|
||||
}
|
||||
|
||||
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
|
||||
TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery, boolean earlyTermination) throws IOException {
|
||||
Query knnQuery;
|
||||
int topK = this.topK;
|
||||
if (overSamplingFactor > 1f) {
|
||||
|
@ -369,14 +414,14 @@ class KnnSearcher {
|
|||
}
|
||||
int efSearch = Math.max(topK, this.efSearch);
|
||||
if (indexType == KnnIndexTester.IndexType.IVF) {
|
||||
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe);
|
||||
knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, filterQuery, nProbe);
|
||||
} else {
|
||||
knnQuery = new ESKnnFloatVectorQuery(
|
||||
VECTOR_FIELD,
|
||||
vector,
|
||||
topK,
|
||||
efSearch,
|
||||
null,
|
||||
filterQuery,
|
||||
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
|
||||
);
|
||||
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
|
||||
|
@ -449,7 +494,7 @@ class KnnSearcher {
|
|||
}
|
||||
}
|
||||
|
||||
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
|
||||
private int[][] computeExactNN(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException {
|
||||
int[][] result = new int[numQueryVectors][];
|
||||
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
|
@ -463,7 +508,7 @@ class KnnSearcher {
|
|||
for (int i = 0; i < numQueryVectors; i++) {
|
||||
float[] queryVector = new float[dim];
|
||||
queryReader.next(queryVector);
|
||||
tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, similarityFunction));
|
||||
tasks.add(new ComputeNNFloatTask(i, topK, queryVector, result, reader, filterQuery, similarityFunction));
|
||||
}
|
||||
ForkJoinPool.commonPool().invokeAll(tasks);
|
||||
}
|
||||
|
@ -471,7 +516,7 @@ class KnnSearcher {
|
|||
}
|
||||
}
|
||||
|
||||
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
|
||||
private int[][] computeExactNNByte(Path queryPath, Query filterQuery, int vectorFileOffsetBytes) throws IOException {
|
||||
int[][] result = new int[numQueryVectors][];
|
||||
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
|
@ -480,7 +525,7 @@ class KnnSearcher {
|
|||
for (int i = 0; i < numQueryVectors; i++) {
|
||||
byte[] queryVector = new byte[dim];
|
||||
queryReader.next(queryVector);
|
||||
tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, similarityFunction));
|
||||
tasks.add(new ComputeNNByteTask(i, queryVector, result, reader, filterQuery, similarityFunction));
|
||||
}
|
||||
ForkJoinPool.commonPool().invokeAll(tasks);
|
||||
}
|
||||
|
@ -495,6 +540,7 @@ class KnnSearcher {
|
|||
private final int[][] result;
|
||||
private final IndexReader reader;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final Query filterQuery;
|
||||
private final int topK;
|
||||
|
||||
ComputeNNFloatTask(
|
||||
|
@ -503,6 +549,7 @@ class KnnSearcher {
|
|||
float[] query,
|
||||
int[][] result,
|
||||
IndexReader reader,
|
||||
Query filterQuery,
|
||||
VectorSimilarityFunction similarityFunction
|
||||
) {
|
||||
this.queryOrd = queryOrd;
|
||||
|
@ -510,6 +557,7 @@ class KnnSearcher {
|
|||
this.result = result;
|
||||
this.reader = reader;
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.filterQuery = filterQuery;
|
||||
this.topK = topK;
|
||||
}
|
||||
|
||||
|
@ -520,6 +568,11 @@ class KnnSearcher {
|
|||
var queryVector = new ConstKnnFloatValueSource(query);
|
||||
var docVectors = new FloatKnnVectorFieldSource(VECTOR_FIELD);
|
||||
Query query = new FunctionQuery(new FloatVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
|
||||
if (filterQuery != null) {
|
||||
query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD)
|
||||
.add(filterQuery, BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
}
|
||||
var topDocs = searcher.search(query, topK);
|
||||
result[queryOrd] = getResultIds(topDocs, reader.storedFields());
|
||||
if ((queryOrd + 1) % 10 == 0) {
|
||||
|
@ -538,13 +591,22 @@ class KnnSearcher {
|
|||
private final byte[] query;
|
||||
private final int[][] result;
|
||||
private final IndexReader reader;
|
||||
private final Query filterQuery;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader, VectorSimilarityFunction similarityFunction) {
|
||||
ComputeNNByteTask(
|
||||
int queryOrd,
|
||||
byte[] query,
|
||||
int[][] result,
|
||||
IndexReader reader,
|
||||
Query filterQuery,
|
||||
VectorSimilarityFunction similarityFunction
|
||||
) {
|
||||
this.queryOrd = queryOrd;
|
||||
this.query = query;
|
||||
this.result = result;
|
||||
this.reader = reader;
|
||||
this.filterQuery = filterQuery;
|
||||
this.similarityFunction = similarityFunction;
|
||||
}
|
||||
|
||||
|
@ -556,6 +618,11 @@ class KnnSearcher {
|
|||
var queryVector = new ConstKnnByteVectorValueSource(query);
|
||||
var docVectors = new ByteKnnVectorFieldSource(VECTOR_FIELD);
|
||||
Query query = new FunctionQuery(new ByteVectorSimilarityFunction(similarityFunction, queryVector, docVectors));
|
||||
if (filterQuery != null) {
|
||||
query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.SHOULD)
|
||||
.add(filterQuery, BooleanClause.Occur.FILTER)
|
||||
.build();
|
||||
}
|
||||
var topDocs = searcher.search(query, topK);
|
||||
result[queryOrd] = getResultIds(topDocs, reader.storedFields());
|
||||
if ((queryOrd + 1) % 10 == 0) {
|
||||
|
@ -582,4 +649,57 @@ class KnnSearcher {
|
|||
return resultIds;
|
||||
}
|
||||
|
||||
private static class BitSetQuery extends Query {
|
||||
private final BitSet[] segmentDocs;
|
||||
|
||||
BitSetQuery(BitSet[] segmentDocs) {
|
||||
this.segmentDocs = segmentDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
var bitSet = segmentDocs[context.ord];
|
||||
var cardinality = bitSet.cardinality();
|
||||
var scorer = new ConstantScoreScorer(score(), scoreMode, new BitSetIterator(bitSet, cardinality));
|
||||
return new ScorerSupplier() {
|
||||
@Override
|
||||
public Scorer get(long leadCost) throws IOException {
|
||||
return scorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return cardinality;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "BitSetQuery";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && Arrays.equals(segmentDocs, ((BitSetQuery) other).segmentDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + Arrays.hashCode(segmentDocs);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue