[ES|QL] Infer the score mode to use from the Lucene collector (#125930)

This change uses the Lucene collector to infer which score mode to use
when the topN collector is used.
This commit is contained in:
Jim Ferenczi 2025-04-01 11:52:27 +01:00 committed by GitHub
parent 8028d5adde
commit 42b7b78a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 125 additions and 79 deletions

View File

@ -0,0 +1,5 @@
pr: 125930
summary: Infer the score mode to use from the Lucene collector
area: "ES|QL"
type: enhancement
issues: []

View File

@ -49,7 +49,7 @@ public class LuceneCountOperator extends LuceneOperator {
int taskConcurrency,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
}
@Override

View File

@ -23,6 +23,8 @@ import java.io.IOException;
import java.util.List;
import java.util.function.Function;
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
/**
* Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
*/
@ -121,7 +123,7 @@ public final class LuceneMaxFactory extends LuceneOperator.Factory {
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}

View File

@ -23,6 +23,8 @@ import java.io.IOException;
import java.util.List;
import java.util.function.Function;
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
/**
* Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
*/
@ -121,7 +123,7 @@ public final class LuceneMinFactory extends LuceneOperator.Factory {
NumberType numberType,
int limit
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
this.fieldName = fieldName;
this.numberType = numberType;
}

View File

@ -11,7 +11,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
@ -84,28 +83,27 @@ public abstract class LuceneOperator extends SourceOperator {
protected final DataPartitioning dataPartitioning;
protected final int taskConcurrency;
protected final int limit;
protected final ScoreMode scoreMode;
protected final boolean needsScore;
protected final LuceneSliceQueue sliceQueue;
/**
* Build the factory.
*
* @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
* @param needsScore Whether the score is needed.
*/
protected Factory(
List<? extends ShardContext> contexts,
Function<ShardContext, Query> queryFunction,
Function<ShardContext, Weight> weightFunction,
DataPartitioning dataPartitioning,
int taskConcurrency,
int limit,
ScoreMode scoreMode
boolean needsScore
) {
this.limit = limit;
this.scoreMode = scoreMode;
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, scoreMode);
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
this.needsScore = needsScore;
}
public final int taskConcurrency() {

View File

@ -11,7 +11,6 @@ import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
@ -56,9 +55,16 @@ public class LuceneSourceOperator extends LuceneOperator {
int taskConcurrency,
int maxPageSize,
int limit,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
super(
contexts,
weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
dataPartitioning,
taskConcurrency,
limit,
needsScore
);
this.maxPageSize = maxPageSize;
// TODO: use a single limiter for multiple stage execution
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
@ -66,7 +72,7 @@ public class LuceneSourceOperator extends LuceneOperator {
@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
}
public int maxPageSize() {
@ -81,8 +87,8 @@ public class LuceneSourceOperator extends LuceneOperator {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ "]";
}
}
@ -94,7 +100,7 @@ public class LuceneSourceOperator extends LuceneOperator {
LuceneSliceQueue sliceQueue,
int limit,
Limiter limiter,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.minPageSize = Math.max(1, maxPageSize / 2);
@ -104,7 +110,7 @@ public class LuceneSourceOperator extends LuceneOperator {
boolean success = false;
try {
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
if (scoreMode.needsScores()) {
if (needsScore) {
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
this.leafCollector = new ScoringCollector();
} else {

View File

@ -14,12 +14,12 @@ import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollectorManager;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.Strings;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.sort.SortBuilder;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -43,9 +44,6 @@ import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES;
/**
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
*/
@ -63,16 +61,16 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
int maxPageSize,
int limit,
List<SortBuilder<?>> sorts,
boolean scoring
boolean needsScore
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS);
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
this.maxPageSize = maxPageSize;
this.sorts = sorts;
}
@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
}
public int maxPageSize() {
@ -88,8 +86,8 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ ", needsScore = "
+ needsScore
+ ", sorts = ["
+ notPrettySorts
+ "]]";
@ -108,7 +106,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
private PerShardCollector perShardCollector;
private final List<SortBuilder<?>> sorts;
private final int limit;
private final ScoreMode scoreMode;
private final boolean needsScore;
public LuceneTopNSourceOperator(
BlockFactory blockFactory,
@ -116,12 +114,12 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
List<SortBuilder<?>> sorts,
int limit,
LuceneSliceQueue sliceQueue,
ScoreMode scoreMode
boolean needsScore
) {
super(blockFactory, maxPageSize, sliceQueue);
this.sorts = sorts;
this.limit = limit;
this.scoreMode = scoreMode;
this.needsScore = needsScore;
}
@Override
@ -163,7 +161,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
try {
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
// TODO: share the bottom between shardCollectors
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
}
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
@ -261,7 +259,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
}
private DoubleVector.Builder scoreVectorOrNull(int size) {
if (scoreMode.needsScores()) {
if (needsScore) {
return blockFactory.newDoubleVectorFixedBuilder(size);
} else {
return null;
@ -271,37 +269,11 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
@Override
protected void describe(StringBuilder sb) {
sb.append(", limit = ").append(limit);
sb.append(", scoreMode = ").append(scoreMode);
sb.append(", needsScore = ").append(needsScore);
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
sb.append(", sorts = [").append(notPrettySorts).append("]");
}
PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (scoreMode.needsScores() == false) {
return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
} else {
SortField[] sortFields = sortAndFormats.get().sort.getSort();
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
// SORT _score DESC
return new ScoringPerShardCollector(shardContext, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
} else {
// SORT ..., _score, ...
var sort = new Sort();
if (sortFields != null) {
var l = new ArrayList<>(Arrays.asList(sortFields));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
}
return new ScoringPerShardCollector(shardContext, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
}
}
}
abstract static class PerShardCollector {
private final ShardContext shardContext;
private final TopDocsCollector<?> collector;
@ -336,4 +308,45 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
super(shardContext, topDocsCollector);
}
}
private static Function<ShardContext, Weight> weightFunction(
Function<ShardContext, Query> queryFunction,
List<SortBuilder<?>> sorts,
boolean needsScore
) {
return ctx -> {
final var query = queryFunction.apply(ctx);
final var searcher = ctx.searcher();
try {
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
throws IOException {
Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
if (sortAndFormats.isEmpty()) {
throw new IllegalStateException("sorts must not be disabled in TopN");
}
if (needsScore == false) {
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
}
Sort sort = sortAndFormats.get().sort;
if (Sort.RELEVANCE.equals(sort)) {
// SORT _score DESC
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
}
// SORT ..., _score, ...
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
l.add(SortField.FIELD_DOC);
l.add(SortField.FIELD_SCORE);
sort = new Sort(l.toArray(SortField[]::new));
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
}
}

View File

@ -33,6 +33,8 @@ import java.io.UncheckedIOException;
import java.util.List;
import java.util.function.Function;
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
/**
* Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
* <p>
@ -56,7 +58,7 @@ public class TimeSeriesSortedSourceOperatorFactory extends LuceneOperator.Factor
int maxPageSize,
int limit
) {
super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
this.maxPageSize = maxPageSize;
}

View File

@ -120,7 +120,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneSourceOperator"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
);
}

View File

@ -110,8 +110,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
@Override
protected Matcher<String> expectedToStringOfSimple() {
return matchesRegex(
"LuceneTopNSourceOperator\\[shards = \\[test], "
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
"LuceneTopNSourceOperator\\[shards = \\[test], " + "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
);
}
@ -119,7 +118,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesRegex(
"LuceneTopNSourceOperator\\[dataPartitioning = (DOC|SHARD|SEGMENT), "
+ "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
+ "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
);
}

View File

@ -114,19 +114,19 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
@Override
protected Matcher<String> expectedToStringOfSimple() {
var s = scoring ? "TOP_DOCS_WITH_SCORES" : "TOP_DOCS";
return matchesRegex(
"LuceneTopNSourceOperator\\[shards = \\[test], maxPageSize = \\d+, limit = 100, scoreMode = " + s + ", sorts = \\[\\{.+}]]"
"LuceneTopNSourceOperator\\[shards = \\[test], maxPageSize = \\d+, limit = 100, needsScore = "
+ scoring
+ ", sorts = \\[\\{.+}]]"
);
}
@Override
protected Matcher<String> expectedDescriptionOfSimple() {
var s = scoring ? "TOP_DOCS_WITH_SCORES" : "TOP_DOCS";
return matchesRegex(
"LuceneTopNSourceOperator"
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = "
+ s
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = "
+ scoring
+ ", sorts = \\[\\{.+}]]"
);
}

View File

@ -167,7 +167,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
\\_AggregationOperator[mode = INITIAL, aggs = sum of longs]
\\_ExchangeSinkOperator""".replace(
"sourceStatus",
"dataPartitioning = SHARD, maxPageSize = " + pageSize() + ", limit = 2147483647, scoreMode = COMPLETE_NO_SCORES"
"dataPartitioning = SHARD, maxPageSize = " + pageSize() + ", limit = 2147483647, needsScore = false"
)
)
);
@ -502,7 +502,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
[{"pause_me":{"order":"asc","missing":"_last","unmapped_type":"long"}}]""";
String sourceStatus = "dataPartitioning = SHARD, maxPageSize = "
+ pageSize()
+ ", limit = 1000, scoreMode = TOP_DOCS, sorts = "
+ ", limit = 1000, needsScore = false, sorts = "
+ sortStatus;
assertThat(dataTasks(tasks).get(0).description(), equalTo("""
\\_LuceneTopNSourceOperator[sourceStatus]
@ -545,7 +545,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
scriptPermits.release(pageSize() - prereleasedDocs);
List<TaskInfo> tasks = getTasksRunning();
assertThat(dataTasks(tasks).get(0).description(), equalTo("""
\\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = limit(), scoreMode = COMPLETE_NO_SCORES]
\\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = limit(), needsScore = false]
\\_ValuesSourceReaderOperator[fields = [pause_me]]
\\_ProjectOperator[projection = [1]]
\\_ExchangeSinkOperator""".replace("pageSize()", Integer.toString(pageSize())).replace("limit()", limit)));
@ -573,8 +573,10 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
logger.info("unblocking script");
scriptPermits.release(pageSize());
List<TaskInfo> tasks = getTasksRunning();
String sourceStatus = "dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647, scoreMode = COMPLETE_NO_SCORES"
.replace("pageSize()", Integer.toString(pageSize()));
String sourceStatus = "dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647, needsScore = false".replace(
"pageSize()",
Integer.toString(pageSize())
);
assertThat(
dataTasks(tasks).get(0).description(),
equalTo(

View File

@ -31,6 +31,8 @@ import org.elasticsearch.index.IndexMode;
import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
import org.elasticsearch.index.mapper.MapperServiceTestCase;
import org.elasticsearch.node.Node;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
@ -46,11 +48,14 @@ import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.spatial.SpatialPlugin;
import org.hamcrest.Matcher;
import org.junit.After;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -58,6 +63,7 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class LocalExecutionPlannerTests extends MapperServiceTestCase {
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();
@ -78,6 +84,19 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
this.estimatedRowSizeIsHuge = estimatedRowSizeIsHuge;
}
@Override
protected Collection<Plugin> getPlugins() {
var plugin = new SpatialPlugin();
plugin.loadExtensions(new ExtensiblePlugin.ExtensionLoader() {
@Override
public <T> List<T> loadExtensions(Class<T> extensionPointType) {
return List.of();
}
});
return Collections.singletonList(plugin);
}
@After
public void closeIndex() throws IOException {
IOUtils.close(reader, directory, () -> Releasables.close(releasables), releasables::clear);
@ -251,11 +270,9 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
);
for (int i = 0; i < numShards; i++) {
shardContexts.add(
new EsPhysicalOperationProviders.DefaultShardContext(
i,
createSearchExecutionContext(createMapperService(mapping(b -> {})), searcher),
AliasFilter.EMPTY
)
new EsPhysicalOperationProviders.DefaultShardContext(i, createSearchExecutionContext(createMapperService(mapping(b -> {
b.startObject("point").field("type", "geo_point").endObject();
})), searcher), AliasFilter.EMPTY)
);
}
releasables.add(searcher);