ES|QL - Add scoring for full text functions disjunctions (#121793)
This commit is contained in:
parent
e11d89d76b
commit
2b40e73fe9
|
@ -0,0 +1,5 @@
|
|||
pr: 121793
|
||||
summary: "ES|QL - Add scoring for full text functions disjunctions"
|
||||
area: ES|QL
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -117,9 +117,7 @@ In addition, when [querying multiple indexes](docs-content://explore-analyze/que
|
|||
|
||||
## Full-text search [esql-limitations-full-text-search]
|
||||
|
||||
[preview] {{esql}}'s support for [full-text search](/reference/query-languages/esql/esql-functions-operators.md#esql-search-functions) is currently in Technical Preview. One limitation of full-text search is that it is necessary to use the search function, like [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), in a [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command directly after the [`FROM`](/reference/query-languages/esql/esql-commands.md#esql-from) source command, or close enough to it. Otherwise, the query will fail with a validation error. Another limitation is that any [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command containing a full-text search function cannot use disjunctions (`OR`), unless:
|
||||
|
||||
* All functions used in the OR clauses are full-text functions themselves, or scoring is not used
|
||||
[preview] {{esql}}'s support for [full-text search](/reference/query-languages/esql/esql-functions-operators.md#esql-search-functions) is currently in Technical Preview. One limitation of full-text search is that it is necessary to use the search function, like [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), in a [`WHERE`](/reference/query-languages/esql/esql-commands.md#esql-where) command directly after the [`FROM`](/reference/query-languages/esql/esql-commands.md#esql-from) source command, or close enough to it. Otherwise, the query will fail with a validation error.
|
||||
|
||||
For example, this query is valid:
|
||||
|
||||
|
@ -136,27 +134,6 @@ FROM books
|
|||
| WHERE MATCH(author, "Faulkner")
|
||||
```
|
||||
|
||||
And this query that uses a disjunction will succeed:
|
||||
|
||||
```esql
|
||||
FROM books
|
||||
| WHERE MATCH(author, "Faulkner") OR QSTR("author: Hemingway")
|
||||
```
|
||||
|
||||
However using scoring will fail because it uses a non full text function as part of the disjunction:
|
||||
|
||||
```esql
|
||||
FROM books METADATA _score
|
||||
| WHERE MATCH(author, "Faulkner") OR author LIKE "Hemingway"
|
||||
```
|
||||
|
||||
Scoring will work in the following query, as it uses full text functions on both `OR` clauses:
|
||||
|
||||
```esql
|
||||
FROM books METADATA _score
|
||||
| WHERE MATCH(author, "Faulkner") OR QSTR("author: Hemingway")
|
||||
```
|
||||
|
||||
Note that, because of [the way {{esql}} treats `text` values](#esql-limitations-text-fields), any queries on `text` fields that do not explicitly use the full-text functions, [`MATCH`](/reference/query-languages/esql/esql-functions-operators.md#esql-match), [`QSTR`](/reference/query-languages/esql/esql-functions-operators.md#esql-qstr) or [`KQL`](/reference/query-languages/esql/esql-functions-operators.md#esql-kql), will behave as if the fields are actually `keyword` fields: they are case-sensitive and need to match the full string.
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,403 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.BulkScorer;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.LeafCollector;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.elasticsearch.common.CheckedBiConsumer;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DocBlock;
|
||||
import org.elasticsearch.compute.data.DocVector;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.data.Vector;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Base class for evaluating a Lucene query at the compute engine and providing a Block as a result.
|
||||
* Subclasses can override methods to decide what type of {@link Block} should be returned, and how to add results to it
|
||||
* based on documents on the Page matching the query or not.
|
||||
* See {@link LuceneQueryExpressionEvaluator} for an example of how to use this class and {@link LuceneQueryScoreEvaluator} for
|
||||
* examples of subclasses that provide different types of scoring results for different ESQL constructs.
|
||||
* It's much faster to push queries to the {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
|
||||
* this class is here to save the day.
|
||||
*/
|
||||
public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {
|
||||
|
||||
public record ShardConfig(Query query, IndexSearcher searcher) {}
|
||||
|
||||
private final BlockFactory blockFactory;
|
||||
private final ShardConfig[] shards;
|
||||
|
||||
private final List<ShardState> perShardState;
|
||||
|
||||
protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
|
||||
this.blockFactory = blockFactory;
|
||||
this.shards = shards;
|
||||
this.perShardState = new ArrayList<>(Collections.nCopies(shards.length, null));
|
||||
}
|
||||
|
||||
public Block executeQuery(Page page) {
|
||||
// Lucene based operators retrieve DocVectors as first block
|
||||
Block block = page.getBlock(0);
|
||||
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
|
||||
DocVector docs = (DocVector) block.asVector();
|
||||
try {
|
||||
if (docs.singleSegmentNonDecreasing()) {
|
||||
return evalSingleSegmentNonDecreasing(docs).asBlock();
|
||||
} else {
|
||||
return evalSlow(docs).asBlock();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
|
||||
* <p>
|
||||
* ESQL receives documents in DocVector, and they can be in one of two
|
||||
* states. Either the DocVector contains documents from a single segment
|
||||
* non-decreasing order, or it doesn't. that first case is much more like
|
||||
* how Lucene likes to process documents. and it's much more common. So we
|
||||
* optimize for it.
|
||||
* <p>
|
||||
* Vectors that are {@link DocVector#singleSegmentNonDecreasing()}
|
||||
* represent many documents from a single Lucene segment. In Elasticsearch
|
||||
* terms that's a segment in a single shard. And the document ids are in
|
||||
* non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}.
|
||||
* But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are
|
||||
* very like how lucene "natively" processes documents and this optimizes
|
||||
* those accesses.
|
||||
* </p>
|
||||
* <p>
|
||||
* If the documents are literally {@code 0, 1, ... n} then we use
|
||||
* {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes
|
||||
* a whole range. This should be quite common in the case where we don't
|
||||
* have deleted documents because that's the order that
|
||||
* {@link LuceneSourceOperator} produces them.
|
||||
* </p>
|
||||
* <p>
|
||||
* If there are gaps in the sequence we use {@link Scorer} calls to
|
||||
* score the sequence. This'll be less fast but isn't going be particularly
|
||||
* common.
|
||||
* </p>
|
||||
*/
|
||||
private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
|
||||
ShardState shardState = shardState(docs.shards().getInt(0));
|
||||
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
|
||||
int min = docs.docs().getInt(0);
|
||||
int max = docs.docs().getInt(docs.getPositionCount() - 1);
|
||||
int length = max - min + 1;
|
||||
try (T scoreBuilder = createVectorBuilder(blockFactory, length)) {
|
||||
if (length == docs.getPositionCount() && length > 1) {
|
||||
return segmentState.scoreDense(scoreBuilder, min, max);
|
||||
}
|
||||
return segmentState.scoreSparse(scoreBuilder, docs.docs());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See
|
||||
* {@link #evalSingleSegmentNonDecreasing} for the meaning of
|
||||
* {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently
|
||||
* evaluate those segments.
|
||||
* <p>
|
||||
* This processes the worst case blocks of documents. These can be from any
|
||||
* number of shards and any number of segments and in any order. We do this
|
||||
* by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order.
|
||||
* So, that's segment by segment, docs ascending. We build a boolean block
|
||||
* out of that. Then we <strong>sort</strong> that to put the booleans in
|
||||
* the order that the {@link DocVector} came in.
|
||||
* </p>
|
||||
*/
|
||||
private Vector evalSlow(DocVector docs) throws IOException {
|
||||
int[] map = docs.shardSegmentDocMapForwards();
|
||||
// Clear any state flags from the previous run
|
||||
int prevShard = -1;
|
||||
int prevSegment = -1;
|
||||
SegmentState segmentState = null;
|
||||
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
|
||||
for (int i = 0; i < docs.getPositionCount(); i++) {
|
||||
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
|
||||
int segment = docs.segments().getInt(map[i]);
|
||||
if (segmentState == null || prevShard != shard || prevSegment != segment) {
|
||||
segmentState = shardState(shard).segmentState(segment);
|
||||
segmentState.initScorer(docs.docs().getInt(map[i]));
|
||||
prevShard = shard;
|
||||
prevSegment = segment;
|
||||
}
|
||||
if (segmentState.noMatch) {
|
||||
appendNoMatch(scoreBuilder);
|
||||
} else {
|
||||
segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
|
||||
}
|
||||
}
|
||||
try (Vector outOfOrder = scoreBuilder.build()) {
|
||||
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
private ShardState shardState(int shard) throws IOException {
|
||||
ShardState shardState = perShardState.get(shard);
|
||||
if (shardState != null) {
|
||||
return shardState;
|
||||
}
|
||||
shardState = new ShardState(shards[shard]);
|
||||
perShardState.set(shard, shardState);
|
||||
return shardState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains shard search related information, like the weight and index searcher
|
||||
*/
|
||||
private class ShardState {
|
||||
private final Weight weight;
|
||||
private final IndexSearcher searcher;
|
||||
private final List<SegmentState> perSegmentState;
|
||||
|
||||
ShardState(ShardConfig config) throws IOException {
|
||||
weight = config.searcher.createWeight(config.query, scoreMode(), 1.0f);
|
||||
searcher = config.searcher;
|
||||
perSegmentState = new ArrayList<>(Collections.nCopies(searcher.getLeafContexts().size(), null));
|
||||
}
|
||||
|
||||
SegmentState segmentState(int segment) throws IOException {
|
||||
SegmentState segmentState = perSegmentState.get(segment);
|
||||
if (segmentState != null) {
|
||||
return segmentState;
|
||||
}
|
||||
segmentState = new SegmentState(weight, searcher.getLeafContexts().get(segment));
|
||||
perSegmentState.set(segment, segmentState);
|
||||
return segmentState;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains segment search related information, like the leaf reader context and bulk scorer
|
||||
*/
|
||||
private class SegmentState {
|
||||
private final Weight weight;
|
||||
private final LeafReaderContext ctx;
|
||||
|
||||
/**
|
||||
* Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized
|
||||
* <strong>or</strong> that {@link #noMatch} is true.
|
||||
*/
|
||||
private Scorer scorer;
|
||||
|
||||
/**
|
||||
* Thread that initialized the {@link #scorer}.
|
||||
*/
|
||||
private Thread scorerThread;
|
||||
|
||||
/**
|
||||
* Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized
|
||||
* <strong>or</strong> that {@link #noMatch} is true.
|
||||
*/
|
||||
private BulkScorer bulkScorer;
|
||||
|
||||
/**
|
||||
* Thread that initialized the {@link #bulkScorer}.
|
||||
*/
|
||||
private Thread bulkScorerThread;
|
||||
|
||||
/**
|
||||
* Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer},
|
||||
* the {@link Weight} tells us there aren't any matches.
|
||||
*/
|
||||
private boolean noMatch;
|
||||
|
||||
private SegmentState(Weight weight, LeafReaderContext ctx) {
|
||||
this.weight = weight;
|
||||
this.ctx = ctx;
|
||||
}
|
||||
|
||||
/**
|
||||
* Score a range using the {@link BulkScorer}. This should be faster
|
||||
* than using {@link #scoreSparse} for dense doc ids.
|
||||
*/
|
||||
Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
|
||||
if (noMatch) {
|
||||
return createNoMatchVector(blockFactory, max - min + 1);
|
||||
}
|
||||
if (bulkScorer == null || // The bulkScorer wasn't initialized
|
||||
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
|
||||
) {
|
||||
bulkScorerThread = Thread.currentThread();
|
||||
bulkScorer = weight.bulkScorer(ctx);
|
||||
if (bulkScorer == null) {
|
||||
noMatch = true;
|
||||
return createNoMatchVector(blockFactory, max - min + 1);
|
||||
}
|
||||
}
|
||||
try (
|
||||
DenseCollector<T> collector = new DenseCollector<>(
|
||||
min,
|
||||
max,
|
||||
scoreBuilder,
|
||||
LuceneQueryEvaluator.this::appendNoMatch,
|
||||
LuceneQueryEvaluator.this::appendMatch
|
||||
)
|
||||
) {
|
||||
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
|
||||
return collector.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
|
||||
* doc ids it'd be faster to use {@link #scoreDense}.
|
||||
*/
|
||||
Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
|
||||
initScorer(docs.getInt(0));
|
||||
if (noMatch) {
|
||||
return createNoMatchVector(blockFactory, docs.getPositionCount());
|
||||
}
|
||||
for (int i = 0; i < docs.getPositionCount(); i++) {
|
||||
scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
|
||||
}
|
||||
return scoreBuilder.build();
|
||||
}
|
||||
|
||||
private void initScorer(int minDocId) throws IOException {
|
||||
if (noMatch) {
|
||||
return;
|
||||
}
|
||||
if (scorer == null || // Scorer not initialized
|
||||
scorerThread != Thread.currentThread() || // Scorer initialized on a different thread
|
||||
scorer.iterator().docID() > minDocId // The previous block came "after" this one
|
||||
) {
|
||||
scorerThread = Thread.currentThread();
|
||||
scorer = weight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
noMatch = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void scoreSingleDocWithScorer(T builder, int doc) throws IOException {
|
||||
if (scorer.iterator().docID() == doc) {
|
||||
appendMatch(builder, scorer);
|
||||
} else if (scorer.iterator().docID() > doc) {
|
||||
appendNoMatch(builder);
|
||||
} else {
|
||||
if (scorer.iterator().advance(doc) == doc) {
|
||||
appendMatch(builder, scorer);
|
||||
} else {
|
||||
appendNoMatch(builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects matching information for dense range of doc ids. This assumes that
|
||||
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
|
||||
* which isn't documented, but @jpountz swears is true.
|
||||
*/
|
||||
static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
|
||||
private final U scoreBuilder;
|
||||
private final int max;
|
||||
private final Consumer<U> appendNoMatch;
|
||||
private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;
|
||||
|
||||
private Scorable scorer;
|
||||
int next;
|
||||
|
||||
DenseCollector(
|
||||
int min,
|
||||
int max,
|
||||
U scoreBuilder,
|
||||
Consumer<U> appendNoMatch,
|
||||
CheckedBiConsumer<U, Scorable, IOException> appendMatch
|
||||
) {
|
||||
this.scoreBuilder = scoreBuilder;
|
||||
this.max = max;
|
||||
next = min;
|
||||
this.appendNoMatch = appendNoMatch;
|
||||
this.appendMatch = appendMatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScorer(Scorable scorable) {
|
||||
this.scorer = scorable;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) throws IOException {
|
||||
while (next++ < doc) {
|
||||
appendNoMatch.accept(scoreBuilder);
|
||||
}
|
||||
appendMatch.accept(scoreBuilder, scorer);
|
||||
}
|
||||
|
||||
public Vector build() {
|
||||
return scoreBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
while (next++ <= max) {
|
||||
appendNoMatch.accept(scoreBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(scoreBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the score mode to use on searches
|
||||
*/
|
||||
protected abstract ScoreMode scoreMode();
|
||||
|
||||
/**
|
||||
* Creates a vector where all positions correspond to elements that don't match the query
|
||||
*/
|
||||
protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);
|
||||
|
||||
/**
|
||||
* Creates the corresponding vector builder to store the results of evaluating the query
|
||||
*/
|
||||
protected abstract T createVectorBuilder(BlockFactory blockFactory, int size);
|
||||
|
||||
/**
|
||||
* Appends a matching result to a builder created by @link createVectorBuilder}
|
||||
*/
|
||||
protected abstract void appendMatch(T builder, Scorable scorer) throws IOException;
|
||||
|
||||
/**
|
||||
* Appends a non matching result to a builder created by @link createVectorBuilder}
|
||||
*/
|
||||
protected abstract void appendNoMatch(T builder);
|
||||
|
||||
}
|
|
@ -7,350 +7,64 @@
|
|||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.BulkScorer;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.LeafCollector;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BooleanVector;
|
||||
import org.elasticsearch.compute.data.DocBlock;
|
||||
import org.elasticsearch.compute.data.DocVector;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.data.Vector;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
|
||||
/**
|
||||
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
|
||||
* the compute engine's normal execution, yielding matches/does not match into
|
||||
* a {@link BooleanVector}. It's much faster to push these to the
|
||||
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
|
||||
* this evaluator is here to save the day.
|
||||
* a {@link BooleanVector}.
|
||||
* @see LuceneQueryScoreEvaluator
|
||||
*/
|
||||
public class LuceneQueryExpressionEvaluator implements EvalOperator.ExpressionEvaluator {
|
||||
public record ShardConfig(Query query, IndexSearcher searcher) {}
|
||||
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
|
||||
implements
|
||||
EvalOperator.ExpressionEvaluator {
|
||||
|
||||
private final BlockFactory blockFactory;
|
||||
private final ShardConfig[] shards;
|
||||
|
||||
private ShardState[] perShardState = EMPTY_SHARD_STATES;
|
||||
|
||||
public LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
|
||||
this.blockFactory = blockFactory;
|
||||
this.shards = shards;
|
||||
LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
|
||||
super(blockFactory, shards);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Block eval(Page page) {
|
||||
// Lucene based operators retrieve DocVectors as first block
|
||||
Block block = page.getBlock(0);
|
||||
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
|
||||
DocVector docs = (DocVector) block.asVector();
|
||||
try {
|
||||
if (docs.singleSegmentNonDecreasing()) {
|
||||
return evalSingleSegmentNonDecreasing(docs).asBlock();
|
||||
} else {
|
||||
return evalSlow(docs).asBlock();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate {@link DocVector#singleSegmentNonDecreasing()} documents.
|
||||
* <p>
|
||||
* ESQL receives documents in DocVector, and they can be in one of two
|
||||
* states. Either the DocVector contains documents from a single segment
|
||||
* non-decreasing order, or it doesn't. that first case is much more like
|
||||
* how Lucene likes to process documents. and it's much more common. So we
|
||||
* optimize for it.
|
||||
* <p>
|
||||
* Vectors that are {@link DocVector#singleSegmentNonDecreasing()}
|
||||
* represent many documents from a single Lucene segment. In Elasticsearch
|
||||
* terms that's a segment in a single shard. And the document ids are in
|
||||
* non-decreasing order. Probably just {@code 0, 1, 2, 3, 4, 5...}.
|
||||
* But maybe something like {@code 0, 5, 6, 10, 10, 10}. Both of those are
|
||||
* very like how lucene "natively" processes documents and this optimizes
|
||||
* those accesses.
|
||||
* </p>
|
||||
* <p>
|
||||
* If the documents are literally {@code 0, 1, ... n} then we use
|
||||
* {@link BulkScorer#score(LeafCollector, Bits, int, int)} which processes
|
||||
* a whole range. This should be quite common in the case where we don't
|
||||
* have deleted documents because that's the order that
|
||||
* {@link LuceneSourceOperator} produces them.
|
||||
* </p>
|
||||
* <p>
|
||||
* If there are gaps in the sequence we use {@link Scorer} calls to
|
||||
* score the sequence. This'll be less fast but isn't going be particularly
|
||||
* common.
|
||||
* </p>
|
||||
*/
|
||||
private BooleanVector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
|
||||
ShardState shardState = shardState(docs.shards().getInt(0));
|
||||
SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
|
||||
int min = docs.docs().getInt(0);
|
||||
int max = docs.docs().getInt(docs.getPositionCount() - 1);
|
||||
int length = max - min + 1;
|
||||
if (length == docs.getPositionCount() && length > 1) {
|
||||
return segmentState.scoreDense(min, max);
|
||||
}
|
||||
return segmentState.scoreSparse(docs.docs());
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate non-{@link DocVector#singleSegmentNonDecreasing()} documents. See
|
||||
* {@link #evalSingleSegmentNonDecreasing} for the meaning of
|
||||
* {@link DocVector#singleSegmentNonDecreasing()} and how we can efficiently
|
||||
* evaluate those segments.
|
||||
* <p>
|
||||
* This processes the worst case blocks of documents. These can be from any
|
||||
* number of shards and any number of segments and in any order. We do this
|
||||
* by iterating the docs in {@code shard ASC, segment ASC, doc ASC} order.
|
||||
* So, that's segment by segment, docs ascending. We build a boolean block
|
||||
* out of that. Then we <strong>sort</strong> that to put the booleans in
|
||||
* the order that the {@link DocVector} came in.
|
||||
* </p>
|
||||
*/
|
||||
private BooleanVector evalSlow(DocVector docs) throws IOException {
|
||||
int[] map = docs.shardSegmentDocMapForwards();
|
||||
// Clear any state flags from the previous run
|
||||
int prevShard = -1;
|
||||
int prevSegment = -1;
|
||||
SegmentState segmentState = null;
|
||||
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
|
||||
for (int i = 0; i < docs.getPositionCount(); i++) {
|
||||
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
|
||||
int segment = docs.segments().getInt(map[i]);
|
||||
if (segmentState == null || prevShard != shard || prevSegment != segment) {
|
||||
segmentState = shardState(shard).segmentState(segment);
|
||||
segmentState.initScorer(docs.docs().getInt(map[i]));
|
||||
prevShard = shard;
|
||||
prevSegment = segment;
|
||||
}
|
||||
if (segmentState.noMatch) {
|
||||
builder.appendBoolean(false);
|
||||
} else {
|
||||
segmentState.scoreSingleDocWithScorer(builder, docs.docs().getInt(map[i]));
|
||||
}
|
||||
}
|
||||
try (BooleanVector outOfOrder = builder.build()) {
|
||||
return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
|
||||
}
|
||||
}
|
||||
return executeQuery(page);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
protected ScoreMode scoreMode() {
|
||||
return ScoreMode.COMPLETE_NO_SCORES;
|
||||
}
|
||||
|
||||
private ShardState shardState(int shard) throws IOException {
|
||||
if (shard >= perShardState.length) {
|
||||
perShardState = ArrayUtil.grow(perShardState, shard + 1);
|
||||
} else if (perShardState[shard] != null) {
|
||||
return perShardState[shard];
|
||||
}
|
||||
perShardState[shard] = new ShardState(shards[shard]);
|
||||
return perShardState[shard];
|
||||
@Override
|
||||
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
|
||||
return blockFactory.newConstantBooleanVector(false, size);
|
||||
}
|
||||
|
||||
private class ShardState {
|
||||
private final Weight weight;
|
||||
private final IndexSearcher searcher;
|
||||
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
|
||||
|
||||
ShardState(ShardConfig config) throws IOException {
|
||||
weight = config.searcher.createWeight(config.query, ScoreMode.COMPLETE_NO_SCORES, 0.0f);
|
||||
searcher = config.searcher;
|
||||
}
|
||||
|
||||
SegmentState segmentState(int segment) throws IOException {
|
||||
if (segment >= perSegmentState.length) {
|
||||
perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1);
|
||||
} else if (perSegmentState[segment] != null) {
|
||||
return perSegmentState[segment];
|
||||
}
|
||||
perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment));
|
||||
return perSegmentState[segment];
|
||||
}
|
||||
@Override
|
||||
protected BooleanVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
|
||||
return blockFactory.newBooleanVectorFixedBuilder(size);
|
||||
}
|
||||
|
||||
private class SegmentState {
|
||||
private final Weight weight;
|
||||
private final LeafReaderContext ctx;
|
||||
|
||||
/**
|
||||
* Lazily initialed {@link Scorer} for this. {@code null} here means uninitialized
|
||||
* <strong>or</strong> that {@link #noMatch} is true.
|
||||
*/
|
||||
private Scorer scorer;
|
||||
|
||||
/**
|
||||
* Thread that initialized the {@link #scorer}.
|
||||
*/
|
||||
private Thread scorerThread;
|
||||
|
||||
/**
|
||||
* Lazily initialed {@link BulkScorer} for this. {@code null} here means uninitialized
|
||||
* <strong>or</strong> that {@link #noMatch} is true.
|
||||
*/
|
||||
private BulkScorer bulkScorer;
|
||||
|
||||
/**
|
||||
* Thread that initialized the {@link #bulkScorer}.
|
||||
*/
|
||||
private Thread bulkScorerThread;
|
||||
|
||||
/**
|
||||
* Set to {@code true} if, in the process of building a {@link Scorer} or {@link BulkScorer},
|
||||
* the {@link Weight} tells us there aren't any matches.
|
||||
*/
|
||||
private boolean noMatch;
|
||||
|
||||
private SegmentState(Weight weight, LeafReaderContext ctx) {
|
||||
this.weight = weight;
|
||||
this.ctx = ctx;
|
||||
}
|
||||
|
||||
/**
|
||||
* Score a range using the {@link BulkScorer}. This should be faster
|
||||
* than using {@link #scoreSparse} for dense doc ids.
|
||||
*/
|
||||
BooleanVector scoreDense(int min, int max) throws IOException {
|
||||
int length = max - min + 1;
|
||||
if (noMatch) {
|
||||
return blockFactory.newConstantBooleanVector(false, length);
|
||||
}
|
||||
if (bulkScorer == null || // The bulkScorer wasn't initialized
|
||||
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
|
||||
) {
|
||||
bulkScorerThread = Thread.currentThread();
|
||||
bulkScorer = weight.bulkScorer(ctx);
|
||||
if (bulkScorer == null) {
|
||||
noMatch = true;
|
||||
return blockFactory.newConstantBooleanVector(false, length);
|
||||
}
|
||||
}
|
||||
try (DenseCollector collector = new DenseCollector(blockFactory, min, max)) {
|
||||
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
|
||||
return collector.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
|
||||
* doc ids it'd be faster to use {@link #scoreDense}.
|
||||
*/
|
||||
BooleanVector scoreSparse(IntVector docs) throws IOException {
|
||||
initScorer(docs.getInt(0));
|
||||
if (noMatch) {
|
||||
return blockFactory.newConstantBooleanVector(false, docs.getPositionCount());
|
||||
}
|
||||
try (BooleanVector.Builder builder = blockFactory.newBooleanVectorFixedBuilder(docs.getPositionCount())) {
|
||||
for (int i = 0; i < docs.getPositionCount(); i++) {
|
||||
scoreSingleDocWithScorer(builder, docs.getInt(i));
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
|
||||
private void initScorer(int minDocId) throws IOException {
|
||||
if (noMatch) {
|
||||
return;
|
||||
}
|
||||
if (scorer == null || // Scorer not initialized
|
||||
scorerThread != Thread.currentThread() || // Scorer initialized on a different thread
|
||||
scorer.iterator().docID() > minDocId // The previous block came "after" this one
|
||||
) {
|
||||
scorerThread = Thread.currentThread();
|
||||
scorer = weight.scorer(ctx);
|
||||
if (scorer == null) {
|
||||
noMatch = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void scoreSingleDocWithScorer(BooleanVector.Builder builder, int doc) throws IOException {
|
||||
if (scorer.iterator().docID() == doc) {
|
||||
builder.appendBoolean(true);
|
||||
} else if (scorer.iterator().docID() > doc) {
|
||||
builder.appendBoolean(false);
|
||||
} else {
|
||||
builder.appendBoolean(scorer.iterator().advance(doc) == doc);
|
||||
}
|
||||
}
|
||||
@Override
|
||||
protected void appendNoMatch(BooleanVector.Builder builder) {
|
||||
builder.appendBoolean(false);
|
||||
}
|
||||
|
||||
private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0];
|
||||
private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0];
|
||||
|
||||
/**
|
||||
* Collects matching information for dense range of doc ids. This assumes that
|
||||
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
|
||||
* which isn't documented, but @jpountz swears is true.
|
||||
*/
|
||||
static class DenseCollector implements LeafCollector, Releasable {
|
||||
private final BooleanVector.FixedBuilder builder;
|
||||
private final int max;
|
||||
|
||||
int next;
|
||||
|
||||
DenseCollector(BlockFactory blockFactory, int min, int max) {
|
||||
this.builder = blockFactory.newBooleanVectorFixedBuilder(max - min + 1);
|
||||
this.max = max;
|
||||
next = min;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScorer(Scorable scorable) {}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) {
|
||||
while (next++ < doc) {
|
||||
builder.appendBoolean(false);
|
||||
}
|
||||
builder.appendBoolean(true);
|
||||
}
|
||||
|
||||
public BooleanVector build() {
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() {
|
||||
while (next++ <= max) {
|
||||
builder.appendBoolean(false);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(builder);
|
||||
}
|
||||
@Override
|
||||
protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
|
||||
builder.appendBoolean(true);
|
||||
}
|
||||
|
||||
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
|
||||
private final ShardConfig[] shardConfigs;
|
||||
|
||||
public Factory(ShardConfig[] shardConfigs) {
|
||||
this.shardConfigs = shardConfigs;
|
||||
}
|
||||
|
||||
public record Factory(ShardConfig[] shardConfigs) implements EvalOperator.ExpressionEvaluator.Factory {
|
||||
@Override
|
||||
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
|
||||
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.data.Vector;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* {@link ScoreOperator.ExpressionScorer} to run a Lucene {@link Query} during
|
||||
* the compute engine's normal execution, yielding the corresponding scores into
|
||||
* a {@link DoubleVector}.
|
||||
* Elements that don't match will have a score of {@link #NO_MATCH_SCORE}.
|
||||
* @see LuceneQueryScoreEvaluator
|
||||
*/
|
||||
public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleVector.Builder> implements ScoreOperator.ExpressionScorer {
|
||||
|
||||
public static final double NO_MATCH_SCORE = 0.0;
|
||||
|
||||
LuceneQueryScoreEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
|
||||
super(blockFactory, shards);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DoubleBlock score(Page page) {
|
||||
return (DoubleBlock) executeQuery(page);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ScoreMode scoreMode() {
|
||||
return ScoreMode.COMPLETE;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
|
||||
return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, size);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DoubleVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
|
||||
return blockFactory.newDoubleVectorFixedBuilder(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void appendNoMatch(DoubleVector.Builder builder) {
|
||||
builder.appendDouble(NO_MATCH_SCORE);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void appendMatch(DoubleVector.Builder builder, Scorable scorer) throws IOException {
|
||||
builder.appendDouble(scorer.score());
|
||||
}
|
||||
|
||||
public record Factory(ShardConfig[] shardConfigs) implements ScoreOperator.ExpressionScorer.Factory {
|
||||
@Override
|
||||
public ScoreOperator.ExpressionScorer get(DriverContext context) {
|
||||
return new LuceneQueryScoreEvaluator(context.blockFactory(), shardConfigs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.operator;
|
||||
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DocVector;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
* Evaluates scores for a ExpressionScorer. The scores are added to the existing scores in the input page
|
||||
*/
|
||||
public class ScoreOperator extends AbstractPageMappingOperator {
|
||||
|
||||
public record ScoreOperatorFactory(ExpressionScorer.Factory scorerFactory, int scoreBlockPosition) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new ScoreOperator(driverContext.blockFactory(), scorerFactory.get(driverContext), scoreBlockPosition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() {
|
||||
return "ScoreOperator[scorer=" + scorerFactory + "]";
|
||||
}
|
||||
}
|
||||
|
||||
private final BlockFactory blockFactory;
|
||||
private final ExpressionScorer scorer;
|
||||
private final int scoreBlockPosition;
|
||||
|
||||
public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int scoreBlockPosition) {
|
||||
this.blockFactory = blockFactory;
|
||||
this.scorer = scorer;
|
||||
this.scoreBlockPosition = scoreBlockPosition;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Page process(Page page) {
|
||||
assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
|
||||
assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
|
||||
assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
|
||||
|
||||
Block[] blocks = new Block[page.getBlockCount()];
|
||||
for (int i = 0; i < page.getBlockCount(); i++) {
|
||||
if (i == scoreBlockPosition) {
|
||||
blocks[i] = calculateScoresBlock(page);
|
||||
} else {
|
||||
blocks[i] = page.getBlock(i);
|
||||
}
|
||||
}
|
||||
|
||||
return new Page(blocks);
|
||||
}
|
||||
|
||||
private Block calculateScoresBlock(Page page) {
|
||||
try (DoubleBlock evalScores = scorer.score(page); DoubleBlock existingScores = page.getBlock(scoreBlockPosition)) {
|
||||
// TODO Optimize for constant scores?
|
||||
int rowCount = page.getPositionCount();
|
||||
DoubleVector.Builder builder = blockFactory.newDoubleVectorFixedBuilder(rowCount);
|
||||
for (int i = 0; i < rowCount; i++) {
|
||||
builder.appendDouble(existingScores.getDouble(i) + evalScores.getDouble(i));
|
||||
}
|
||||
return builder.build().asBlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return getClass().getSimpleName() + "[scorer=" + scorer + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(scorer, super::close);
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the score of an expression one {@link Page} at a time.
|
||||
*/
|
||||
public interface ExpressionScorer extends Releasable {
|
||||
/** A Factory for creating ExpressionScorers. */
|
||||
interface Factory {
|
||||
ExpressionScorer get(DriverContext context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scores the expression.
|
||||
* @return the returned Block has its own reference and the caller is responsible for releasing it.
|
||||
*/
|
||||
DoubleBlock score(Page page);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KeywordField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.ConstantScoreQuery;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.MultiTermQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.apache.lucene.search.TermInSetQuery;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.compute.OperatorTests;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.DocBlock;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.data.Vector;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.ShuffleDocsOperator;
|
||||
import org.elasticsearch.compute.test.ComputeTestCase;
|
||||
import org.elasticsearch.compute.test.OperatorTestCase;
|
||||
import org.elasticsearch.compute.test.TestDriverFactory;
|
||||
import org.elasticsearch.compute.test.TestResultPageSinkOperator;
|
||||
import org.elasticsearch.core.CheckedFunction;
|
||||
import org.elasticsearch.index.mapper.BlockDocValuesReader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
/**
|
||||
* Base class for testing Lucene query evaluators.
|
||||
*/
|
||||
public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vector.Builder> extends ComputeTestCase {
|
||||
|
||||
private static final String FIELD = "g";
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDenseCollectorSmall() throws IOException {
|
||||
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 2)) {
|
||||
collector.setScorer(getScorer());
|
||||
collector.collect(0);
|
||||
collector.collect(1);
|
||||
collector.collect(2);
|
||||
collector.finish();
|
||||
try (T result = (T) collector.build()) {
|
||||
for (int i = 0; i <= 2; i++) {
|
||||
assertCollectedResultMatch(result, i, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDenseCollectorSimple() throws IOException {
|
||||
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 10)) {
|
||||
collector.setScorer(getScorer());
|
||||
collector.collect(2);
|
||||
collector.collect(5);
|
||||
collector.finish();
|
||||
try (T result = (T) collector.build()) {
|
||||
for (int i = 0; i < 11; i++) {
|
||||
assertCollectedResultMatch(result, i, i == 2 || i == 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDenseCollector() throws IOException {
|
||||
int length = between(1, 10_000);
|
||||
int min = between(0, Integer.MAX_VALUE - length - 1);
|
||||
int max = min + length;
|
||||
boolean[] expected = new boolean[length];
|
||||
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(min, max)) {
|
||||
collector.setScorer(getScorer());
|
||||
for (int i = 0; i < length; i++) {
|
||||
expected[i] = randomBoolean();
|
||||
if (expected[i]) {
|
||||
collector.collect(min + i);
|
||||
}
|
||||
}
|
||||
collector.finish();
|
||||
try (T result = (T) collector.build()) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
assertCollectedResultMatch(result, i, expected[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a dense collector for the given range.
|
||||
*/
|
||||
protected abstract LuceneQueryEvaluator.DenseCollector<U> createDenseCollector(int min, int max);
|
||||
|
||||
/**
|
||||
* Chceks that the collected results at the given position corresponds to a match or no match
|
||||
*/
|
||||
protected abstract void assertCollectedResultMatch(T resultVector, int position, boolean isMatch);
|
||||
|
||||
public void testTermQuery() throws IOException {
|
||||
Set<String> values = values();
|
||||
String term = values.iterator().next();
|
||||
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), false);
|
||||
assertTermsQuery(results, Set.of(term), 1);
|
||||
}
|
||||
|
||||
public void testTermQueryShuffled() throws IOException {
|
||||
Set<String> values = values();
|
||||
String term = values.iterator().next();
|
||||
List<Page> results = runQuery(values, new ConstantScoreQuery(new TermQuery(new Term(FIELD, term))), true);
|
||||
assertTermsQuery(results, Set.of(term), 1);
|
||||
}
|
||||
|
||||
public void testTermsQuery() throws IOException {
|
||||
testTermsQuery(false);
|
||||
}
|
||||
|
||||
public void testTermsQueryShuffled() throws IOException {
|
||||
testTermsQuery(true);
|
||||
}
|
||||
|
||||
private void testTermsQuery(boolean shuffleDocs) throws IOException {
|
||||
Set<String> values = values();
|
||||
Iterator<String> itr = values.iterator();
|
||||
TreeSet<String> matching = new TreeSet<>();
|
||||
TreeSet<BytesRef> matchingBytes = new TreeSet<>();
|
||||
int expectedMatchCount = between(2, values.size());
|
||||
for (int i = 0; i < expectedMatchCount; i++) {
|
||||
String v = itr.next();
|
||||
matching.add(v);
|
||||
matchingBytes.add(new BytesRef(v));
|
||||
}
|
||||
List<Page> results = runQuery(values, new TermInSetQuery(MultiTermQuery.CONSTANT_SCORE_REWRITE, FIELD, matchingBytes), shuffleDocs);
|
||||
assertTermsQuery(results, matching, expectedMatchCount);
|
||||
}
|
||||
|
||||
protected void assertTermsQuery(List<Page> results, Set<String> matching, int expectedMatchCount) {
|
||||
int matchCount = 0;
|
||||
for (Page page : results) {
|
||||
int initialBlockIndex = termsBlockIndex(page);
|
||||
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
|
||||
@SuppressWarnings("unchecked")
|
||||
T resultVector = (T) page.getBlock(resultsBlockIndex(page)).asVector();
|
||||
for (int i = 0; i < page.getPositionCount(); i++) {
|
||||
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
|
||||
boolean isMatch = matching.contains(termAtPosition.utf8ToString());
|
||||
assertTermResultMatch(resultVector, i, isMatch);
|
||||
if (isMatch) {
|
||||
matchCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
assertThat(matchCount, equalTo(expectedMatchCount));
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks that the result at the given position corresponds to a term match or no match
|
||||
*/
|
||||
protected abstract void assertTermResultMatch(T resultVector, int position, boolean isMatch);
|
||||
|
||||
private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs) throws IOException {
|
||||
DriverContext driverContext = driverContext();
|
||||
BlockFactory blockFactory = driverContext.blockFactory();
|
||||
return withReader(values, reader -> {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
LuceneQueryEvaluator.ShardConfig shard = new LuceneQueryEvaluator.ShardConfig(searcher.rewrite(query), searcher);
|
||||
List<Operator> operators = new ArrayList<>();
|
||||
if (shuffleDocs) {
|
||||
operators.add(new ShuffleDocsOperator(blockFactory));
|
||||
}
|
||||
operators.add(
|
||||
new ValuesSourceReaderOperator(
|
||||
blockFactory,
|
||||
List.of(
|
||||
new ValuesSourceReaderOperator.FieldInfo(
|
||||
FIELD,
|
||||
ElementType.BYTES_REF,
|
||||
unused -> new BlockDocValuesReader.BytesRefsFromOrdsBlockLoader(FIELD)
|
||||
)
|
||||
),
|
||||
List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> {
|
||||
throw new UnsupportedOperationException();
|
||||
})),
|
||||
0
|
||||
)
|
||||
);
|
||||
LuceneQueryEvaluator.ShardConfig[] shards = new LuceneQueryEvaluator.ShardConfig[] {
|
||||
new LuceneQueryEvaluator.ShardConfig(searcher.rewrite(query), searcher) };
|
||||
operators.add(createOperator(blockFactory, shards));
|
||||
List<Page> results = new ArrayList<>();
|
||||
Driver driver = TestDriverFactory.create(
|
||||
driverContext,
|
||||
LuceneQueryEvaluatorTests.luceneOperatorFactory(reader, new MatchAllDocsQuery(), usesScoring()).get(driverContext),
|
||||
operators,
|
||||
new TestResultPageSinkOperator(results::add)
|
||||
);
|
||||
OperatorTestCase.runDriver(driver);
|
||||
OperatorTests.assertDriverContext(driverContext);
|
||||
return results;
|
||||
});
|
||||
}
|
||||
|
||||
private <T> T withReader(Set<String> values, CheckedFunction<DirectoryReader, T, IOException> run) throws IOException {
|
||||
try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
|
||||
for (String value : values) {
|
||||
writer.addDocument(List.of(new KeywordField(FIELD, value, Field.Store.NO)));
|
||||
}
|
||||
writer.commit();
|
||||
try (DirectoryReader reader = writer.getReader()) {
|
||||
return run.apply(reader);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<String> values() {
|
||||
int maxNumDocs = between(10, 1_000);
|
||||
int keyLength = randomIntBetween(1, 10);
|
||||
Set<String> values = new HashSet<>();
|
||||
for (int i = 0; i < maxNumDocs; i++) {
|
||||
values.add(randomAlphaOfLength(keyLength));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@link DriverContext} with a non-breaking-BigArrays.
|
||||
*/
|
||||
private DriverContext driverContext() {
|
||||
BlockFactory blockFactory = blockFactory();
|
||||
return new DriverContext(blockFactory.bigArrays(), blockFactory);
|
||||
}
|
||||
|
||||
// Returns the initial block index, ignoring the score block if scoring is enabled
|
||||
protected int termsBlockIndex(Page page) {
|
||||
assert page.getBlock(0) instanceof DocBlock : "expected doc block at index 0";
|
||||
if (usesScoring()) {
|
||||
assert page.getBlock(1) instanceof DoubleBlock : "expected double block at index 1";
|
||||
return 2;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
private static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, boolean scoring) {
|
||||
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
|
||||
return new LuceneSourceOperator.Factory(
|
||||
List.of(searchContext),
|
||||
ctx -> query,
|
||||
randomFrom(DataPartitioning.values()),
|
||||
randomIntBetween(1, 10),
|
||||
randomPageSize(),
|
||||
LuceneOperator.NO_LIMIT,
|
||||
scoring
|
||||
);
|
||||
}
|
||||
|
||||
// Returns the block index for the results to check
|
||||
protected abstract int resultsBlockIndex(Page page);
|
||||
|
||||
/**
|
||||
* Returns a test scorer to use for scoring docs. Can be null
|
||||
*/
|
||||
protected abstract Scorable getScorer();
|
||||
|
||||
/**
|
||||
* Create the operator to test
|
||||
*/
|
||||
protected abstract Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards);
|
||||
|
||||
/**
|
||||
* Should the test use scoring?
|
||||
*/
|
||||
protected abstract boolean usesScoring();
|
||||
}
|
|
@ -7,275 +7,59 @@
|
|||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KeywordField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.MultiTermQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermInSetQuery;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.compute.OperatorTests;
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BooleanBlock;
|
||||
import org.elasticsearch.compute.data.BooleanVector;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefVector;
|
||||
import org.elasticsearch.compute.data.DocBlock;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.DenseCollector;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.DenseCollector;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.ShuffleDocsOperator;
|
||||
import org.elasticsearch.compute.test.ComputeTestCase;
|
||||
import org.elasticsearch.compute.test.OperatorTestCase;
|
||||
import org.elasticsearch.compute.test.TestDriverFactory;
|
||||
import org.elasticsearch.compute.test.TestResultPageSinkOperator;
|
||||
import org.elasticsearch.core.CheckedFunction;
|
||||
import org.elasticsearch.index.mapper.BlockDocValuesReader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class LuceneQueryExpressionEvaluatorTests extends ComputeTestCase {
|
||||
private static final String FIELD = "g";
|
||||
public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTests<BooleanVector, BooleanVector.Builder> {
|
||||
|
||||
public void testDenseCollectorSmall() {
|
||||
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 2)) {
|
||||
collector.collect(0);
|
||||
collector.collect(1);
|
||||
collector.collect(2);
|
||||
collector.finish();
|
||||
try (BooleanVector result = collector.build()) {
|
||||
for (int i = 0; i <= 2; i++) {
|
||||
assertThat(result.getBoolean(i), equalTo(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
private final boolean useScoring = randomBoolean();
|
||||
|
||||
public void testDenseCollectorSimple() {
|
||||
try (DenseCollector collector = new DenseCollector(blockFactory(), 0, 10)) {
|
||||
collector.collect(2);
|
||||
collector.collect(5);
|
||||
collector.finish();
|
||||
try (BooleanVector result = collector.build()) {
|
||||
for (int i = 0; i < 11; i++) {
|
||||
assertThat(result.getBoolean(i), equalTo(i == 2 || i == 5));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testDenseCollector() {
|
||||
int length = between(1, 10_000);
|
||||
int min = between(0, Integer.MAX_VALUE - length - 1);
|
||||
int max = min + length + 1;
|
||||
boolean[] expected = new boolean[length];
|
||||
try (DenseCollector collector = new DenseCollector(blockFactory(), min, max)) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
expected[i] = randomBoolean();
|
||||
if (expected[i]) {
|
||||
collector.collect(min + i);
|
||||
}
|
||||
}
|
||||
collector.finish();
|
||||
try (BooleanVector result = collector.build()) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
assertThat(result.getBoolean(i), equalTo(expected[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testTermQuery() throws IOException {
|
||||
Set<String> values = values();
|
||||
String term = values.iterator().next();
|
||||
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), false);
|
||||
assertTermQuery(term, results);
|
||||
}
|
||||
|
||||
public void testTermQueryShuffled() throws IOException {
|
||||
Set<String> values = values();
|
||||
String term = values.iterator().next();
|
||||
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), true);
|
||||
assertTermQuery(term, results);
|
||||
}
|
||||
|
||||
private void assertTermQuery(String term, List<Page> results) {
|
||||
int matchCount = 0;
|
||||
for (Page page : results) {
|
||||
int initialBlockIndex = initialBlockIndex(page);
|
||||
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
|
||||
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
|
||||
for (int i = 0; i < page.getPositionCount(); i++) {
|
||||
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
|
||||
assertThat(matches.getBoolean(i), equalTo(termAtPosition.utf8ToString().equals(term)));
|
||||
if (matches.getBoolean(i)) {
|
||||
matchCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
assertThat(matchCount, equalTo(1));
|
||||
}
|
||||
|
||||
public void testTermsQuery() throws IOException {
|
||||
testTermsQuery(false);
|
||||
}
|
||||
|
||||
public void testTermsQueryShuffled() throws IOException {
|
||||
testTermsQuery(true);
|
||||
}
|
||||
|
||||
private void testTermsQuery(boolean shuffleDocs) throws IOException {
|
||||
Set<String> values = values();
|
||||
Iterator<String> itr = values.iterator();
|
||||
TreeSet<String> matching = new TreeSet<>();
|
||||
TreeSet<BytesRef> matchingBytes = new TreeSet<>();
|
||||
int expectedMatchCount = between(2, values.size());
|
||||
for (int i = 0; i < expectedMatchCount; i++) {
|
||||
String v = itr.next();
|
||||
matching.add(v);
|
||||
matchingBytes.add(new BytesRef(v));
|
||||
}
|
||||
List<Page> results = runQuery(values, new TermInSetQuery(MultiTermQuery.CONSTANT_SCORE_REWRITE, FIELD, matchingBytes), shuffleDocs);
|
||||
int matchCount = 0;
|
||||
for (Page page : results) {
|
||||
int initialBlockIndex = initialBlockIndex(page);
|
||||
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
|
||||
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
|
||||
for (int i = 0; i < page.getPositionCount(); i++) {
|
||||
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
|
||||
assertThat(matches.getBoolean(i), equalTo(matching.contains(termAtPosition.utf8ToString())));
|
||||
if (matches.getBoolean(i)) {
|
||||
matchCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
assertThat(matchCount, equalTo(expectedMatchCount));
|
||||
}
|
||||
|
||||
private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs) throws IOException {
|
||||
DriverContext driverContext = driverContext();
|
||||
BlockFactory blockFactory = driverContext.blockFactory();
|
||||
return withReader(values, reader -> {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
LuceneQueryExpressionEvaluator.ShardConfig shard = new LuceneQueryExpressionEvaluator.ShardConfig(
|
||||
searcher.rewrite(query),
|
||||
searcher
|
||||
);
|
||||
LuceneQueryExpressionEvaluator luceneQueryEvaluator = new LuceneQueryExpressionEvaluator(
|
||||
blockFactory,
|
||||
new LuceneQueryExpressionEvaluator.ShardConfig[] { shard }
|
||||
|
||||
);
|
||||
|
||||
List<Operator> operators = new ArrayList<>();
|
||||
if (shuffleDocs) {
|
||||
operators.add(new ShuffleDocsOperator(blockFactory));
|
||||
}
|
||||
operators.add(
|
||||
new ValuesSourceReaderOperator(
|
||||
blockFactory,
|
||||
List.of(
|
||||
new ValuesSourceReaderOperator.FieldInfo(
|
||||
FIELD,
|
||||
ElementType.BYTES_REF,
|
||||
unused -> new BlockDocValuesReader.BytesRefsFromOrdsBlockLoader(FIELD)
|
||||
)
|
||||
),
|
||||
List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> {
|
||||
throw new UnsupportedOperationException();
|
||||
})),
|
||||
0
|
||||
)
|
||||
);
|
||||
operators.add(new EvalOperator(blockFactory, luceneQueryEvaluator));
|
||||
List<Page> results = new ArrayList<>();
|
||||
Driver driver = TestDriverFactory.create(
|
||||
driverContext,
|
||||
luceneOperatorFactory(reader, new MatchAllDocsQuery(), LuceneOperator.NO_LIMIT, scoring).get(driverContext),
|
||||
operators,
|
||||
new TestResultPageSinkOperator(results::add)
|
||||
);
|
||||
OperatorTestCase.runDriver(driver);
|
||||
OperatorTests.assertDriverContext(driverContext);
|
||||
return results;
|
||||
});
|
||||
}
|
||||
|
||||
private <T> T withReader(Set<String> values, CheckedFunction<DirectoryReader, T, IOException> run) throws IOException {
|
||||
try (BaseDirectoryWrapper dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
|
||||
for (String value : values) {
|
||||
writer.addDocument(List.of(new KeywordField(FIELD, value, Field.Store.NO)));
|
||||
}
|
||||
writer.commit();
|
||||
try (DirectoryReader reader = writer.getReader()) {
|
||||
return run.apply(reader);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<String> values() {
|
||||
int maxNumDocs = between(10, 1_000);
|
||||
int keyLength = randomIntBetween(1, 10);
|
||||
Set<String> values = new HashSet<>();
|
||||
for (int i = 0; i < maxNumDocs; i++) {
|
||||
values.add(randomAlphaOfLength(keyLength));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@link DriverContext} with a non-breaking-BigArrays.
|
||||
*/
|
||||
private DriverContext driverContext() {
|
||||
BlockFactory blockFactory = blockFactory();
|
||||
return new DriverContext(blockFactory.bigArrays(), blockFactory);
|
||||
}
|
||||
|
||||
// Scores are not interesting to this test, but enabled conditionally and effectively ignored just for coverage.
|
||||
private final boolean scoring = randomBoolean();
|
||||
|
||||
// Returns the initial block index, ignoring the score block if scoring is enabled
|
||||
private int initialBlockIndex(Page page) {
|
||||
assert page.getBlock(0) instanceof DocBlock : "expected doc block at index 0";
|
||||
if (scoring) {
|
||||
assert page.getBlock(1) instanceof DoubleBlock : "expected double block at index 1";
|
||||
return 2;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, int limit, boolean scoring) {
|
||||
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
|
||||
return new LuceneSourceOperator.Factory(
|
||||
List.of(searchContext),
|
||||
ctx -> query,
|
||||
randomFrom(DataPartitioning.values()),
|
||||
randomIntBetween(1, 10),
|
||||
randomPageSize(),
|
||||
limit,
|
||||
scoring
|
||||
@Override
|
||||
protected DenseCollector<BooleanVector.Builder> createDenseCollector(int min, int max) {
|
||||
return new LuceneQueryEvaluator.DenseCollector<>(
|
||||
min,
|
||||
max,
|
||||
blockFactory().newBooleanVectorFixedBuilder(max - min + 1),
|
||||
b -> b.appendBoolean(false),
|
||||
(b, s) -> b.appendBoolean(true)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Scorable getScorer() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
|
||||
return new EvalOperator(blockFactory, new LuceneQueryExpressionEvaluator(blockFactory, shards));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean usesScoring() {
|
||||
// Be consistent for a single test execution
|
||||
return useScoring;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int resultsBlockIndex(Page page) {
|
||||
return page.getBlockCount() - 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertCollectedResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
|
||||
assertThat(resultVector.getBoolean(position), equalTo(isMatch));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertTermResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
|
||||
assertThat(resultVector.getBoolean(position), equalTo(isMatch));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.compute.lucene;
|
||||
|
||||
import org.apache.lucene.search.Scorable;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator.NO_MATCH_SCORE;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<DoubleVector, DoubleVector.Builder> {
|
||||
|
||||
private static final float TEST_SCORE = 1.5f;
|
||||
private static final Double DEFAULT_SCORE = 1.0;
|
||||
|
||||
@Override
|
||||
protected LuceneQueryEvaluator.DenseCollector<DoubleVector.Builder> createDenseCollector(int min, int max) {
|
||||
return new LuceneQueryEvaluator.DenseCollector<>(
|
||||
min,
|
||||
max,
|
||||
blockFactory().newDoubleVectorFixedBuilder(max - min + 1),
|
||||
b -> b.appendDouble(NO_MATCH_SCORE),
|
||||
(b, s) -> b.appendDouble(s.score())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Scorable getScorer() {
|
||||
return new Scorable() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return TEST_SCORE;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
|
||||
return new ScoreOperator(blockFactory, new LuceneQueryScoreEvaluator(blockFactory, shards), 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean usesScoring() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int resultsBlockIndex(Page page) {
|
||||
// Reuses the score block
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertCollectedResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
|
||||
if (isMatch) {
|
||||
assertThat(resultVector.getDouble(position), equalTo((double) TEST_SCORE));
|
||||
} else {
|
||||
// All docs have a default score coming from Lucene
|
||||
assertThat(resultVector.getDouble(position), equalTo(NO_MATCH_SCORE));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertTermResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
|
||||
if (isMatch) {
|
||||
assertThat(resultVector.getDouble(position), greaterThan(DEFAULT_SCORE));
|
||||
} else {
|
||||
// All docs have a default score coming from Lucene
|
||||
assertThat(resultVector.getDouble(position), equalTo(DEFAULT_SCORE));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -357,3 +357,115 @@ _id:keyword
|
|||
2
|
||||
3
|
||||
;
|
||||
|
||||
scoresNonPushableFunctions
|
||||
|
||||
required_capability: metadata_score
|
||||
|
||||
from books metadata _score
|
||||
| where length(title) > 100
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
2924 | 1.0
|
||||
8678 | 1.0
|
||||
;
|
||||
|
||||
scoresPushableFunctions
|
||||
|
||||
required_capability: metadata_score
|
||||
|
||||
from books metadata _score
|
||||
| where year >= 2017
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
6818 | 1.0
|
||||
7400 | 1.0
|
||||
8480 | 1.0
|
||||
8534 | 1.0
|
||||
8615 | 1.0
|
||||
;
|
||||
|
||||
conjunctionScoresPushableNonPushableFunctions
|
||||
|
||||
required_capability: metadata_score
|
||||
required_capability: match_function
|
||||
|
||||
from books metadata _score
|
||||
| where match(title, "Lord") and length(title) > 20
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
2675 | 2.5619282722473145
|
||||
2714 | 1.9245924949645996
|
||||
7140 | 1.746896743774414
|
||||
4023 | 1.5062403678894043
|
||||
;
|
||||
|
||||
conjunctionScoresPushableFunctions
|
||||
|
||||
required_capability: metadata_score
|
||||
required_capability: match_function
|
||||
|
||||
from books metadata _score
|
||||
| where match(title, "Lord") and ratings > 4.6
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
7140 | 2.746896743774414
|
||||
4023 | 2.5062403678894043
|
||||
;
|
||||
|
||||
disjunctionScoresPushableNonPushableFunctions
|
||||
|
||||
required_capability: metadata_score
|
||||
required_capability: match_operator_colon
|
||||
required_capability: full_text_functions_disjunctions_score
|
||||
|
||||
from books metadata _score
|
||||
| where match(title, "Lord") or length(title) > 100
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
2675 | 3.5619282722473145
|
||||
2714 | 2.9245924949645996
|
||||
7140 | 2.746896743774414
|
||||
4023 | 2.5062403678894043
|
||||
2924 | 1.0
|
||||
8678 | 1.0
|
||||
;
|
||||
|
||||
|
||||
disjunctionScoresMultipleClauses
|
||||
|
||||
required_capability: metadata_score
|
||||
required_capability: match_operator_colon
|
||||
required_capability: full_text_functions_disjunctions_score
|
||||
|
||||
from books metadata _score
|
||||
| where (title: "Lord" and length(title) > 40) or (author: "Dostoevsky" and length(title) > 40)
|
||||
| keep book_no, _score
|
||||
| sort _score desc, book_no asc
|
||||
;
|
||||
|
||||
book_no:keyword | _score:double
|
||||
8086 | 2.786686897277832
|
||||
9801 | 2.786686897277832
|
||||
1937 | 2.1503653526306152
|
||||
8534 | 2.1503653526306152
|
||||
2714 | 1.9245924949645996
|
||||
7140 | 1.746896743774414
|
||||
4023 | 1.5062403678894043
|
||||
2924 | 1.2732219696044922
|
||||
;
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plugin;
|
||||
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
||||
//@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE,org.elasticsearch.compute:TRACE", reason = "debug")
|
||||
public class ScoringIT extends AbstractEsqlIntegTestCase {
|
||||
|
||||
@Before
|
||||
public void setupIndex() {
|
||||
createAndPopulateIndex();
|
||||
}
|
||||
|
||||
public void testDefaultScoring() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
|
||||
assertThat(values.size(), equalTo(6));
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
assertThat(values.get(0).get(1), equalTo(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoringNonPushableFunctions() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE length(content) < 20
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(2));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(1));
|
||||
assertThat(values.get(1).get(0), equalTo(2));
|
||||
|
||||
assertThat((Double) values.get(0).get(1), is(1.0));
|
||||
assertThat((Double) values.get(1).get(1), is(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testDisjunctionScoring() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE match(content, "fox") OR length(content) < 20
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(3));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(1));
|
||||
assertThat(values.get(1).get(0), equalTo(6));
|
||||
assertThat(values.get(2).get(0), equalTo(2));
|
||||
|
||||
// Matches full text query and non pushable query
|
||||
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
|
||||
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
|
||||
// Matches just non pushable query
|
||||
assertThat((Double) values.get(2).get(1), equalTo(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testConjunctionPushableScoring() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE match(content, "fox") AND id > 4
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(1));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(6));
|
||||
|
||||
// Matches full text query and pushable query
|
||||
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testConjunctionNonPushableScoring() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE match(content, "fox") AND length(content) < 20
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(1));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(1));
|
||||
|
||||
// Matches full text query and pushable query
|
||||
assertThat((Double) values.get(0).get(1), greaterThan(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testDisjunctionScoringPushableFunctions() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE match(content, "fox") OR match(content, "quick")
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(2));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(6));
|
||||
assertThat(values.get(1).get(0), equalTo(1));
|
||||
|
||||
// Matches both conditions
|
||||
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
|
||||
// Matches a single condition
|
||||
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testDisjunctionScoringMultipleNonPushableFunctions() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE match(content, "fox") OR length(content) < 20 AND id > 2
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(2));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(1));
|
||||
assertThat(values.get(1).get(0), equalTo(6));
|
||||
|
||||
// Matches the full text query and the two pushable query
|
||||
assertThat((Double) values.get(0).get(1), greaterThan(2.0));
|
||||
assertThat((Double) values.get(0).get(1), lessThan(3.0));
|
||||
// Matches just the match function
|
||||
assertThat((Double) values.get(1).get(1), lessThan(2.0));
|
||||
assertThat((Double) values.get(1).get(1), greaterThan(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testDisjunctionScoringWithNot() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE NOT(match(content, "dog")) OR length(content) > 50
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(3));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(1));
|
||||
assertThat(values.get(1).get(0), equalTo(4));
|
||||
assertThat(values.get(2).get(0), equalTo(5));
|
||||
|
||||
// Matches NOT gets 0.0 and default score is 1.0
|
||||
assertThat((Double) values.get(0).get(1), equalTo(1.0));
|
||||
assertThat((Double) values.get(1).get(1), equalTo(1.0));
|
||||
assertThat((Double) values.get(2).get(1), equalTo(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoringWithNoFullTextFunction() {
|
||||
var query = """
|
||||
FROM test METADATA _score
|
||||
| WHERE length(content) > 50
|
||||
| KEEP id, _score
|
||||
| SORT _score DESC, id ASC
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "_score"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double"));
|
||||
List<List<Object>> values = getValuesList(resp);
|
||||
assertThat(values.size(), equalTo(1));
|
||||
|
||||
assertThat(values.get(0).get(0), equalTo(4));
|
||||
|
||||
// Non pushable query gets score of 0.0, summed with 1.0 coming from Lucene
|
||||
assertThat((Double) values.get(0).get(1), equalTo(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
private void createAndPopulateIndex() {
|
||||
var indexName = "test";
|
||||
var client = client().admin().indices();
|
||||
var CreateRequest = client.prepareCreate(indexName)
|
||||
.setSettings(Settings.builder().put("index.number_of_shards", 1))
|
||||
.setMapping("id", "type=integer", "content", "type=text");
|
||||
assertAcked(CreateRequest);
|
||||
client().prepareBulk()
|
||||
.add(new IndexRequest(indexName).id("1").source("id", 1, "content", "This is a brown fox"))
|
||||
.add(new IndexRequest(indexName).id("2").source("id", 2, "content", "This is a brown dog"))
|
||||
.add(new IndexRequest(indexName).id("3").source("id", 3, "content", "This dog is really brown"))
|
||||
.add(new IndexRequest(indexName).id("4").source("id", 4, "content", "The dog is brown but this document is very very long"))
|
||||
.add(new IndexRequest(indexName).id("5").source("id", 5, "content", "There is also a white cat"))
|
||||
.add(new IndexRequest(indexName).id("6").source("id", 6, "content", "The quick brown fox jumps over the lazy dog"))
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get();
|
||||
ensureYellow(indexName);
|
||||
}
|
||||
}
|
|
@ -865,7 +865,12 @@ public class EsqlCapabilities {
|
|||
/**
|
||||
* Support for RRF command
|
||||
*/
|
||||
RRF(Build.current().isSnapshot());
|
||||
RRF(Build.current().isSnapshot()),
|
||||
|
||||
/**
|
||||
* Full text functions can be scored when being part of a disjunction
|
||||
*/
|
||||
FULL_TEXT_FUNCTIONS_DISJUNCTIONS_SCORE;
|
||||
|
||||
private final boolean enabled;
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ public final class EvalMapper {
|
|||
static class Nots extends ExpressionMapper<Not> {
|
||||
@Override
|
||||
public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layout, List<ShardContext> shardContexts) {
|
||||
var expEval = toEvaluator(foldCtx, not.field(), layout);
|
||||
var expEval = toEvaluator(foldCtx, not.field(), layout, shardContexts);
|
||||
return dvrCtx -> new org.elasticsearch.xpack.esql.evaluator.predicate.operator.logical.NotEvaluator(
|
||||
not.source(),
|
||||
expEval.get(dvrCtx),
|
||||
|
@ -281,7 +281,7 @@ public final class EvalMapper {
|
|||
|
||||
@Override
|
||||
public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNull isNull, Layout layout, List<ShardContext> shardContexts) {
|
||||
var field = toEvaluator(foldCtx, isNull.field(), layout);
|
||||
var field = toEvaluator(foldCtx, isNull.field(), layout, shardContexts);
|
||||
return new IsNullEvaluatorFactory(field);
|
||||
}
|
||||
|
||||
|
@ -329,7 +329,7 @@ public final class EvalMapper {
|
|||
|
||||
@Override
|
||||
public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNotNull isNotNull, Layout layout, List<ShardContext> shardContexts) {
|
||||
return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout));
|
||||
return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout, shardContexts));
|
||||
}
|
||||
|
||||
record IsNotNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory {
|
||||
|
|
|
@ -8,27 +8,26 @@
|
|||
package org.elasticsearch.xpack.esql.expression.function.fulltext;
|
||||
|
||||
import org.elasticsearch.common.lucene.BytesRefs;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.ShardConfig;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator.ShardConfig;
|
||||
import org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
|
||||
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
|
||||
import org.elasticsearch.xpack.esql.common.Failures;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
||||
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
|
||||
import org.elasticsearch.xpack.esql.core.expression.function.Function;
|
||||
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.core.util.Holder;
|
||||
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
|
||||
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
|
||||
|
@ -39,6 +38,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
|
|||
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
|
||||
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
|
||||
import org.elasticsearch.xpack.esql.querydsl.query.TranslationAwareExpressionQuery;
|
||||
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
@ -56,7 +56,12 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
|
|||
* These functions needs to be pushed down to Lucene queries to be executed - there's no Evaluator for them, but depend on
|
||||
* {@link org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer} to rewrite them into Lucene queries.
|
||||
*/
|
||||
public abstract class FullTextFunction extends Function implements TranslationAware, PostAnalysisPlanVerificationAware, EvaluatorMapper {
|
||||
public abstract class FullTextFunction extends Function
|
||||
implements
|
||||
TranslationAware,
|
||||
PostAnalysisPlanVerificationAware,
|
||||
EvaluatorMapper,
|
||||
ExpressionScoreMapper {
|
||||
|
||||
private final Expression query;
|
||||
private final QueryBuilder queryBuilder;
|
||||
|
@ -204,13 +209,6 @@ public abstract class FullTextFunction extends Function implements TranslationAw
|
|||
failures
|
||||
);
|
||||
checkFullTextFunctionsParents(condition, failures);
|
||||
|
||||
boolean usesScore = plan.output()
|
||||
.stream()
|
||||
.anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
|
||||
if (usesScore) {
|
||||
checkFullTextSearchDisjunctions(condition, failures);
|
||||
}
|
||||
} else {
|
||||
plan.forEachExpression(FullTextFunction.class, ftf -> {
|
||||
failures.add(fail(ftf, "[{}] {} is only supported in WHERE commands", ftf.functionName(), ftf.functionType()));
|
||||
|
@ -218,65 +216,6 @@ public abstract class FullTextFunction extends Function implements TranslationAw
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether a condition contains a disjunction with a full text search.
|
||||
* If it does, check that every element of the disjunction is a full text search or combinations (AND, OR, NOT) of them.
|
||||
* If not, add a failure to the failures collection.
|
||||
*
|
||||
* @param condition condition to check for disjunctions of full text searches
|
||||
* @param failures failures collection to add to
|
||||
*/
|
||||
private static void checkFullTextSearchDisjunctions(Expression condition, Failures failures) {
|
||||
Holder<Boolean> isInvalid = new Holder<>(false);
|
||||
condition.forEachDown(Or.class, or -> {
|
||||
if (isInvalid.get()) {
|
||||
// Exit early if we already have a failures
|
||||
return;
|
||||
}
|
||||
if (checkDisjunctionPushable(or) == false) {
|
||||
isInvalid.set(true);
|
||||
failures.add(
|
||||
fail(
|
||||
or,
|
||||
"Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, "
|
||||
+ "but only if just full text functions are used in the OR condition",
|
||||
or.sourceText()
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a disjunction is pushable from the point of view of FullTextFunctions. Either it has no FullTextFunctions or
|
||||
* all it contains are FullTextFunctions.
|
||||
*
|
||||
* @param or disjunction to check
|
||||
* @return true if the disjunction is pushable, false otherwise
|
||||
*/
|
||||
private static boolean checkDisjunctionPushable(Or or) {
|
||||
boolean hasFullText = or.anyMatch(FullTextFunction.class::isInstance);
|
||||
return hasFullText == false || onlyFullTextFunctionsInExpression(or);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether an expression contains just full text functions or negations (NOT) and combinations (AND, OR) of full text functions
|
||||
*
|
||||
* @param expression expression to check
|
||||
* @return true if all children are full text functions or negations of full text functions, false otherwise
|
||||
*/
|
||||
private static boolean onlyFullTextFunctionsInExpression(Expression expression) {
|
||||
if (expression instanceof FullTextFunction) {
|
||||
return true;
|
||||
} else if (expression instanceof Not) {
|
||||
return onlyFullTextFunctionsInExpression(expression.children().get(0));
|
||||
} else if (expression instanceof BinaryLogic binaryLogic) {
|
||||
return onlyFullTextFunctionsInExpression(binaryLogic.left()) && onlyFullTextFunctionsInExpression(binaryLogic.right());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks all commands that exist before a specific type satisfy conditions.
|
||||
*
|
||||
|
@ -365,4 +304,15 @@ public abstract class FullTextFunction extends Function implements TranslationAw
|
|||
}
|
||||
return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
|
||||
List<EsPhysicalOperationProviders.ShardContext> shardContexts = toScorer.shardContexts();
|
||||
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
|
||||
int i = 0;
|
||||
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
|
||||
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
|
||||
}
|
||||
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,11 @@ package org.elasticsearch.xpack.esql.expression.predicate.logical;
|
|||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.DoubleVector;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
||||
|
@ -22,6 +27,7 @@ import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
|
|||
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
|
||||
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
|
||||
import org.elasticsearch.xpack.esql.score.ExpressionScoreMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -29,7 +35,10 @@ import java.util.List;
|
|||
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isBoolean;
|
||||
|
||||
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation> implements TranslationAware {
|
||||
public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boolean, BinaryLogicOperation>
|
||||
implements
|
||||
TranslationAware,
|
||||
ExpressionScoreMapper {
|
||||
|
||||
protected BinaryLogic(Source source, Expression left, Expression right, BinaryLogicOperation operation) {
|
||||
super(source, left, right, operation);
|
||||
|
@ -108,4 +117,33 @@ public abstract class BinaryLogic extends BinaryOperator<Boolean, Boolean, Boole
|
|||
}
|
||||
return new BoolQuery(source, isAnd, queries);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
|
||||
return context -> new BinaryLogicScorer(context, toScorer.toScorer(left()).get(context), toScorer.toScorer(right()).get(context));
|
||||
}
|
||||
|
||||
/**
|
||||
* Binary logic adds together scores coming from the left and right expressions, both for conjunctions and disjunctions
|
||||
*/
|
||||
private record BinaryLogicScorer(DriverContext driverContext, ScoreOperator.ExpressionScorer left, ScoreOperator.ExpressionScorer right)
|
||||
implements
|
||||
ScoreOperator.ExpressionScorer {
|
||||
@Override
|
||||
public DoubleBlock score(Page page) {
|
||||
DoubleVector.Builder builder = driverContext.blockFactory().newDoubleVectorFixedBuilder(page.getPositionCount());
|
||||
try (DoubleVector leftVector = left.score(page).asVector(); DoubleVector rightVector = right.score(page).asVector()) {
|
||||
for (int i = 0; i < page.getPositionCount(); i++) {
|
||||
builder.appendDouble(leftVector.getDouble(i) + rightVector.getDouble(i));
|
||||
}
|
||||
}
|
||||
return builder.build().asBlock();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
left.close();
|
||||
right.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.elasticsearch.compute.operator.Operator.OperatorFactory;
|
|||
import org.elasticsearch.compute.operator.OutputOperator.OutputOperatorFactory;
|
||||
import org.elasticsearch.compute.operator.RowInTableLookupOperator;
|
||||
import org.elasticsearch.compute.operator.RrfScoreEvalOperator;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
import org.elasticsearch.compute.operator.ShowOperator;
|
||||
import org.elasticsearch.compute.operator.SinkOperator;
|
||||
import org.elasticsearch.compute.operator.SinkOperator.SinkOperatorFactory;
|
||||
|
@ -105,6 +106,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
|
|||
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
|
||||
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
|
||||
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
|
||||
import org.elasticsearch.xpack.esql.score.ScoreMapper;
|
||||
import org.elasticsearch.xpack.esql.session.Configuration;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -729,10 +731,29 @@ public class LocalExecutionPlanner {
|
|||
private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) {
|
||||
PhysicalOperation source = plan(filter.child(), context);
|
||||
// TODO: should this be extracted into a separate eval block?
|
||||
return source.with(
|
||||
PhysicalOperation filterOperation = source.with(
|
||||
new FilterOperatorFactory(EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout, shardContexts)),
|
||||
source.layout
|
||||
);
|
||||
if (PlannerUtils.usesScoring(filter)) {
|
||||
// Add scorer operator to add the filter expression scores to the overall scores
|
||||
int scoreBlock = 0;
|
||||
for (Attribute attribute : filter.output()) {
|
||||
if (MetadataAttribute.SCORE.equals(attribute.name())) {
|
||||
break;
|
||||
}
|
||||
scoreBlock++;
|
||||
}
|
||||
if (scoreBlock == filter.output().size()) {
|
||||
throw new IllegalStateException("Couldn't find _score attribute in a WHERE clause");
|
||||
}
|
||||
|
||||
filterOperation = filterOperation.with(
|
||||
new ScoreOperator.ScoreOperatorFactory(ScoreMapper.toScorer(filter.condition(), shardContexts), scoreBlock),
|
||||
filterOperation.layout
|
||||
);
|
||||
}
|
||||
return filterOperation;
|
||||
}
|
||||
|
||||
private PhysicalOperation planLimit(LimitExec limit, LocalExecutionPlannerContext context) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
|
|||
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.core.util.Holder;
|
||||
|
@ -32,6 +33,7 @@ import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
|
|||
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer;
|
||||
import org.elasticsearch.xpack.esql.plan.QueryPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Filter;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||
|
@ -278,4 +280,8 @@ public class PlannerUtils {
|
|||
new NoopCircuitBreaker("noop-esql-breaker"),
|
||||
BigArrays.NON_RECYCLING_INSTANCE
|
||||
);
|
||||
|
||||
public static boolean usesScoring(QueryPlan<?> plan) {
|
||||
return plan.output().stream().anyMatch(attr -> attr instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.score;
|
||||
|
||||
import org.elasticsearch.compute.operator.ScoreOperator.ExpressionScorer;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Maps expressions that have a mapping to an {@link ExpressionScorer}. Allows for transforming expressions into their corresponding scores.
|
||||
*/
|
||||
public interface ExpressionScoreMapper {
|
||||
interface ToScorer {
|
||||
ExpressionScorer.Factory toScorer(Expression expression);
|
||||
|
||||
default List<EsPhysicalOperationProviders.ShardContext> shardContexts() {
|
||||
throw new UnsupportedOperationException("Shard contexts should only be needed for scoring operations");
|
||||
}
|
||||
}
|
||||
|
||||
ExpressionScorer.Factory toScorer(ToScorer toScorer);
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.score;
|
||||
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.ScoreOperator;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Maps an expression tree into ExpressionScorer.Factory, so scores can be evaluated for an expression tree.
|
||||
*/
|
||||
public class ScoreMapper {
|
||||
|
||||
public static ScoreOperator.ExpressionScorer.Factory toScorer(Expression expression, List<ShardContext> shardContexts) {
|
||||
if (expression instanceof ExpressionScoreMapper mapper) {
|
||||
return mapper.toScorer(new ExpressionScoreMapper.ToScorer() {
|
||||
@Override
|
||||
public ScoreOperator.ExpressionScorer.Factory toScorer(Expression expression) {
|
||||
return ScoreMapper.toScorer(expression, shardContexts);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ShardContext> shardContexts() {
|
||||
return shardContexts;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return page -> new DefaultScoreMapper().get(page);
|
||||
}
|
||||
|
||||
public static class DefaultScoreMapper implements ScoreOperator.ExpressionScorer.Factory {
|
||||
@Override
|
||||
public ScoreOperator.ExpressionScorer get(DriverContext driverContext) {
|
||||
return new ScoreOperator.ExpressionScorer() {
|
||||
@Override
|
||||
public DoubleBlock score(Page page) {
|
||||
return driverContext.blockFactory().newConstantDoubleBlockWith(0.0, page.getPositionCount());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -284,6 +284,11 @@ public class CsvTests extends ESTestCase {
|
|||
"CSV tests cannot currently handle the _source field mapping directives",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING.capabilityName())
|
||||
);
|
||||
assumeFalse(
|
||||
"CSV tests cannot currently handle scoring that depends on Lucene",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.METADATA_SCORE.capabilityName())
|
||||
);
|
||||
|
||||
if (Build.current().isSnapshot()) {
|
||||
assertThat(
|
||||
"Capability is not included in the enabled list capabilities on a snapshot build. Spelling mistake?",
|
||||
|
|
|
@ -9,7 +9,6 @@ package org.elasticsearch.xpack.esql.analysis;
|
|||
|
||||
import org.elasticsearch.Build;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.logging.LoggerMessageFormat;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.esql.VerificationException;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
|
@ -1467,11 +1466,12 @@ public class VerifierTests extends ESTestCase {
|
|||
private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) {
|
||||
|
||||
// Disjunctions with non-pushable functions - scoring
|
||||
checkdisjunctionScoringError("1:35", functionInvocation + " or length(first_name) > 10");
|
||||
checkdisjunctionScoringError("1:35", "match(last_name, \"Anneke\") or (" + functionInvocation + " and length(first_name) > 10)");
|
||||
checkdisjunctionScoringError(
|
||||
"1:35",
|
||||
"(" + functionInvocation + " and length(first_name) > 0) or (match(last_name, \"Anneke\") and length(first_name) > 10)"
|
||||
query("from test | where " + functionInvocation + " or length(first_name) > 10");
|
||||
query("from test | where match(last_name, \"Anneke\") or (" + functionInvocation + " and length(first_name) > 10)");
|
||||
query(
|
||||
"from test | where ("
|
||||
+ functionInvocation
|
||||
+ " and length(first_name) > 0) or (match(last_name, \"Anneke\") and length(first_name) > 10)"
|
||||
);
|
||||
|
||||
// Disjunctions with non-pushable functions - no scoring
|
||||
|
@ -1503,19 +1503,6 @@ public class VerifierTests extends ESTestCase {
|
|||
|
||||
}
|
||||
|
||||
private void checkdisjunctionScoringError(String position, String expression) {
|
||||
assertEquals(
|
||||
LoggerMessageFormat.format(
|
||||
null,
|
||||
"{}: Invalid condition when using METADATA _score [{}]. Full text functions can be used in an OR condition, "
|
||||
+ "but only if just full text functions are used in the OR condition",
|
||||
position,
|
||||
expression
|
||||
),
|
||||
error("from test metadata _score | where " + expression)
|
||||
);
|
||||
}
|
||||
|
||||
public void testQueryStringFunctionWithNonBooleanFunctions() {
|
||||
checkFullTextFunctionsWithNonBooleanFunctions("QSTR", "qstr(\"first_name: Anna\")", "function");
|
||||
}
|
||||
|
|
|
@ -1704,7 +1704,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
|
|||
assertThat(queryBuilder.value(), is(123456));
|
||||
}
|
||||
|
||||
public void testMatchFunctionWithPushableConjunction() {
|
||||
public void testMatchFunctionWithNonPushableConjunction() {
|
||||
String query = """
|
||||
from test
|
||||
| where match(last_name, "Smith") and length(first_name) > 10
|
||||
|
@ -1723,6 +1723,24 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
|
|||
assertThat(esQuery.query(), instanceOf(MatchQueryBuilder.class));
|
||||
}
|
||||
|
||||
public void testMatchFunctionWithPushableConjunction() {
|
||||
String query = """
|
||||
from test metadata _score
|
||||
| where match(last_name, "Smith") and salary > 10000
|
||||
""";
|
||||
var plan = plannerOptimizer.plan(query);
|
||||
|
||||
var limit = as(plan, LimitExec.class);
|
||||
var exchange = as(limit.child(), ExchangeExec.class);
|
||||
var project = as(exchange.child(), ProjectExec.class);
|
||||
var fieldExtract = as(project.child(), FieldExtractExec.class);
|
||||
var esQuery = as(fieldExtract.child(), EsQueryExec.class);
|
||||
Source source = new Source(2, 38, "salary > 10000");
|
||||
BoolQueryBuilder expected = new BoolQueryBuilder().must(new MatchQueryBuilder("last_name", "Smith").lenient(true))
|
||||
.must(wrapWithSingleQuery(query, QueryBuilders.rangeQuery("salary").gt(10000), "salary", source));
|
||||
assertThat(esQuery.query().toString(), equalTo(expected.toString()));
|
||||
}
|
||||
|
||||
public void testMatchFunctionWithNonPushableDisjunction() {
|
||||
String query = """
|
||||
from test
|
||||
|
@ -1754,7 +1772,6 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
|
|||
var project = as(exchange.child(), ProjectExec.class);
|
||||
var fieldExtract = as(project.child(), FieldExtractExec.class);
|
||||
var esQuery = as(fieldExtract.child(), EsQueryExec.class);
|
||||
var boolQuery = as(esQuery.query(), BoolQueryBuilder.class);
|
||||
Source source = new Source(2, 37, "emp_no > 10");
|
||||
BoolQueryBuilder expected = new BoolQueryBuilder().should(new MatchQueryBuilder("last_name", "Smith").lenient(true))
|
||||
.should(wrapWithSingleQuery(query, QueryBuilders.rangeQuery("emp_no").gt(10), "emp_no", source));
|
||||
|
|
Loading…
Reference in New Issue