ES|QL - Add scoring for full text functions disjunctions (#121793)

This commit is contained in:
Carlos Delgado 2025-03-11 15:29:15 +01:00 committed by GitHub
parent e11d89d76b
commit 2b40e73fe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1619 additions and 691 deletions

View File

@ -0,0 +1,5 @@
pr: 121793
summary: "ES|QL - Add scoring for full text functions disjunctions"
area: ES|QL
type: enhancement
issues: []

View File

@ -117,9 +117,7 @@ In addition, when [querying multiple indexes](docs-content://explore-analyze/que
## Full-text search [esql-limitations-full-text-search]
[preview] {{esql}}'s support for [full-text search](/reference/query-languages/esql/esql-functions-operators.md#esql-search-functions) is currently in Technical Preview. One limitation of full-text search is that it is necessary to use the search function, like [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), in a [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command directly after the [`FROM`](/reference/query-languages/esql/esql-commands.md#esql-from) source command, or close enough to it. Otherwise, the query will fail with a validation error. Another limitation is that any [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command containing a full-text search function cannot use disjunctions (`OR`), unless:
* All functions used in the OR clauses are full-text functions themselves, or scoring is not used
[preview] {{esql}}'s support for [full-text search](/reference/query-languages/esql/esql-functions-operators.md#esql-search-functions) is currently in Technical Preview. One limitation of full-text search is that it is necessary to use the search function, like [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), in a [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command directly after the [`FROM`](/reference/query-languages/esql/esql-commands.md#esql-from) source command, or close enough to it. Otherwise, the query will fail with a validation error.
For example, this query is valid:
@ -136,27 +134,6 @@ FROM books
| WHERE MATCH(author, "Faulkner")
```
And this query that uses a disjunction will succeed:
```esql
FROM books
| WHERE MATCH(author, "Faulkner") OR QSTR("author: Hemingway")
```
However using scoring will fail because it uses a non full text function as part of the disjunction:
```esql
FROM books METADATA _score
| WHERE MATCH(author, "Faulkner") OR author LIKE "Hemingway"
```
Scoring will work in the following query, as it uses full text functions on both `OR` clauses:
```esql
FROM books METADATA _score
| WHERE MATCH(author, "Faulkner") OR QSTR("author: Hemingway")
```
Note that, because of [the way {{esql}} treats `text` values](#esql-limitations-text-fields), any queries on `text` fields that do not explicitly use the full-text functions, [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), [`QSTR`](/reference/query-languages/esql/esql-functions-operators.md#esql-qstr) or [`KQL`](/reference/query-languages/esql/esql-functions-operators.md#esql-kql), will behave as if the fields are actually `keyword` fields: they are case-sensitive and need to match the full string.

View File

@ -0,0 +1,403 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.lucene;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
/**
* Base class for evaluating a Lucene query at the compute engine and providing a Block as a result.
* Subclasses can override methods to decide what type of {@link Block} should be returned, and how to add results to it
* based on documents on the Page matching the query or not.
* See {@link LuceneQueryExpressionEvaluator} for an example of how to use this class and {@link LuceneQueryScoreEvaluator} for
* examples of subclasses that provide different types of scoring results for different ESQL constructs.
* It's much faster to push queries to the {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
* this class is here to save the day.
*/
public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {
public record ShardConfig(Query query, IndexSearcher searcher) {}
private final BlockFactory blockFactory;
private final ShardConfig[] shards;
private final List<ShardState> perShardState;
protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
this.blockFactory = blockFactory;
this.shards = shards;
this.perShardState = new ArrayList<>(Collections.nCopies(shards.length, null));
}
public Block executeQuery(Page page) {
// Lucene based operators retrieve DocVectors as first block
Block block = page.getBlock(0);
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
DocVector docs = (DocVector) block.asVector();
try {
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs).asBlock();
} else {
return evalSlow(docs).asBlock();
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
/**
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
* <p>
* ESQL receives documents in DocVector, and they can be in one of two
* states. Either the DocVector contains documents from a single segment
* non-decreasing order, or it doesn't. that first case is much more like
* how Lucene likes to process documents. and it's much more common. So we
* optimize for it.
* <p>
* Vectors that are {@link DocVector#singleSegmentNonDecreasing()}
* represent many documents from a single Lucene segment. In Elasticsearch
* terms that's a segment in a single shard. And the document ids are in
* non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}.
* But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are
* very like how lucene "natively" processes documents and this optimizes
* those accesses.
* </p>
* <p>
* If the documents are literally {@code 0, 1, ... n} then we use
* {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes
* a whole range. This should be quite common in the case where we don't
* have deleted documents because that's the order that
* {@link LuceneSourceOperator} produces them.
* </p>
* <p>
* If there are gaps in the sequence we use {@link Scorer} calls to
* score the sequence. This'll be less fast but isn't going be particularly
* common.
* </p>
*/
private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
ShardState shardState = shardState(docs.shards().getInt(0));
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
int min = docs.docs().getInt(0);
int max = docs.docs().getInt(docs.getPositionCount() - 1);
int length = max - min + 1;
try (T scoreBuilder = createVectorBuilder(blockFactory, length)) {
if (length == docs.getPositionCount() && length > 1) {
return segmentState.scoreDense(scoreBuilder, min, max);
}
return segmentState.scoreSparse(scoreBuilder, docs.docs());
}
}
/**
* Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See
* {@link #evalSingleSegmentNonDecreasing} for the meaning of
* {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently
* evaluate those segments.
* <p>
* This processes the worst case blocks of documents. These can be from any
* number of shards and any number of segments and in any order. We do this
* by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order.
* So, that's segment by segment, docs ascending. We build a boolean block
* out of that. Then we <strong>sort</strong> that to put the booleans in
* the order that the {@link DocVector} came in.
* </p>
*/
private Vector evalSlow(DocVector docs) throws IOException {
int[] map = docs.shardSegmentDocMapForwards();
// Clear any state flags from the previous run
int prevShard = -1;
int prevSegment = -1;
SegmentState segmentState = null;
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
for (int i = 0; i < docs.getPositionCount(); i++) {
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
int segment = docs.segments().getInt(map[i]);
if (segmentState == null || prevShard != shard || prevSegment != segment) {
segmentState = shardState(shard).segmentState(segment);
segmentState.initScorer(docs.docs().getInt(map[i]));
prevShard = shard;
prevSegment = segment;
}
if (segmentState.noMatch) {
appendNoMatch(scoreBuilder);
} else {
segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
}
}
try (Vector outOfOrder = scoreBuilder.build()) {
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
}
}
}
@Override
public void close() {}
private ShardState shardState(int shard) throws IOException {
ShardState shardState = perShardState.get(shard);
if (shardState != null) {
return shardState;
}
shardState = new ShardState(shards[shard]);
perShardState.set(shard, shardState);
return shardState;
}
/**
* Contains shard search related information, like the weight and index searcher
*/
private class ShardState {
private final Weight weight;
private final IndexSearcher searcher;
private final List<SegmentState> perSegmentState;
ShardState(ShardConfig config) throws IOException {
weight = config.searcher.createWeight(config.query, scoreMode(), 1.0f);
searcher = config.searcher;
perSegmentState = new ArrayList<>(Collections.nCopies(searcher.getLeafContexts().size(), null));
}
SegmentState segmentState(int segment) throws IOException {
SegmentState segmentState = perSegmentState.get(segment);
if (segmentState != null) {
return segmentState;
}
segmentState = new SegmentState(weight, searcher.getLeafContexts().get(segment));
perSegmentState.set(segment, segmentState);
return segmentState;
}
}
/**
* Contains segment search related information, like the leaf reader context and bulk scorer
*/
private class SegmentState {
private final Weight weight;
private final LeafReaderContext ctx;
/**
* Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized
* <strong>or</strong> that {@link #noMatch} is true.
*/
private Scorer scorer;
/**
* Thread that initialized the {@link #scorer}.
*/
private Thread scorerThread;
/**
* Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized
* <strong>or</strong> that {@link #noMatch} is true.
*/
private BulkScorer bulkScorer;
/**
* Thread that initialized the {@link #bulkScorer}.
*/
private Thread bulkScorerThread;
/**
* Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer},
* the {@link Weight} tells us there aren't any matches.
*/
private boolean noMatch;
private SegmentState(Weight weight, LeafReaderContext ctx) {
this.weight = weight;
this.ctx = ctx;
}
/**
* Score a range using the {@link BulkScorer}. This should be faster
* than using {@link #scoreSparse} for dense doc ids.
*/
Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
if (noMatch) {
return createNoMatchVector(blockFactory, max - min + 1);
}
if (bulkScorer == null || // The bulkScorer wasn't initialized
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
) {
bulkScorerThread = Thread.currentThread();
bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer == null) {
noMatch = true;
return createNoMatchVector(blockFactory, max - min + 1);
}
}
try (
DenseCollector<T> collector = new DenseCollector<>(
min,
max,
scoreBuilder,
LuceneQueryEvaluator.this::appendNoMatch,
LuceneQueryEvaluator.this::appendMatch
)
) {
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
return collector.build();
}
}
/**
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
* doc ids it'd be faster to use {@link #scoreDense}.
*/
Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
initScorer(docs.getInt(0));
if (noMatch) {
return createNoMatchVector(blockFactory, docs.getPositionCount());
}
for (int i = 0; i < docs.getPositionCount(); i++) {
scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
}
return scoreBuilder.build();
}
private void initScorer(int minDocId) throws IOException {
if (noMatch) {
return;
}
if (scorer == null || // Scorer not initialized
scorerThread != Thread.currentThread() || // Scorer initialized on a different thread
scorer.iterator().docID() > minDocId // The previous block came "after" this one
) {
scorerThread = Thread.currentThread();
scorer = weight.scorer(ctx);
if (scorer == null) {
noMatch = true;
}
}
}
private void scoreSingleDocWithScorer(T builder, int doc) throws IOException {
if (scorer.iterator().docID() == doc) {
appendMatch(builder, scorer);
} else if (scorer.iterator().docID() > doc) {
appendNoMatch(builder);
} else {
if (scorer.iterator().advance(doc) == doc) {
appendMatch(builder, scorer);
} else {
appendNoMatch(builder);
}
}
}
}
/**
* Collects matching information for dense range of doc ids. This assumes that
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
* which isn't documented, but @jpountz swears is true.
*/
static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
private final U scoreBuilder;
private final int max;
private final Consumer<U> appendNoMatch;
private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;
private Scorable scorer;
int next;
DenseCollector(
int min,
int max,
U scoreBuilder,
Consumer<U> appendNoMatch,
CheckedBiConsumer<U, Scorable, IOException> appendMatch
) {
this.scoreBuilder = scoreBuilder;
this.max = max;
next = min;
this.appendNoMatch = appendNoMatch;
this.appendMatch = appendMatch;
}
@Override
public void setScorer(Scorable scorable) {
this.scorer = scorable;
}
@Override
public void collect(int doc) throws IOException {
while (next++ < doc) {
appendNoMatch.accept(scoreBuilder);
}
appendMatch.accept(scoreBuilder, scorer);
}
public Vector build() {
return scoreBuilder.build();
}
@Override
public void finish() {
while (next++ <= max) {
appendNoMatch.accept(scoreBuilder);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(scoreBuilder);
}
}
/**
* Returns the score mode to use on searches
*/
protected abstract ScoreMode scoreMode();
/**
* Creates a vector where all positions correspond to elements that don't match the query
*/
protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);
/**
* Creates the corresponding vector builder to store the results of evaluating the query
*/
protected abstract T createVectorBuilder(BlockFactory blockFactory, int size);
/**
* Appends a matching result to a builder created by @link createVectorBuilder}
*/
protected abstract void appendMatch(T builder, Scorable scorer) throws IOException;
/**
* Appends a non matching result to a builder created by @link createVectorBuilder}
*/
protected abstract void appendNoMatch(T builder);
}

View File

@ -7,350 +7,64 @@
package org.elasticsearch.compute.lucene;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import java.io.IOException;
import java.io.UncheckedIOException;
/**
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
* the compute engine's normal execution, yielding matches/does not match into
* a {@link BooleanVector}. It's much faster to push these to the
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
* this evaluator is here to save the day.
* a {@link BooleanVector}.
* @see LuceneQueryScoreEvaluator
*/
public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator {
public record ShardConfig(Query query, IndexSearcher searcher) {}
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
implements
EvalOperator.ExpressionEvaluator {
private final BlockFactory blockFactory;
private final ShardConfig[] shards;
private ShardState[] perShardState = EMPTY_SHARD_STATES;
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
this.blockFactory = blockFactory;
this.shards = shards;
LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
super(blockFactory, shards);
}
@Override
public Block eval(Page page) {
// Lucene based operators retrieve DocVectors as first block
Block block = page.getBlock(0);
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
DocVector docs = (DocVector) block.asVector();
try {
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs).asBlock();
} else {
return evalSlow(docs).asBlock();
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
/**
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
* <p>
* ESQL receives documents in DocVector, and they can be in one of two
* states. Either the DocVector contains documents from a single segment
* non-decreasing order, or it doesn't. that first case is much more like
* how Lucene likes to process documents. and it's much more common. So we
* optimize for it.
* <p>
* Vectors that are {@link DocVector#singleSegmentNonDecreasing()}
* represent many documents from a single Lucene segment. In Elasticsearch
* terms that's a segment in a single shard. And the document ids are in
* non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}.
* But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are
* very like how lucene "natively" processes documents and this optimizes
* those accesses.
* </p>
* <p>
* If the documents are literally {@code 0, 1, ... n} then we use
* {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes
* a whole range. This should be quite common in the case where we don't
* have deleted documents because that's the order that
* {@link LuceneSourceOperator} produces them.
* </p>
* <p>
* If there are gaps in the sequence we use {@link Scorer} calls to
* score the sequence. This'll be less fast but isn't going be particularly
* common.
* </p>
*/
private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
ShardState shardState = shardState(docs.shards().getInt(0));
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
int min = docs.docs().getInt(0);
int max = docs.docs().getInt(docs.getPositionCount() - 1);
int length = max - min + 1;
if (length == docs.getPositionCount() && length > 1) {
return segmentState.scoreDense(min, max);
}
return segmentState.scoreSparse(docs.docs());
}
/**
* Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See
* {@link #evalSingleSegmentNonDecreasing} for the meaning of
* {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently
* evaluate those segments.
* <p>
* This processes the worst case blocks of documents. These can be from any
* number of shards and any number of segments and in any order. We do this
* by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order.
* So, that's segment by segment, docs ascending. We build a boolean block
* out of that. Then we <strong>sort</strong> that to put the booleans in
* the order that the {@link DocVector} came in.
* </p>
*/
private BooleanVector evalSlow(DocVector docs) throws IOException {
int[] map = docs.shardSegmentDocMapForwards();
// Clear any state flags from the previous run
int prevShard = -1;
int prevSegment = -1;
SegmentState segmentState = null;
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
for (int i = 0; i < docs.getPositionCount(); i++) {
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
int segment = docs.segments().getInt(map[i]);
if (segmentState == null || prevShard != shard || prevSegment != segment) {
segmentState = shardState(shard).segmentState(segment);
segmentState.initScorer(docs.docs().getInt(map[i]));
prevShard = shard;
prevSegment = segment;
}
if (segmentState.noMatch) {
builder.appendBoolean(false);
} else {
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
}
}
try (BooleanVector outOfOrder = builder.build()) {
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
}
}
return executeQuery(page);
}
@Override
public void close() {
protected ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
private ShardState shardState(int shard) throws IOException {
if (shard >= perShardState.length) {
perShardState = ArrayUtil.grow(perShardState, shard + 1);
} else if (perShardState[shard] != null) {
return perShardState[shard];
}
perShardState[shard] = new ShardState(shards[shard]);
return perShardState[shard];
@Override
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
return blockFactory.newConstantBooleanVector(false, size);
}
private class ShardState {
private final Weight weight;
private final IndexSearcher searcher;
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
ShardState(ShardConfig config) throws IOException {
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f);
searcher = config.searcher;
}
SegmentState segmentState(int segment) throws IOException {
if (segment >= perSegmentState.length) {
perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1);
} else if (perSegmentState[segment] != null) {
return perSegmentState[segment];
}
perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment));
return perSegmentState[segment];
}
@Override
protected BooleanVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newBooleanVectorFixedBuilder(size);
}
private class SegmentState {
private final Weight weight;
private final LeafReaderContext ctx;
/**
* Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized
* <strong>or</strong> that {@link #noMatch} is true.
*/
private Scorer scorer;
/**
* Thread that initialized the {@link #scorer}.
*/
private Thread scorerThread;
/**
* Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized
* <strong>or</strong> that {@link #noMatch} is true.
*/
private BulkScorer bulkScorer;
/**
* Thread that initialized the {@link #bulkScorer}.
*/
private Thread bulkScorerThread;
/**
* Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer},
* the {@link Weight} tells us there aren't any matches.
*/
private boolean noMatch;
private SegmentState(Weight weight, LeafReaderContext ctx) {
this.weight = weight;
this.ctx = ctx;
}
/**
* Score a range using the {@link BulkScorer}. This should be faster
* than using {@link #scoreSparse} for dense doc ids.
*/
BooleanVector scoreDense(int min, int max) throws IOException {
int length = max - min + 1;
if (noMatch) {
return blockFactory.newConstantBooleanVector(false, length);
}
if (bulkScorer == null || // The bulkScorer wasn't initialized
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
) {
bulkScorerThread = Thread.currentThread();
bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer == null) {
noMatch = true;
return blockFactory.newConstantBooleanVector(false, length);
}
}
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
return collector.build();
}
}
/**
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
* doc ids it'd be faster to use {@link #scoreDense}.
*/
BooleanVector scoreSparse(IntVector docs) throws IOException {
initScorer(docs.getInt(0));
if (noMatch) {
return blockFactory.newConstantBooleanVector(false, docs.getPositionCount());
}
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
for (int i = 0; i < docs.getPositionCount(); i++) {
scoreSingleDocWithScorer(builder, docs.getInt(i));
}
return builder.build();
}
}
private void initScorer(int minDocId) throws IOException {
if (noMatch) {
return;
}
if (scorer == null || // Scorer not initialized
scorerThread != Thread.currentThread() || // Scorer initialized on a different thread
scorer.iterator().docID() > minDocId // The previous block came "after" this one
) {
scorerThread = Thread.currentThread();
scorer = weight.scorer(ctx);
if (scorer == null) {
noMatch = true;
}
}
}
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException {
if (scorer.iterator().docID() == doc) {
builder.appendBoolean(true);
} else if (scorer.iterator().docID() > doc) {
builder.appendBoolean(false);
} else {
builder.appendBoolean(scorer.iterator().advance(doc) == doc);
}
}
@Override
protected void appendNoMatch(BooleanVector.Builder builder) {
builder.appendBoolean(false);
}
private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0];
private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0];
/**
* Collects matching information for dense range of doc ids. This assumes that
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
* which isn't documented, but @jpountz swears is true.
*/
static class DenseCollector implements LeafCollector, Releasable {
private final BooleanVector.FixedBuilder builder;
private final int max;
int next;
DenseCollector(BlockFactory blockFactory, int min, int max) {
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
this.max = max;
next = min;
}
@Override
public void setScorer(Scorable scorable) {}
@Override
public void collect(int doc) {
while (next++ < doc) {
builder.appendBoolean(false);
}
builder.appendBoolean(true);
}
public BooleanVector build() {
return builder.build();
}
@Override
public void finish() {
while (next++ <= max) {
builder.appendBoolean(false);
}
}
@Override
public void close() {
Releasables.closeExpectNoException(builder);
}
@Override
protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
builder.appendBoolean(true);
}
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
private final ShardConfig[] shardConfigs;
public Factory(ShardConfig[] shardConfigs) {
this.shardConfigs = shardConfigs;
}
public record Factory(ShardConfig[] shardConfigs) implements EvalOperator.ExpressionEvaluator.Factory {
@Override
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.lucene;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ScoreOperator;
import java.io.IOException;
/**
* {@link ScoreOperator.ExpressionScorer} to run a Lucene {@link Query} during
* the compute engine's normal execution, yielding the corresponding scores into
* a {@link DoubleVector}.
* Elements that don't match will have a score of {@link #NO_MATCH_SCORE}.
* @see LuceneQueryScoreEvaluator
*/
public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleVector.Builder> implements ScoreOperator.ExpressionScorer {
public static final double NO_MATCH_SCORE = 0.0;
LuceneQueryScoreEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
super(blockFactory, shards);
}
@Override
public DoubleBlock score(Page page) {
return (DoubleBlock) executeQuery(page);
}
@Override
protected ScoreMode scoreMode() {
return ScoreMode.COMPLETE;
}
@Override
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, size);
}
@Override
protected DoubleVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newDoubleVectorFixedBuilder(size);
}
@Override
protected void appendNoMatch(DoubleVector.Builder builder) {
builder.appendDouble(NO_MATCH_SCORE);
}
@Override
protected void appendMatch(DoubleVector.Builder builder, Scorable scorer) throws IOException {
builder.appendDouble(scorer.score());
}
public record Factory(ShardConfig[] shardConfigs) implements ScoreOperator.ExpressionScorer.Factory {
@Override
public ScoreOperator.ExpressionScorer get(DriverContext context) {
return new LuceneQueryScoreEvaluator(context.blockFactory(), shardConfigs);
}
}
}

View File

@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
* Evaluates scores for a ExpressionScorer. The scores are added to the existing scores in the input page
*/
public class ScoreOperator extends AbstractPageMappingOperator {
public record ScoreOperatorFactory(ExpressionScorer.Factory scorerFactory, int scoreBlockPosition) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new ScoreOperator(driverContext.blockFactory(), scorerFactory.get(driverContext), scoreBlockPosition);
}
@Override
public String describe() {
return "ScoreOperator[scorer=" + scorerFactory + "]";
}
}
private final BlockFactory blockFactory;
private final ExpressionScorer scorer;
private final int scoreBlockPosition;
public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int scoreBlockPosition) {
this.blockFactory = blockFactory;
this.scorer = scorer;
this.scoreBlockPosition = scoreBlockPosition;
}
@Override
protected Page process(Page page) {
assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
Block[] blocks = new Block[page.getBlockCount()];
for (int i = 0; i < page.getBlockCount(); i++) {
if (i == scoreBlockPosition) {
blocks[i] = calculateScoresBlock(page);
} else {
blocks[i] = page.getBlock(i);
}
}
return new Page(blocks);
}
private Block calculateScoresBlock(Page page) {
try (DoubleBlock evalScores = scorer.score(page); DoubleBlock existingScores = page.getBlock(scoreBlockPosition)) {
// TODO Optimize for constant scores?
int rowCount = page.getPositionCount();
DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(rowCount);
for (int i = 0; i < rowCount; i++) {
builder.appendDouble(existingScores.getDouble(i) + evalScores.getDouble(i));
}
return builder.build().asBlock();
}
}
@Override
public String toString() {
return getClass().getSimpleName() + "[scorer=" + scorer + "]";
}
@Override
public void close() {
Releasables.closeExpectNoException(scorer, super::close);
}
/**
* Evaluates the score of an expression one {@link Page} at a time.
*/
public interface ExpressionScorer extends Releasable {
/** A Factory for creating ExpressionScorers. */
interface Factory {
ExpressionScorer get(DriverContext context);
}
/**
* Scores the expression.
* @return the returned Block has its own reference and the caller is responsible for releasing it.
*/
DoubleBlock score(Page page);
}
}

View File

@ -0,0 +1,302 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.lucene;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KeywordField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.TermInSetQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.OperatorTests;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.ShuffleDocsOperator;
import org.elasticsearch.compute.test.ComputeTestCase;
import org.elasticsearch.compute.test.OperatorTestCase;
import org.elasticsearch.compute.test.TestDriverFactory;
import org.elasticsearch.compute.test.TestResultPageSinkOperator;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.index.mapper.BlockDocValuesReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize;
import static org.hamcrest.Matchers.equalTo;
/**
* Base class for testing Lucene query evaluators.
*/
public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vector.Builder> extends ComputeTestCase {
private static final String FIELD = "g";
@SuppressWarnings("unchecked")
public void testDenseCollectorSmall() throws IOException {
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 2)) {
collector.setScorer(getScorer());
collector.collect(0);
collector.collect(1);
collector.collect(2);
collector.finish();
try (T result = (T) collector.build()) {
for (int i = 0; i <= 2; i++) {
assertCollectedResultMatch(result, i, true);
}
}
}
}
@SuppressWarnings("unchecked")
public void testDenseCollectorSimple() throws IOException {
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 10)) {
collector.setScorer(getScorer());
collector.collect(2);
collector.collect(5);
collector.finish();
try (T result = (T) collector.build()) {
for (int i = 0; i < 11; i++) {
assertCollectedResultMatch(result, i, i == 2 || i == 5);
}
}
}
}
@SuppressWarnings("unchecked")
public void testDenseCollector() throws IOException {
int length = between(1, 10_000);
int min = between(0, Integer.MAX_VALUE - length - 1);
int max = min + length;
boolean[] expected = new boolean[length];
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(min, max)) {
collector.setScorer(getScorer());
for (int i = 0; i < length; i++) {
expected[i] = randomBoolean();
if (expected[i]) {
collector.collect(min + i);
}
}
collector.finish();
try (T result = (T) collector.build()) {
for (int i = 0; i < length; i++) {
assertCollectedResultMatch(result, i, expected[i]);
}
}
}
}
/**
* Create a dense collector for the given range.
*/
protected abstract LuceneQueryEvaluator.DenseCollector<U> createDenseCollector(int min, int max);
/**
* Chceks that the collected results at the given position corresponds to a match or no match
*/
protected abstract void assertCollectedResultMatch(T resultVector, int position, boolean isMatch);
public void testTermQuery() throws IOException {
Set<String> values = values();
String term = values.iterator().next();
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), false);
assertTermsQuery(results, Set.of(term), 1);
}
public void testTermQueryShuffled() throws IOException {
Set<String> values = values();
String term = values.iterator().next();
List<Page> results = runQuery(values, new ConstantScoreQuery(new TermQuery(new Term(FIELD, term))), true);
assertTermsQuery(results, Set.of(term), 1);
}
public void testTermsQuery() throws IOException {
testTermsQuery(false);
}
public void testTermsQueryShuffled() throws IOException {
testTermsQuery(true);
}
private void testTermsQuery(boolean shuffleDocs) throws IOException {
Set<String> values = values();
Iterator<String> itr = values.iterator();
TreeSet<String> matching = new TreeSet<>();
TreeSet<BytesRef> matchingBytes = new TreeSet<>();
int expectedMatchCount = between(2, values.size());
for (int i = 0; i < expectedMatchCount; i++) {
String v = itr.next();
matching.add(v);
matchingBytes.add(new BytesRef(v));
}
List<Page> results = runQuery(values, new TermInSetQuery(MultiTermQuery.CONSTANT_SCORE_REWRITE, FIELD, matchingBytes), shuffleDocs);
assertTermsQuery(results, matching, expectedMatchCount);
}
protected void assertTermsQuery(List<Page> results, Set<String> matching, int expectedMatchCount) {
int matchCount = 0;
for (Page page : results) {
int initialBlockIndex = termsBlockIndex(page);
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
@SuppressWarnings("unchecked")
T resultVector = (T) page.getBlock(resultsBlockIndex(page)).asVector();
for (int i = 0; i < page.getPositionCount(); i++) {
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
boolean isMatch = matching.contains(termAtPosition.utf8ToString());
assertTermResultMatch(resultVector, i, isMatch);
if (isMatch) {
matchCount++;
}
}
}
assertThat(matchCount, equalTo(expectedMatchCount));
}
/**
* Checks that the result at the given position corresponds to a term match or no match
*/
protected abstract void assertTermResultMatch(T resultVector, int position, boolean isMatch);
private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs) throws IOException {
DriverContext driverContext = driverContext();
BlockFactory blockFactory = driverContext.blockFactory();
return withReader(values, reader -> {
IndexSearcher searcher = new IndexSearcher(reader);
LuceneQueryEvaluator.ShardConfig shard = new LuceneQueryEvaluator.ShardConfig(searcher.rewrite(query), searcher);
List<Operator> operators = new ArrayList<>();
if (shuffleDocs) {
operators.add(new ShuffleDocsOperator(blockFactory));
}
operators.add(
new ValuesSourceReaderOperator(
blockFactory,
List.of(
new ValuesSourceReaderOperator.FieldInfo(
FIELD,
ElementType.BYTES_REF,
unused -> new BlockDocValuesReader.BytesRefsFromOrdsBlockLoader(FIELD)
)
),
List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> {
throw new UnsupportedOperationException();
})),
0
)
);
LuceneQueryEvaluator.ShardConfig[] shards = new LuceneQueryEvaluator.ShardConfig[] {
new LuceneQueryEvaluator.ShardConfig(searcher.rewrite(query), searcher) };
operators.add(createOperator(blockFactory, shards));
List<Page> results = new ArrayList<>();
Driver driver = TestDriverFactory.create(
driverContext,
LuceneQueryEvaluatorTests.luceneOperatorFactory(reader, new MatchAllDocsQuery(), usesScoring()).get(driverContext),
operators,
new TestResultPageSinkOperator(results::add)
);
OperatorTestCase.runDriver(driver);
OperatorTests.assertDriverContext(driverContext);
return results;
});
}
private <T> T withReader(Set<String> values, CheckedFunction<DirectoryReader, T, IOException> run) throws IOException {
try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
for (String value : values) {
writer.addDocument(List.of(new KeywordField(FIELD, value, Field.Store.NO)));
}
writer.commit();
try (DirectoryReader reader = writer.getReader()) {
return run.apply(reader);
}
}
}
private Set<String> values() {
int maxNumDocs = between(10, 1_000);
int keyLength = randomIntBetween(1, 10);
Set<String> values = new HashSet<>();
for (int i = 0; i < maxNumDocs; i++) {
values.add(randomAlphaOfLength(keyLength));
}
return values;
}
/**
* A {@link DriverContext} with a non-breaking-BigArrays.
*/
private DriverContext driverContext() {
BlockFactory blockFactory = blockFactory();
return new DriverContext(blockFactory.bigArrays(), blockFactory);
}
// Returns the initial block index, ignoring the score block if scoring is enabled
protected int termsBlockIndex(Page page) {
assert page.getBlock(0) instanceof DocBlock : "expected doc block at index 0";
if (usesScoring()) {
assert page.getBlock(1) instanceof DoubleBlock : "expected double block at index 1";
return 2;
} else {
return 1;
}
}
private static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, boolean scoring) {
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
return new LuceneSourceOperator.Factory(
List.of(searchContext),
ctx -> query,
randomFrom(DataPartitioning.values()),
randomIntBetween(1, 10),
randomPageSize(),
LuceneOperator.NO_LIMIT,
scoring
);
}
// Returns the block index for the results to check
protected abstract int resultsBlockIndex(Page page);
/**
* Returns a test scorer to use for scoring docs. Can be null
*/
protected abstract Scorable getScorer();
/**
* Create the operator to test
*/
protected abstract Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards);
/**
* Should the test use scoring?
*/
protected abstract boolean usesScoring();
}

View File

@ -7,275 +7,59 @@
package org.elasticsearch.compute.lucene;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KeywordField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermInSetQuery;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.OperatorTests;
import org.apache.lucene.search.Scorable;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.DenseCollector;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.DenseCollector;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.ShuffleDocsOperator;
import org.elasticsearch.compute.test.ComputeTestCase;
import org.elasticsearch.compute.test.OperatorTestCase;
import org.elasticsearch.compute.test.TestDriverFactory;
import org.elasticsearch.compute.test.TestResultPageSinkOperator;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.index.mapper.BlockDocValuesReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize;
import static org.hamcrest.Matchers.equalTo;
public class LuceneQueryExpressionEvaluatorTests extends ComputeTestCase {
private static final String FIELD = "g";
public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTests<BooleanVector, BooleanVector.Builder> {
public void testDenseCollectorSmall() {
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 2)) {
collector.collect(0);
collector.collect(1);
collector.collect(2);
collector.finish();
try (BooleanVector result = collector.build()) {
for (int i = 0; i <= 2; i++) {
assertThat(result.getBoolean(i), equalTo(true));
}
}
}
}
private final boolean useScoring = randomBoolean();
public void testDenseCollectorSimple() {
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 10)) {
collector.collect(2);
collector.collect(5);
collector.finish();
try (BooleanVector result = collector.build()) {
for (int i = 0; i < 11; i++) {
assertThat(result.getBoolean(i), equalTo(i == 2 || i == 5));
}
}
}
}
public void testDenseCollector() {
int length = between(1, 10_000);
int min = between(0, Integer.MAX_VALUE - length - 1);
int max = min + length + 1;
boolean[] expected = new boolean[length];
try (DenseCollector collector = new DenseCollector(blockFactory(), min, max)) {
for (int i = 0; i < length; i++) {
expected[i] = randomBoolean();
if (expected[i]) {
collector.collect(min + i);
}
}
collector.finish();
try (BooleanVector result = collector.build()) {
for (int i = 0; i < length; i++) {
assertThat(result.getBoolean(i), equalTo(expected[i]));
}
}
}
}
public void testTermQuery() throws IOException {
Set<String> values = values();
String term = values.iterator().next();
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), false);
assertTermQuery(term, results);
}
public void testTermQueryShuffled() throws IOException {
Set<String> values = values();
String term = values.iterator().next();
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), true);
assertTermQuery(term, results);
}
private void assertTermQuery(String term, List<Page> results) {
int matchCount = 0;
for (Page page : results) {
int initialBlockIndex = initialBlockIndex(page);
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
for (int i = 0; i < page.getPositionCount(); i++) {
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
assertThat(matches.getBoolean(i), equalTo(termAtPosition.utf8ToString().equals(term)));
if (matches.getBoolean(i)) {
matchCount++;
}
}
}
assertThat(matchCount, equalTo(1));
}
public void testTermsQuery() throws IOException {
testTermsQuery(false);
}
public void testTermsQueryShuffled() throws IOException {
testTermsQuery(true);
}
private void testTermsQuery(boolean shuffleDocs) throws IOException {
Set<String> values = values();
Iterator<String> itr = values.iterator();
TreeSet<String> matching = new TreeSet<>();
TreeSet<BytesRef> matchingBytes = new TreeSet<>();
int expectedMatchCount = between(2, values.size());
for (int i = 0; i < expectedMatchCount; i++) {
String v = itr.next();
matching.add(v);
matchingBytes.add(new BytesRef(v));
}
List<Page> results = runQuery(values, new TermInSetQuery(MultiTermQuery.CONSTANT_SCORE_REWRITE, FIELD, matchingBytes), shuffleDocs);
int matchCount = 0;
for (Page page : results) {
int initialBlockIndex = initialBlockIndex(page);
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
for (int i = 0; i < page.getPositionCount(); i++) {
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
assertThat(matches.getBoolean(i), equalTo(matching.contains(termAtPosition.utf8ToString())));
if (matches.getBoolean(i)) {
matchCount++;
}
}
}
assertThat(matchCount, equalTo(expectedMatchCount));
}
private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs) throws IOException {
DriverContext driverContext = driverContext();
BlockFactory blockFactory = driverContext.blockFactory();
return withReader(values, reader -> {
IndexSearcher searcher = new IndexSearcher(reader);
LuceneQueryExpressionEvaluator.ShardConfig shard = new LuceneQueryExpressionEvaluator.ShardConfig(
searcher.rewrite(query),
searcher
);
LuceneQueryExpressionEvaluator luceneQueryEvaluator = new LuceneQueryExpressionEvaluator(
blockFactory,
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard }
);
List<Operator> operators = new ArrayList<>();
if (shuffleDocs) {
operators.add(new ShuffleDocsOperator(blockFactory));
}
operators.add(
new ValuesSourceReaderOperator(
blockFactory,
List.of(
new ValuesSourceReaderOperator.FieldInfo(
FIELD,
ElementType.BYTES_REF,
unused -> new BlockDocValuesReader.BytesRefsFromOrdsBlockLoader(FIELD)
)
),
List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> {
throw new UnsupportedOperationException();
})),
0
)
);
operators.add(new EvalOperator(blockFactory, luceneQueryEvaluator));
List<Page> results = new ArrayList<>();
Driver driver = TestDriverFactory.create(
driverContext,
luceneOperatorFactory(reader, new MatchAllDocsQuery(), LuceneOperator.NO_LIMIT, scoring).get(driverContext),
operators,
new TestResultPageSinkOperator(results::add)
);
OperatorTestCase.runDriver(driver);
OperatorTests.assertDriverContext(driverContext);
return results;
});
}
private <T> T withReader(Set<String> values, CheckedFunction<DirectoryReader, T, IOException> run) throws IOException {
try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
for (String value : values) {
writer.addDocument(List.of(new KeywordField(FIELD, value, Field.Store.NO)));
}
writer.commit();
try (DirectoryReader reader = writer.getReader()) {
return run.apply(reader);
}
}
}
private Set<String> values() {
int maxNumDocs = between(10, 1_000);
int keyLength = randomIntBetween(1, 10);
Set<String> values = new HashSet<>();
for (int i = 0; i < maxNumDocs; i++) {
values.add(randomAlphaOfLength(keyLength));
}
return values;
}
/**
* A {@link DriverContext} with a non-breaking-BigArrays.
*/
private DriverContext driverContext() {
BlockFactory blockFactory = blockFactory();
return new DriverContext(blockFactory.bigArrays(), blockFactory);
}
// Scores are not interesting to this test, but enabled conditionally and effectively ignored just for coverage.
private final boolean scoring = randomBoolean();
// Returns the initial block index, ignoring the score block if scoring is enabled
private int initialBlockIndex(Page page) {
assert page.getBlock(0) instanceof DocBlock : "expected doc block at index 0";
if (scoring) {
assert page.getBlock(1) instanceof DoubleBlock : "expected double block at index 1";
return 2;
} else {
return 1;
}
}
static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, int limit, boolean scoring) {
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
return new LuceneSourceOperator.Factory(
List.of(searchContext),
ctx -> query,
randomFrom(DataPartitioning.values()),
randomIntBetween(1, 10),
randomPageSize(),
limit,
scoring
@Override
protected DenseCollector<BooleanVector.Builder> createDenseCollector(int min, int max) {
return new LuceneQueryEvaluator.DenseCollector<>(
min,
max,
blockFactory().newBooleanVectorFixedBuilder(max - min + 1),
b -> b.appendBoolean(false),
(b, s) -> b.appendBoolean(true)
);
}
@Override
protected Scorable getScorer() {
return null;
}
@Override
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
return new EvalOperator(blockFactory, new LuceneQueryExpressionEvaluator(blockFactory, shards));
}
@Override
protected boolean usesScoring() {
// Be consistent for a single test execution
return useScoring;
}
@Override
protected int resultsBlockIndex(Page page) {
return page.getBlockCount() - 1;
}
@Override
protected void assertCollectedResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
assertThat(resultVector.getBoolean(position), equalTo(isMatch));
}
@Override
protected void assertTermResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
assertThat(resultVector.getBoolean(position), equalTo(isMatch));
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.lucene;
import org.apache.lucene.search.Scorable;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.ScoreOperator;
import java.io.IOException;
import static org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator.NO_MATCH_SCORE;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<DoubleVector, DoubleVector.Builder> {
private static final float TEST_SCORE = 1.5f;
private static final Double DEFAULT_SCORE = 1.0;
@Override
protected LuceneQueryEvaluator.DenseCollector<DoubleVector.Builder> createDenseCollector(int min, int max) {
return new LuceneQueryEvaluator.DenseCollector<>(
min,
max,
blockFactory().newDoubleVectorFixedBuilder(max - min + 1),
b -> b.appendDouble(NO_MATCH_SCORE),
(b, s) -> b.appendDouble(s.score())
);
}
@Override
protected Scorable getScorer() {
return new Scorable() {
@Override
public float score() throws IOException {
return TEST_SCORE;
}
};
}
@Override
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
return new ScoreOperator(blockFactory, new LuceneQueryScoreEvaluator(blockFactory, shards), 1);
}
@Override
protected boolean usesScoring() {
return true;
}
@Override
protected int resultsBlockIndex(Page page) {
// Reuses the score block
return 1;
}
@Override
protected void assertCollectedResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
if (isMatch) {
assertThat(resultVector.getDouble(position), equalTo((double) TEST_SCORE));
} else {
// All docs have a default score coming from Lucene
assertThat(resultVector.getDouble(position), equalTo(NO_MATCH_SCORE));
}
}
@Override
protected void assertTermResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
if (isMatch) {
assertThat(resultVector.getDouble(position), greaterThan(DEFAULT_SCORE));
} else {
// All docs have a default score coming from Lucene
assertThat(resultVector.getDouble(position), equalTo(DEFAULT_SCORE));
}
}
}

View File

@ -357,3 +357,115 @@ _id:keyword
2
3
;
scoresNonPushableFunctions
required_capability: metadata_score
from books metadata _score
| where length(title) > 100
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
2924 | 1.0
8678 | 1.0
;
scoresPushableFunctions
required_capability: metadata_score
from books metadata _score
| where year >= 2017
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
6818 | 1.0
7400 | 1.0
8480 | 1.0
8534 | 1.0
8615 | 1.0
;
conjunctionScoresPushableNonPushableFunctions
required_capability: metadata_score
required_capability: match_function
from books metadata _score
| where match(title, "Lord") and length(title) > 20
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
2675 | 2.5619282722473145
2714 | 1.9245924949645996
7140 | 1.746896743774414
4023 | 1.5062403678894043
;
conjunctionScoresPushableFunctions
required_capability: metadata_score
required_capability: match_function
from books metadata _score
| where match(title, "Lord") and ratings > 4.6
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
7140 | 2.746896743774414
4023 | 2.5062403678894043
;
disjunctionScoresPushableNonPushableFunctions
required_capability: metadata_score
required_capability: match_operator_colon
required_capability: full_text_functions_disjunctions_score
from books metadata _score
| where match(title, "Lord") or length(title) > 100
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
2675 | 3.5619282722473145
2714 | 2.9245924949645996
7140 | 2.746896743774414
4023 | 2.5062403678894043
2924 | 1.0
8678 | 1.0
;
disjunctionScoresMultipleClauses
required_capability: metadata_score
required_capability: match_operator_colon
required_capability: full_text_functions_disjunctions_score
from books metadata _score
| where (title: "Lord" and length(title) > 40) or (author: "Dostoevsky" and length(title) > 40)
| keep book_no, _score
| sort _score desc, book_no asc
;
book_no:keyword | _score:double
8086 | 2.786686897277832
9801 | 2.786686897277832
1937 | 2.1503653526306152
8534 | 2.1503653526306152
2714 | 1.9245924949645996
7140 | 1.746896743774414
4023 | 1.5062403678894043
2924 | 1.2732219696044922
;

View File

@ -0,0 +1,257 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plugin;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.junit.Before;
import java.util.List;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
public class ScoringIT extends AbstractEsqlIntegTestCase {
@Before
public void setupIndex() {
createAndPopulateIndex();
}
public void testDefaultScoring() {
var query = """
FROM test METADATA _score
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(6));
for (int i = 0; i < 6; i++) {
assertThat(values.get(0).get(1), equalTo(1.0));
}
}
}
public void testScoringNonPushableFunctions() {
var query = """
FROM test METADATA _score
| WHERE length(content) < 20
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(2));
assertThat(values.get(0).get(0), equalTo(1));
assertThat(values.get(1).get(0), equalTo(2));
assertThat((Double) values.get(0).get(1), is(1.0));
assertThat((Double) values.get(1).get(1), is(1.0));
}
}
public void testDisjunctionScoring() {
var query = """
FROM test METADATA _score
| WHERE match(content, "fox") OR length(content) < 20
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(3));
assertThat(values.get(0).get(0), equalTo(1));
assertThat(values.get(1).get(0), equalTo(6));
assertThat(values.get(2).get(0), equalTo(2));
// Matches full text query and non pushable query
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
// Matches just non pushable query
assertThat((Double) values.get(2).get(1), equalTo(1.0));
}
}
public void testConjunctionPushableScoring() {
var query = """
FROM test METADATA _score
| WHERE match(content, "fox") AND id > 4
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(1));
assertThat(values.get(0).get(0), equalTo(6));
// Matches full text query and pushable query
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
}
}
public void testConjunctionNonPushableScoring() {
var query = """
FROM test METADATA _score
| WHERE match(content, "fox") AND length(content) < 20
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(1));
assertThat(values.get(0).get(0), equalTo(1));
// Matches full text query and pushable query
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
}
}
public void testDisjunctionScoringPushableFunctions() {
var query = """
FROM test METADATA _score
| WHERE match(content, "fox") OR match(content, "quick")
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(2));
assertThat(values.get(0).get(0), equalTo(6));
assertThat(values.get(1).get(0), equalTo(1));
// Matches both conditions
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
// Matches a single condition
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
}
}
public void testDisjunctionScoringMultipleNonPushableFunctions() {
var query = """
FROM test METADATA _score
| WHERE match(content, "fox") OR length(content) < 20 AND id > 2
| KEEP id, _score
| SORT _score DESC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(2));
assertThat(values.get(0).get(0), equalTo(1));
assertThat(values.get(1).get(0), equalTo(6));
// Matches the full text query and the two pushable query
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
assertThat((Double) values.get(0).get(1), lessThan(3.0));
// Matches just the match function
assertThat((Double) values.get(1).get(1), lessThan(2.0));
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
}
}
public void testDisjunctionScoringWithNot() {
var query = """
FROM test METADATA _score
| WHERE NOT(match(content, "dog")) OR length(content) > 50
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(3));
assertThat(values.get(0).get(0), equalTo(1));
assertThat(values.get(1).get(0), equalTo(4));
assertThat(values.get(2).get(0), equalTo(5));
// Matches NOT gets 0.0 and default score is 1.0
assertThat((Double) values.get(0).get(1), equalTo(1.0));
assertThat((Double) values.get(1).get(1), equalTo(1.0));
assertThat((Double) values.get(2).get(1), equalTo(1.0));
}
}
public void testScoringWithNoFullTextFunction() {
var query = """
FROM test METADATA _score
| WHERE length(content) > 50
| KEEP id, _score
| SORT _score DESC, id ASC
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));
List<List<Object>> values = getValuesList(resp);
assertThat(values.size(), equalTo(1));
assertThat(values.get(0).get(0), equalTo(4));
// Non pushable query gets score of 0.0, summed with 1.0 coming from Lucene
assertThat((Double) values.get(0).get(1), equalTo(1.0));
}
}
private void createAndPopulateIndex() {
var indexName = "test";
var client = client().admin().indices();
var CreateRequest = client.prepareCreate(indexName)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping("id", "type=integer", "content", "type=text");
assertAcked(CreateRequest);
client().prepareBulk()
.add(new IndexRequest(indexName).id("1").source("id", 1, "content", "This is a brown fox"))
.add(new IndexRequest(indexName).id("2").source("id", 2, "content", "This is a brown dog"))
.add(new IndexRequest(indexName).id("3").source("id", 3, "content", "This dog is really brown"))
.add(new IndexRequest(indexName).id("4").source("id", 4, "content", "The dog is brown but this document is very very long"))
.add(new IndexRequest(indexName).id("5").source("id", 5, "content", "There is also a white cat"))
.add(new IndexRequest(indexName).id("6").source("id", 6, "content", "The quick brown fox jumps over the lazy dog"))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
ensureYellow(indexName);
}
}

View File

@ -865,7 +865,12 @@ public class EsqlCapabilities {
/**
* Support for RRF command
*/
RRF(Build.current().isSnapshot());
RRF(Build.current().isSnapshot()),
/**
* Full text functions can be scored when being part of a disjunction
*/
FULL_TEXT_FUNCTIONS_DISJUNCTIONS_SCORE;
private final boolean enabled;

View File

@ -177,7 +177,7 @@ public final class EvalMapper {
static class Nots extends ExpressionMapper<Not> {
@Override
public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layout, List<ShardContext> shardContexts) {
var expEval = toEvaluator(foldCtx, not.field(), layout);
var expEval = toEvaluator(foldCtx, not.field(), layout, shardContexts);
return dvrCtx -> new org.elasticsearch.xpack.esql.evaluator.predicate.operator.logical.NotEvaluator(
not.source(),
expEval.get(dvrCtx),
@ -281,7 +281,7 @@ public final class EvalMapper {
@Override
public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNull isNull, Layout layout, List<ShardContext> shardContexts) {
var field = toEvaluator(foldCtx, isNull.field(), layout);
var field = toEvaluator(foldCtx, isNull.field(), layout, shardContexts);
return new IsNullEvaluatorFactory(field);
}
@ -329,7 +329,7 @@ public final class EvalMapper {
@Override
public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNotNull isNotNull, Layout layout, List<ShardContext> shardContexts) {
return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout));
return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout, shardContexts));
}
record IsNotNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory {

View File

@ -8,27 +8,26 @@
package org.elasticsearch.xpack.esql.expression.function.fulltext;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.ShardConfig;
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator;
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.ShardConfig;
import org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@ -39,6 +38,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.TranslationAwareExpressionQuery;
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
import java.util.List;
import java.util.Locale;
@ -56,7 +56,12 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
* These functions needs to be pushed down to Lucene queries to be executed - there's no Evaluator for them, but depend on
* {@link org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer} to rewrite them into Lucene queries.
*/
public abstract class FullTextFunction extends Function implements TranslationAware, PostAnalysisPlanVerificationAware, EvaluatorMapper {
public abstract class FullTextFunction extends Function
implements
TranslationAware,
PostAnalysisPlanVerificationAware,
EvaluatorMapper,
ExpressionScoreMapper {
private final Expression query;
private final QueryBuilder queryBuilder;
@ -204,13 +209,6 @@ public abstract class FullTextFunction extends Function implements TranslationAw
failures
);
checkFullTextFunctionsParents(condition, failures);
boolean usesScore = plan.output()
.stream()
.anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
if (usesScore) {
checkFullTextSearchDisjunctions(condition, failures);
}
} else {
plan.forEachExpression(FullTextFunction.class, ftf -> {
failures.add(fail(ftf, "[{}] {} is only supported in WHERE commands", ftf.functionName(), ftf.functionType()));
@ -218,65 +216,6 @@ public abstract class FullTextFunction extends Function implements TranslationAw
}
}
/**
* Checks whether a condition contains a disjunction with a full text search.
* If it does, check that every element of the disjunction is a full text search or combinations (AND, OR, NOT) of them.
* If not, add a failure to the failures collection.
*
* @param condition condition to check for disjunctions of full text searches
* @param failures failures collection to add to
*/
private static void checkFullTextSearchDisjunctions(Expression condition, Failures failures) {
Holder<Boolean> isInvalid = new Holder<>(false);
condition.forEachDown(Or.class, or -> {
if (isInvalid.get()) {
// Exit early if we already have a failures
return;
}
if (checkDisjunctionPushable(or) == false) {
isInvalid.set(true);
failures.add(
fail(
or,
"Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, "
+ "but only if just full text functions are used in the OR condition",
or.sourceText()
)
);
}
});
}
/**
* Checks if a disjunction is pushable from the point of view of FullTextFunctions. Either it has no FullTextFunctions or
* all it contains are FullTextFunctions.
*
* @param or disjunction to check
* @return true if the disjunction is pushable, false otherwise
*/
private static boolean checkDisjunctionPushable(Or or) {
boolean hasFullText = or.anyMatch(FullTextFunction.class::isInstance);
return hasFullText == false || onlyFullTextFunctionsInExpression(or);
}
/**
* Checks whether an expression contains just full text functions or negations (NOT) and combinations (AND, OR) of full text functions
*
* @param expression expression to check
* @return true if all children are full text functions or negations of full text functions, false otherwise
*/
private static boolean onlyFullTextFunctionsInExpression(Expression expression) {
if (expression instanceof FullTextFunction) {
return true;
} else if (expression instanceof Not) {
return onlyFullTextFunctionsInExpression(expression.children().get(0));
} else if (expression instanceof BinaryLogic binaryLogic) {
return onlyFullTextFunctionsInExpression(binaryLogic.left()) && onlyFullTextFunctionsInExpression(binaryLogic.right());
}
return false;
}
/**
* Checks all commands that exist before a specific type satisfy conditions.
*
@ -365,4 +304,15 @@ public abstract class FullTextFunction extends Function implements TranslationAw
}
return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
}
@Override
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
List<EsPhysicalOperationProviders.ShardContext> shardContexts = toScorer.shardContexts();
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
int i = 0;
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
}
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
}
}

View File

@ -8,6 +8,11 @@ package org.elasticsearch.xpack.esql.expression.predicate.logical;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@ -22,6 +27,7 @@ import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
import java.io.IOException;
import java.util.Arrays;
@ -29,7 +35,10 @@ import java.util.List;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isBoolean;
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation> implements TranslationAware {
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation>
implements
TranslationAware,
ExpressionScoreMapper {
protected BinaryLogic(Source source, Expression left, Expression right, BinaryLogicOperation operation) {
super(source, left, right, operation);
@ -108,4 +117,33 @@ public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boole
}
return new BoolQuery(source, isAnd, queries);
}
@Override
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
return context -> new BinaryLogicScorer(context, toScorer.toScorer(left()).get(context), toScorer.toScorer(right()).get(context));
}
/**
* Binary logic adds together scores coming from the left and right expressions, both for conjunctions and disjunctions
*/
private record BinaryLogicScorer(DriverContext driverContext, ScoreOperator.ExpressionScorer left, ScoreOperator.ExpressionScorer right)
implements
ScoreOperator.ExpressionScorer {
@Override
public DoubleBlock score(Page page) {
DoubleVector.Builder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(page.getPositionCount());
try (DoubleVector leftVector = left.score(page).asVector(); DoubleVector rightVector = right.score(page).asVector()) {
for (int i = 0; i < page.getPositionCount(); i++) {
builder.appendDouble(leftVector.getDouble(i) + rightVector.getDouble(i));
}
}
return builder.build().asBlock();
}
@Override
public void close() {
left.close();
right.close();
}
}
}

View File

@ -35,6 +35,7 @@ import org.elasticsearch.compute.operator.Operator.OperatorFactory;
import org.elasticsearch.compute.operator.OutputOperator.OutputOperatorFactory;
import org.elasticsearch.compute.operator.RowInTableLookupOperator;
import org.elasticsearch.compute.operator.RrfScoreEvalOperator;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.compute.operator.ShowOperator;
import org.elasticsearch.compute.operator.SinkOperator;
import org.elasticsearch.compute.operator.SinkOperator.SinkOperatorFactory;
@ -105,6 +106,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.score.ScoreMapper;
import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.ArrayList;
@ -729,10 +731,29 @@ public class LocalExecutionPlanner {
private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(filter.child(), context);
// TODO: should this be extracted into a separate eval block?
return source.with(
PhysicalOperation filterOperation = source.with(
new FilterOperatorFactory(EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout, shardContexts)),
source.layout
);
if (PlannerUtils.usesScoring(filter)) {
// Add scorer operator to add the filter expression scores to the overall scores
int scoreBlock = 0;
for (Attribute attribute : filter.output()) {
if (MetadataAttribute.SCORE.equals(attribute.name())) {
break;
}
scoreBlock++;
}
if (scoreBlock == filter.output().size()) {
throw new IllegalStateException("Couldn't find _score attribute in a WHERE clause");
}
filterOperation = filterOperation.with(
new ScoreOperator.ScoreOperatorFactory(ScoreMapper.toScorer(filter.condition(), shardContexts), scoreBlock),
filterOperation.layout
);
}
return filterOperation;
}
private PhysicalOperation planLimit(LimitExec limit, LocalExecutionPlannerContext context) {

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Holder;
@ -32,6 +33,7 @@ import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.plan.QueryPlan;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@ -278,4 +280,8 @@ public class PlannerUtils {
new NoopCircuitBreaker("noop-esql-breaker"),
BigArrays.NON_RECYCLING_INSTANCE
);
public static boolean usesScoring(QueryPlan<?> plan) {
return plan.output().stream().anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
}
}

View File

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.score;
import org.elasticsearch.compute.operator.ScoreOperator.ExpressionScorer;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
import java.util.List;
/**
* Maps expressions that have a mapping to an {@link ExpressionScorer}. Allows for transforming expressions into their corresponding scores.
*/
public interface ExpressionScoreMapper {
interface ToScorer {
ExpressionScorer.Factory toScorer(Expression expression);
default List<EsPhysicalOperationProviders.ShardContext> shardContexts() {
throw new UnsupportedOperationException("Shard contexts should only be needed for scoring operations");
}
}
ExpressionScorer.Factory toScorer(ToScorer toScorer);
}

View File

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.score;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
import java.util.List;
/**
* Maps an expression tree into ExpressionScorer.Factory, so scores can be evaluated for an expression tree.
*/
public class ScoreMapper {
public static ScoreOperator.ExpressionScorer.Factory toScorer(Expression expression, List<ShardContext> shardContexts) {
if (expression instanceof ExpressionScoreMapper mapper) {
return mapper.toScorer(new ExpressionScoreMapper.ToScorer() {
@Override
public ScoreOperator.ExpressionScorer.Factory toScorer(Expression expression) {
return ScoreMapper.toScorer(expression, shardContexts);
}
@Override
public List<ShardContext> shardContexts() {
return shardContexts;
}
});
}
return page -> new DefaultScoreMapper().get(page);
}
public static class DefaultScoreMapper implements ScoreOperator.ExpressionScorer.Factory {
@Override
public ScoreOperator.ExpressionScorer get(DriverContext driverContext) {
return new ScoreOperator.ExpressionScorer() {
@Override
public DoubleBlock score(Page page) {
return driverContext.blockFactory().newConstantDoubleBlockWith(0.0, page.getPositionCount());
}
@Override
public void close() {}
};
}
}
}

View File

@ -284,6 +284,11 @@ public class CsvTests extends ESTestCase {
"CSV tests cannot currently handle the _source field mapping directives",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING.capabilityName())
);
assumeFalse(
"CSV tests cannot currently handle scoring that depends on Lucene",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.METADATA_SCORE.capabilityName())
);
if (Build.current().isSnapshot()) {
assertThat(
"Capability is not included in the enabled list capabilities on a snapshot build. Spelling mistake?",

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.Build;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
@ -1467,11 +1466,12 @@ public class VerifierTests extends ESTestCase {
private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) {
// Disjunctions with non-pushable functions - scoring
checkdisjunctionScoringError("1:35", functionInvocation + " or length(first_name) > 10");
checkdisjunctionScoringError("1:35", "match(last_name, \"Anneke\") or (" + functionInvocation + " and length(first_name) > 10)");
checkdisjunctionScoringError(
"1:35",
"(" + functionInvocation + " and length(first_name) > 0) or (match(last_name, \"Anneke\") and length(first_name) > 10)"
query("from test | where " + functionInvocation + " or length(first_name) > 10");
query("from test | where match(last_name, \"Anneke\") or (" + functionInvocation + " and length(first_name) > 10)");
query(
"from test | where ("
+ functionInvocation
+ " and length(first_name) > 0) or (match(last_name, \"Anneke\") and length(first_name) > 10)"
);
// Disjunctions with non-pushable functions - no scoring
@ -1503,19 +1503,6 @@ public class VerifierTests extends ESTestCase {
}
private void checkdisjunctionScoringError(String position, String expression) {
assertEquals(
LoggerMessageFormat.format(
null,
"{}: Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, "
+ "but only if just full text functions are used in the OR condition",
position,
expression
),
error("from test metadata _score | where " + expression)
);
}
public void testQueryStringFunctionWithNonBooleanFunctions() {
checkFullTextFunctionsWithNonBooleanFunctions("QSTR", "qstr(\"first_name: Anna\")", "function");
}

View File

@ -1704,7 +1704,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
assertThat(queryBuilder.value(), is(123456));
}
public void testMatchFunctionWithPushableConjunction() {
public void testMatchFunctionWithNonPushableConjunction() {
String query = """
from test
| where match(last_name, "Smith") and length(first_name) > 10
@ -1723,6 +1723,24 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
assertThat(esQuery.query(), instanceOf(MatchQueryBuilder.class));
}
public void testMatchFunctionWithPushableConjunction() {
String query = """
from test metadata _score
| where match(last_name, "Smith") and salary > 10000
""";
var plan = plannerOptimizer.plan(query);
var limit = as(plan, LimitExec.class);
var exchange = as(limit.child(), ExchangeExec.class);
var project = as(exchange.child(), ProjectExec.class);
var fieldExtract = as(project.child(), FieldExtractExec.class);
var esQuery = as(fieldExtract.child(), EsQueryExec.class);
Source source = new Source(2, 38, "salary > 10000");
BoolQueryBuilder expected = new BoolQueryBuilder().must(new MatchQueryBuilder("last_name", "Smith").lenient(true))
.must(wrapWithSingleQuery(query, QueryBuilders.rangeQuery("salary").gt(10000), "salary", source));
assertThat(esQuery.query().toString(), equalTo(expected.toString()));
}
public void testMatchFunctionWithNonPushableDisjunction() {
String query = """
from test
@ -1754,7 +1772,6 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
var project = as(exchange.child(), ProjectExec.class);
var fieldExtract = as(project.child(), FieldExtractExec.class);
var esQuery = as(fieldExtract.child(), EsQueryExec.class);
var boolQuery = as(esQuery.query(), BoolQueryBuilder.class);
Source source = new Source(2, 37, "emp_no > 10");
BoolQueryBuilder expected = new BoolQueryBuilder().should(new MatchQueryBuilder("last_name", "Smith").lenient(true))
.should(wrapWithSingleQuery(query, QueryBuilders.rangeQuery("emp_no").gt(10), "emp_no", source));