ES|QL random sampling (#125570)

This commit is contained in:
Jan Kuipers 2025-04-23 17:48:07 +02:00 committed by GitHub
parent 85a87e71d6
commit bd1a638c03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 4555 additions and 2483 deletions

View File

@ -0,0 +1,5 @@
pr: 125570
summary: ES|QL random sampling
area: Machine Learning
type: feature
issues: []

View File

@ -226,6 +226,7 @@ public class TransportVersions {
public static final TransportVersion SYNONYMS_REFRESH_PARAM = def(9_060_0_00);
public static final TransportVersion DOC_FIELDS_AS_LIST = def(9_061_0_00);
public static final TransportVersion DENSE_VECTOR_OFF_HEAP_STATS = def(9_062_00_0);
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER = def(9_063_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -134,6 +134,7 @@ import org.elasticsearch.search.aggregations.bucket.sampler.SamplerAggregationBu
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
@ -1186,6 +1187,9 @@ public class SearchModule {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
registerQuery(
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
);
registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
}

View File

@ -44,14 +44,34 @@ public final class RandomSamplingQuery extends Query {
* can be generated
*/
public RandomSamplingQuery(double p, int seed, int hash) {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
}
checkProbabilityRange(p);
this.p = p;
this.seed = seed;
this.hash = hash;
}
/**
* Verifies that the probability is within the (0.0, 1.0) range.
* @throws IllegalArgumentException in case of an invalid probability.
*/
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
if (p <= 0.0 || p >= 1.0) {
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
}
}
public double probability() {
return p;
}
public int seed() {
return seed;
}
public int hash() {
return hash;
}
@Override
public String toString(String field) {
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
@ -98,13 +118,13 @@ public final class RandomSamplingQuery extends Query {
/**
* A DocIDSetIter that skips a geometrically random number of documents
*/
static class RandomSamplingIterator extends DocIdSetIterator {
public static class RandomSamplingIterator extends DocIdSetIterator {
private final int maxDoc;
private final double p;
private final FastGeometric distribution;
private int doc = -1;
RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
this.maxDoc = maxDoc;
this.p = p;
this.distribution = new FastGeometric(rng, p);

View File

@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.aggregations.bucket.sampler.random;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {
public static final String NAME = "random_sampling";
static final ParseField PROBABILITY = new ParseField("query");
static final ParseField SEED = new ParseField("seed");
static final ParseField HASH = new ParseField("hash");
private final double probability;
private int seed = Randomness.get().nextInt();
private int hash = 0;
public RandomSamplingQueryBuilder(double probability) {
checkProbabilityRange(probability);
this.probability = probability;
}
public RandomSamplingQueryBuilder seed(int seed) {
checkProbabilityRange(probability);
this.seed = seed;
return this;
}
public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
super(in);
this.probability = in.readDouble();
this.seed = in.readInt();
this.hash = in.readInt();
}
public RandomSamplingQueryBuilder hash(Integer hash) {
this.hash = hash;
return this;
}
public double probability() {
return probability;
}
public int seed() {
return seed;
}
public int hash() {
return hash;
}
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeDouble(probability);
out.writeInt(seed);
out.writeInt(hash);
}
@Override
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.field(SEED.getPreferredName(), seed);
builder.field(HASH.getPreferredName(), hash);
builder.endObject();
}
private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME,
false,
args -> {
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
if (args[1] != null) {
randomSamplingQueryBuilder.seed((int) args[1]);
}
if (args[2] != null) {
randomSamplingQueryBuilder.hash((int) args[2]);
}
return randomSamplingQueryBuilder;
}
);
static {
PARSER.declareDouble(constructorArg(), PROBABILITY);
PARSER.declareInt(optionalConstructorArg(), SEED);
PARSER.declareInt(optionalConstructorArg(), HASH);
}
public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}
@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
return new RandomSamplingQuery(probability, seed, hash);
}
@Override
protected boolean doEquals(RandomSamplingQueryBuilder other) {
return probability == other.probability && seed == other.seed && hash == other.hash;
}
@Override
protected int doHashCode() {
return Objects.hash(probability, seed, hash);
}
/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return NAME;
}
/**
* The minimal version of the recipient this object can be sent to
*/
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER;
}
}

View File

@ -444,6 +444,7 @@ public class SearchModuleTests extends ESTestCase {
"range",
"regexp",
"knn_score_doc",
"random_sampling",
"script",
"script_score",
"simple_query_string",

View File

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.aggregations.bucket.sampler.random;
import org.apache.lucene.search.Query;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.test.AbstractQueryTestCase;
import org.elasticsearch.xcontent.XContentParseException;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase<RandomSamplingQueryBuilder> {
@Override
protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
double probability = randomDoubleBetween(0.0, 1.0, false);
var builder = new RandomSamplingQueryBuilder(probability);
if (randomBoolean()) {
builder.seed(randomInt());
}
if (randomBoolean()) {
builder.hash(randomInt());
}
return builder;
}
@Override
protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
throws IOException {
var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
}
@Override
protected boolean supportsBoost() {
return false;
}
@Override
protected boolean supportsQueryName() {
return false;
}
@Override
public void testUnknownField() {
var json = "{ \""
+ RandomSamplingQueryBuilder.NAME
+ "\" : {\"bogusField\" : \"someValue\", \""
+ RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
+ "\" : \""
+ randomBoolean()
+ "\", \""
+ RandomSamplingQueryBuilder.SEED.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\", \""
+ RandomSamplingQueryBuilder.HASH.getPreferredName()
+ "\" : \""
+ randomInt()
+ "\" } }";
var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
assertTrue(e.getMessage().contains("bogusField"));
}
}

View File

@ -172,6 +172,10 @@ public class MetadataAttribute extends TypedAttribute {
return ATTRIBUTES_MAP.containsKey(name);
}
public static boolean isScoreAttribute(Expression a) {
return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
}
@Override
@SuppressWarnings("checkstyle:EqualsHashCode")// equals is implemented in parent. See innerEquals instead
public int hashCode() {

View File

@ -294,4 +294,21 @@ public final class Page implements Writeable {
}
}
}
public Page filter(int... positions) {
Block[] filteredBlocks = new Block[blocks.length];
boolean success = false;
try {
for (int i = 0; i < blocks.length; i++) {
filteredBlocks[i] = getBlock(i).filter(positions);
}
success = true;
} finally {
releaseBlocks();
if (success == false) {
Releasables.closeExpectNoException(filteredBlocks);
}
}
return new Page(filteredBlocks);
}
}

View File

@ -19,9 +19,9 @@ import org.elasticsearch.xpack.ml.aggs.MlAggsHelper;
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointDetector;
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
/**
@ -68,8 +68,8 @@ public class ChangePointOperator implements Operator {
this.sourceColumn = sourceColumn;
finished = false;
inputPages = new LinkedList<>();
outputPages = new LinkedList<>();
inputPages = new ArrayDeque<>();
outputPages = new ArrayDeque<>();
warnings = null;
}

View File

@ -7,7 +7,6 @@
package org.elasticsearch.compute.operator;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@ -69,20 +68,7 @@ public class FilterOperator extends AbstractPageMappingOperator {
}
positions = Arrays.copyOf(positions, rowCount);
Block[] filteredBlocks = new Block[page.getBlockCount()];
boolean success = false;
try {
for (int i = 0; i < page.getBlockCount(); i++) {
filteredBlocks[i] = page.getBlock(i).filter(positions);
}
success = true;
} finally {
page.releaseBlocks();
if (success == false) {
Releasables.closeExpectNoException(filteredBlocks);
}
}
return new Page(filteredBlocks);
return page.filter(positions);
}
}

View File

@ -0,0 +1,228 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.Objects;
import java.util.SplittableRandom;
public class SampleOperator implements Operator {
public record Factory(double probability, int seed) implements OperatorFactory {
@Override
public SampleOperator get(DriverContext driverContext) {
return new SampleOperator(probability, seed);
}
@Override
public String describe() {
return "SampleOperator[probability = " + probability + ", seed = " + seed + "]";
}
}
private final Deque<Page> outputPages;
/**
* At any time this iterator will point to be next document that still
* needs to be sampled. If this document is on the current page, it's
* added to the output and the iterator is advanced. It the document is
* not on the current page, the current page is finished and the index
* is used for the next page.
*/
private final RandomSamplingQuery.RandomSamplingIterator randomSamplingIterator;
private boolean finished;
private int pagesProcessed = 0;
private int rowsReceived = 0;
private int rowsEmitted = 0;
private long collectNanos;
private long emitNanos;
private SampleOperator(double probability, int seed) {
finished = false;
outputPages = new ArrayDeque<>();
SplittableRandom random = new SplittableRandom(seed);
randomSamplingIterator = new RandomSamplingQuery.RandomSamplingIterator(Integer.MAX_VALUE, probability, random::nextInt);
// Initialize the iterator to the next document that needs to be sampled.
randomSamplingIterator.nextDoc();
}
/**
* whether the given operator can accept more input pages
*/
@Override
public boolean needsInput() {
return finished == false;
}
/**
* adds an input page to the operator. only called when needsInput() == true and isFinished() == false
*
* @param page
* @throws UnsupportedOperationException if the operator is a {@link SourceOperator}
*/
@Override
public void addInput(Page page) {
long startTime = System.nanoTime();
createOutputPage(page);
rowsReceived += page.getPositionCount();
page.releaseBlocks();
pagesProcessed++;
collectNanos += System.nanoTime() - startTime;
}
private void createOutputPage(Page page) {
final int[] sampledPositions = new int[page.getPositionCount()];
int sampledIdx = 0;
for (int i = randomSamplingIterator.docID(); i - rowsReceived < page.getPositionCount(); i = randomSamplingIterator.nextDoc()) {
sampledPositions[sampledIdx++] = i - rowsReceived;
}
if (sampledIdx > 0) {
outputPages.add(page.filter(Arrays.copyOf(sampledPositions, sampledIdx)));
}
}
/**
* notifies the operator that it won't receive any more input pages
*/
@Override
public void finish() {
finished = true;
}
/**
* whether the operator has finished processing all input pages and made the corresponding output pages available
*/
@Override
public boolean isFinished() {
return finished && outputPages.isEmpty();
}
@Override
public Page getOutput() {
final var emitStart = System.nanoTime();
Page page;
if (outputPages.isEmpty()) {
page = null;
} else {
page = outputPages.removeFirst();
rowsEmitted += page.getPositionCount();
}
emitNanos += System.nanoTime() - emitStart;
return page;
}
/**
* notifies the operator that it won't be used anymore (i.e. none of the other methods called),
* and its resources can be cleaned up
*/
@Override
public void close() {
for (Page page : outputPages) {
page.releaseBlocks();
}
}
@Override
public String toString() {
return "SampleOperator[sampled = " + rowsEmitted + "/" + rowsReceived + "]";
}
@Override
public Operator.Status status() {
return new Status(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
}
private record Status(long collectNanos, long emitNanos, int pagesProcessed, int rowsReceived, int rowsEmitted)
implements
Operator.Status {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Operator.Status.class,
"sample",
Status::new
);
Status(StreamInput streamInput) throws IOException {
this(streamInput.readVLong(), streamInput.readVLong(), streamInput.readVInt(), streamInput.readVInt(), streamInput.readVInt());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(collectNanos);
out.writeVLong(emitNanos);
out.writeVInt(pagesProcessed);
out.writeVInt(rowsReceived);
out.writeVInt(rowsEmitted);
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("collect_nanos", collectNanos);
if (builder.humanReadable()) {
builder.field("collect_time", TimeValue.timeValueNanos(collectNanos));
}
builder.field("emit_nanos", emitNanos);
if (builder.humanReadable()) {
builder.field("emit_time", TimeValue.timeValueNanos(emitNanos));
}
builder.field("pages_processed", pagesProcessed);
builder.field("rows_received", rowsReceived);
builder.field("rows_emitted", rowsEmitted);
return builder.endObject();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Status other = (Status) o;
return collectNanos == other.collectNanos
&& emitNanos == other.emitNanos
&& pagesProcessed == other.pagesProcessed
&& rowsReceived == other.rowsReceived
&& rowsEmitted == other.rowsEmitted;
}
@Override
public int hashCode() {
return Objects.hash(collectNanos, emitNanos, pagesProcessed, rowsReceived, rowsEmitted);
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ZERO;
}
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.compute.operator;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.test.OperatorTestCase;
import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator;
import org.hamcrest.Matcher;
import java.util.List;
import java.util.stream.LongStream;
import static org.hamcrest.Matchers.both;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.matchesPattern;
public class SampleOperatorTests extends OperatorTestCase {
@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size));
}
@Override
protected void assertSimpleOutput(List<Page> input, List<Page> results) {
int inputCount = input.stream().mapToInt(Page::getPositionCount).sum();
int outputCount = results.stream().mapToInt(Page::getPositionCount).sum();
double meanExpectedOutputCount = 0.5 * inputCount;
double stdDevExpectedOutputCount = Math.sqrt(meanExpectedOutputCount);
assertThat((double) outputCount, closeTo(meanExpectedOutputCount, 10 * stdDevExpectedOutputCount));
}
@Override
protected SampleOperator.Factory simple() {
return new SampleOperator.Factory(0.5, randomInt());
}
@Override
protected Matcher<String> expectedDescriptionOfSimple() {
return matchesPattern("SampleOperator\\[probability = 0.5, seed = -?\\d+]");
}
@Override
protected Matcher<String> expectedToStringOfSimple() {
return equalTo("SampleOperator[sampled = 0/0]");
}
public void testAccuracy() {
BlockFactory blockFactory = driverContext().blockFactory();
int totalPositionCount = 0;
for (int iter = 0; iter < 10000; iter++) {
SampleOperator operator = simple().get(driverContext());
operator.addInput(new Page(blockFactory.newConstantNullBlock(20000)));
Page output = operator.getOutput();
// 10000 expected rows, stddev=sqrt(10000)=100, so this is 10 stddevs.
assertThat(output.getPositionCount(), both(greaterThan(9000)).and(lessThan(11000)));
totalPositionCount += output.getPositionCount();
output.releaseBlocks();
}
int averagePositionCount = totalPositionCount / 10000;
// Running 10000 times, so the stddev is divided by sqrt(10000)=100, so this 10 stddevs again.
assertThat(averagePositionCount, both(greaterThan(9990)).and(lessThan(10010)));
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.qa.single_node;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
import org.elasticsearch.test.TestClustersThreadFilter;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.xpack.esql.qa.rest.RestSampleTestCase;
import org.junit.ClassRule;
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class RestSampleIT extends RestSampleTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster();
@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
}

View File

@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.qa.rest;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.both;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
public class RestSampleTestCase extends ESRestTestCase {
@Before
public void skipWhenSampleDisabled() throws IOException {
assumeTrue(
"Requires SAMPLE capability",
EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.SAMPLE.capabilityName()))
);
}
@Before
@After
public void assertRequestBreakerEmpty() throws Exception {
EsqlSpecTestCase.assertRequestBreakerEmpty();
}
/**
* Matcher for the results of sampling 50% of the elements 0,1,2,...,998,999.
* The results should consist of unique numbers in [0,999]. Furthermore, the
* size should on average be 500. Allowing for 10 stddev deviations, the size
* should be in [250,750].
*/
private static final TypeSafeMatcher<List<List<Integer>>> RESULT_MATCHER = new TypeSafeMatcher<>() {
@Override
public void describeTo(Description description) {
description.appendText("a list with between 250 and 750 unique elements in [0,999]");
}
@Override
protected boolean matchesSafely(List<List<Integer>> lists) {
if (lists.size() < 250 || lists.size() > 750) {
return false;
}
Set<Integer> values = new HashSet<>();
for (List<Integer> list : lists) {
if (list.size() != 1) {
return false;
}
Integer value = list.get(0);
if (value == null || value < 0 || value >= 1000) {
return false;
}
values.add(value);
}
return values.size() == lists.size();
}
};
/**
* This tests sampling in the Lucene query.
*/
public void testSample_withFrom() throws IOException {
createTestIndex();
test("FROM sample-test-index | SAMPLE 0.5 | LIMIT 1000");
deleteTestIndex();
}
/**
* This tests sampling in the ES|QL operator.
*/
public void testSample_withRow() throws IOException {
List<Integer> numbers = IntStream.range(0, 999).boxed().toList();
test("ROW value = " + numbers + " | MV_EXPAND value | SAMPLE 0.5 | LIMIT 1000");
}
private void test(String query) throws IOException {
int iterationCount = 1000;
int totalResultSize = 0;
for (int iteration = 0; iteration < iterationCount; iteration++) {
Map<String, Object> result = runEsqlQuery(query);
assertResultMap(result, defaultOutputColumns(), RESULT_MATCHER);
totalResultSize += ((List<?>) result.get("values")).size();
}
// On average there's 500 elements in the results set.
// Allowing for 10 stddev deviations, it should be in [490,510].
assertThat(totalResultSize / iterationCount, both(greaterThan(490)).and(lessThan(510)));
}
private static List<Map<String, String>> defaultOutputColumns() {
return List.of(Map.of("name", "value", "type", "integer"));
}
private Map<String, Object> runEsqlQuery(String query) throws IOException {
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query);
return RestEsqlTestCase.runEsqlSync(builder);
}
private void createTestIndex() throws IOException {
Request request = new Request("PUT", "/sample-test-index");
request.setJsonEntity("""
{
"mappings": {
"properties": {
"value": { "type": "integer" }
}
}
}""");
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
StringBuilder requestJsonEntity = new StringBuilder();
for (int i = 0; i < 1000; i++) {
requestJsonEntity.append("{ \"index\": {\"_id\": " + i + "} }\n");
requestJsonEntity.append("{ \"value\": " + i + " }\n");
}
request = new Request("POST", "/sample-test-index/_bulk");
request.addParameter("refresh", "true");
request.setJsonEntity(requestJsonEntity.toString());
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
}
private void deleteTestIndex() throws IOException {
try {
adminClient().performRequest(new Request("DELETE", "/sample-test-index"));
} catch (ResponseException e) {
throw e;
}
}
}

View File

@ -0,0 +1,231 @@
// Tests focused on the SAMPLE command
// Note: this tests only basic behavior, because of limitations of the CSV tests.
// Most tests assert that the count, average and sum of some values are within a
// range. These stats should be correctly adjusted for the sampling. Furthermore,
// they also assert the value of MV_COUNT(VALUES(...)), which is not adjusted for
// the sampling and therefore gives the size of the sample.
// All ranges are very loose, so that the tests should fail less than 1 in a billion.
// The range checks are done in ES|QL, resulting in one boolean value (is_expected),
// because the CSV tests don't support such assertions.
row
required_capability: sample
ROW x = 1 | SAMPLE .999999999
;
x:integer
1
;
row and mv_expand
required_capability: sample
ROW x = [1,2,3,4,5] | MV_EXPAND x | SAMPLE .999999999
;
x:integer
1
2
3
4
5
;
adjust stats for sampling
required_capability: sample
FROM employees
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no), sum_emp_no = SUM(emp_no)
| EVAL is_expected = count >= 40 AND count <= 160 AND
values_count >= 20 AND values_count <= 80 AND
avg_emp_no > 10010 AND avg_emp_no < 10090 AND
sum_emp_no > 40*10010 AND sum_emp_no < 160*10090
| KEEP is_expected
;
is_expected:boolean
true
;
before where
required_capability: sample
FROM employees
| SAMPLE 0.5
| WHERE emp_no > 10050
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
values_count >= 5 AND values_count <= 45 AND
avg_emp_no > 10055 AND avg_emp_no < 10095
| KEEP is_expected
;
is_expected:boolean
true
;
after where
required_capability: sample
FROM employees
| WHERE emp_no <= 10050
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
values_count >= 5 AND values_count <= 45 AND
avg_emp_no > 10005 AND avg_emp_no < 10045
| KEEP is_expected
;
is_expected:boolean
true
;
before sort
required_capability: sample
FROM employees
| SAMPLE 0.5
| SORT emp_no
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 40 AND count <= 160 AND
values_count >= 20 AND values_count <= 80 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
is_expected:boolean
true
;
after sort
required_capability: sample
FROM employees
| SORT emp_no
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 40 AND count <= 160 AND
values_count >= 20 AND values_count <= 80 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
is_expected:boolean
true
;
before limit
required_capability: sample
FROM employees
| SAMPLE 0.5
| LIMIT 10
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no))
| EVAL is_expected = count == 10 AND values_count == 10
| KEEP is_expected
;
is_expected:boolean
true
;
after limit
required_capability: sample
FROM employees
| LIMIT 50
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no))
| EVAL is_expected = count >= 10 AND count <= 90 AND
values_count >= 5 AND values_count <= 45
| KEEP is_expected
;
is_expected:boolean
true
;
before mv_expand
required_capability: sample
ROW x = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50], y = [1,2]
| MV_EXPAND x
| SAMPLE 0.85
| MV_EXPAND y
| STATS count = COUNT() BY x
| STATS counts = VALUES(count)
| EVAL is_expected = MV_COUNT(counts) == 1 AND MV_MIN(counts) == 2
| KEEP is_expected
;
is_expected:boolean
true
;
after mv_expand
required_capability: sample
ROW x = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50], y = [1,2]
| MV_EXPAND x
| MV_EXPAND y
| SAMPLE 0.85
| STATS count = COUNT() BY x
| STATS counts = VALUES(count)
| EVAL is_expected = MV_COUNT(counts) == 2 AND MV_MIN(counts) == 1 AND MV_MAX(counts) == 2
| KEEP is_expected
;
is_expected:boolean
true
;
multiple samples
required_capability: sample
FROM employees
| SAMPLE 0.7
| SAMPLE 0.8
| SAMPLE 0.9
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 40 AND count <= 160 AND
values_count >= 20 AND values_count <= 80 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
is_expected:boolean
true
;
after stats
required_capability: sample
FROM employees
| SAMPLE 0.5
| STATS avg_salary = AVG(salary) BY job_positions
| SAMPLE 0.8
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(avg_salary)), avg_avg_salary = AVG(avg_salary)
| EVAL is_expected = count >= 1 AND count <= 20 AND
values_count >= 1 AND values_count <= 16 AND
avg_avg_salary > 25000 AND avg_avg_salary < 75000
| KEEP is_expected
;
is_expected:boolean
true
;

View File

@ -15,127 +15,128 @@ WHERE=14
DEV_COMPLETION=15
DEV_INLINESTATS=16
DEV_RERANK=17
FROM=18
DEV_TIME_SERIES=19
DEV_FORK=20
JOIN_LOOKUP=21
DEV_JOIN_FULL=22
DEV_JOIN_LEFT=23
DEV_JOIN_RIGHT=24
DEV_LOOKUP=25
MV_EXPAND=26
DROP=27
KEEP=28
DEV_INSIST=29
DEV_RRF=30
RENAME=31
SHOW=32
UNKNOWN_CMD=33
CHANGE_POINT_LINE_COMMENT=34
CHANGE_POINT_MULTILINE_COMMENT=35
CHANGE_POINT_WS=36
ENRICH_POLICY_NAME=37
ENRICH_LINE_COMMENT=38
ENRICH_MULTILINE_COMMENT=39
ENRICH_WS=40
ENRICH_FIELD_LINE_COMMENT=41
ENRICH_FIELD_MULTILINE_COMMENT=42
ENRICH_FIELD_WS=43
SETTING=44
SETTING_LINE_COMMENT=45
SETTTING_MULTILINE_COMMENT=46
SETTING_WS=47
EXPLAIN_WS=48
EXPLAIN_LINE_COMMENT=49
EXPLAIN_MULTILINE_COMMENT=50
PIPE=51
QUOTED_STRING=52
INTEGER_LITERAL=53
DECIMAL_LITERAL=54
AND=55
AS=56
ASC=57
ASSIGN=58
BY=59
CAST_OP=60
COLON=61
COMMA=62
DESC=63
DOT=64
FALSE=65
FIRST=66
IN=67
IS=68
LAST=69
LIKE=70
NOT=71
NULL=72
NULLS=73
ON=74
OR=75
PARAM=76
RLIKE=77
TRUE=78
WITH=79
EQ=80
CIEQ=81
NEQ=82
LT=83
LTE=84
GT=85
GTE=86
PLUS=87
MINUS=88
ASTERISK=89
SLASH=90
PERCENT=91
LEFT_BRACES=92
RIGHT_BRACES=93
DOUBLE_PARAMS=94
NAMED_OR_POSITIONAL_PARAM=95
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=96
OPENING_BRACKET=97
CLOSING_BRACKET=98
LP=99
RP=100
UNQUOTED_IDENTIFIER=101
QUOTED_IDENTIFIER=102
EXPR_LINE_COMMENT=103
EXPR_MULTILINE_COMMENT=104
EXPR_WS=105
METADATA=106
UNQUOTED_SOURCE=107
FROM_LINE_COMMENT=108
FROM_MULTILINE_COMMENT=109
FROM_WS=110
FORK_WS=111
FORK_LINE_COMMENT=112
FORK_MULTILINE_COMMENT=113
JOIN=114
USING=115
JOIN_LINE_COMMENT=116
JOIN_MULTILINE_COMMENT=117
JOIN_WS=118
LOOKUP_LINE_COMMENT=119
LOOKUP_MULTILINE_COMMENT=120
LOOKUP_WS=121
LOOKUP_FIELD_LINE_COMMENT=122
LOOKUP_FIELD_MULTILINE_COMMENT=123
LOOKUP_FIELD_WS=124
MVEXPAND_LINE_COMMENT=125
MVEXPAND_MULTILINE_COMMENT=126
MVEXPAND_WS=127
ID_PATTERN=128
PROJECT_LINE_COMMENT=129
PROJECT_MULTILINE_COMMENT=130
PROJECT_WS=131
RENAME_LINE_COMMENT=132
RENAME_MULTILINE_COMMENT=133
RENAME_WS=134
INFO=135
SHOW_LINE_COMMENT=136
SHOW_MULTILINE_COMMENT=137
SHOW_WS=138
DEV_SAMPLE=18
FROM=19
DEV_TIME_SERIES=20
DEV_FORK=21
JOIN_LOOKUP=22
DEV_JOIN_FULL=23
DEV_JOIN_LEFT=24
DEV_JOIN_RIGHT=25
DEV_LOOKUP=26
MV_EXPAND=27
DROP=28
KEEP=29
DEV_INSIST=30
DEV_RRF=31
RENAME=32
SHOW=33
UNKNOWN_CMD=34
CHANGE_POINT_LINE_COMMENT=35
CHANGE_POINT_MULTILINE_COMMENT=36
CHANGE_POINT_WS=37
ENRICH_POLICY_NAME=38
ENRICH_LINE_COMMENT=39
ENRICH_MULTILINE_COMMENT=40
ENRICH_WS=41
ENRICH_FIELD_LINE_COMMENT=42
ENRICH_FIELD_MULTILINE_COMMENT=43
ENRICH_FIELD_WS=44
SETTING=45
SETTING_LINE_COMMENT=46
SETTTING_MULTILINE_COMMENT=47
SETTING_WS=48
EXPLAIN_WS=49
EXPLAIN_LINE_COMMENT=50
EXPLAIN_MULTILINE_COMMENT=51
PIPE=52
QUOTED_STRING=53
INTEGER_LITERAL=54
DECIMAL_LITERAL=55
AND=56
AS=57
ASC=58
ASSIGN=59
BY=60
CAST_OP=61
COLON=62
COMMA=63
DESC=64
DOT=65
FALSE=66
FIRST=67
IN=68
IS=69
LAST=70
LIKE=71
NOT=72
NULL=73
NULLS=74
ON=75
OR=76
PARAM=77
RLIKE=78
TRUE=79
WITH=80
EQ=81
CIEQ=82
NEQ=83
LT=84
LTE=85
GT=86
GTE=87
PLUS=88
MINUS=89
ASTERISK=90
SLASH=91
PERCENT=92
LEFT_BRACES=93
RIGHT_BRACES=94
DOUBLE_PARAMS=95
NAMED_OR_POSITIONAL_PARAM=96
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=97
OPENING_BRACKET=98
CLOSING_BRACKET=99
LP=100
RP=101
UNQUOTED_IDENTIFIER=102
QUOTED_IDENTIFIER=103
EXPR_LINE_COMMENT=104
EXPR_MULTILINE_COMMENT=105
EXPR_WS=106
METADATA=107
UNQUOTED_SOURCE=108
FROM_LINE_COMMENT=109
FROM_MULTILINE_COMMENT=110
FROM_WS=111
FORK_WS=112
FORK_LINE_COMMENT=113
FORK_MULTILINE_COMMENT=114
JOIN=115
USING=116
JOIN_LINE_COMMENT=117
JOIN_MULTILINE_COMMENT=118
JOIN_WS=119
LOOKUP_LINE_COMMENT=120
LOOKUP_MULTILINE_COMMENT=121
LOOKUP_WS=122
LOOKUP_FIELD_LINE_COMMENT=123
LOOKUP_FIELD_MULTILINE_COMMENT=124
LOOKUP_FIELD_WS=125
MVEXPAND_LINE_COMMENT=126
MVEXPAND_MULTILINE_COMMENT=127
MVEXPAND_WS=128
ID_PATTERN=129
PROJECT_LINE_COMMENT=130
PROJECT_MULTILINE_COMMENT=131
PROJECT_WS=132
RENAME_LINE_COMMENT=133
RENAME_MULTILINE_COMMENT=134
RENAME_WS=135
INFO=136
SHOW_LINE_COMMENT=137
SHOW_MULTILINE_COMMENT=138
SHOW_WS=139
'change_point'=4
'enrich'=5
'explain'=6
@ -147,57 +148,57 @@ SHOW_WS=138
'sort'=12
'stats'=13
'where'=14
'from'=18
'lookup'=21
'mv_expand'=26
'drop'=27
'keep'=28
'rename'=31
'show'=32
'|'=51
'and'=55
'as'=56
'asc'=57
'='=58
'by'=59
'::'=60
':'=61
','=62
'desc'=63
'.'=64
'false'=65
'first'=66
'in'=67
'is'=68
'last'=69
'like'=70
'not'=71
'null'=72
'nulls'=73
'on'=74
'or'=75
'?'=76
'rlike'=77
'true'=78
'with'=79
'=='=80
'=~'=81
'!='=82
'<'=83
'<='=84
'>'=85
'>='=86
'+'=87
'-'=88
'*'=89
'/'=90
'%'=91
'{'=92
'}'=93
'??'=94
']'=98
')'=100
'metadata'=106
'join'=114
'USING'=115
'info'=135
'from'=19
'lookup'=22
'mv_expand'=27
'drop'=28
'keep'=29
'rename'=32
'show'=33
'|'=52
'and'=56
'as'=57
'asc'=58
'='=59
'by'=60
'::'=61
':'=62
','=63
'desc'=64
'.'=65
'false'=66
'first'=67
'in'=68
'is'=69
'last'=70
'like'=71
'not'=72
'null'=73
'nulls'=74
'on'=75
'or'=76
'?'=77
'rlike'=78
'true'=79
'with'=80
'=='=81
'=~'=82
'!='=83
'<'=84
'<='=85
'>'=86
'>='=87
'+'=88
'-'=89
'*'=90
'/'=91
'%'=92
'{'=93
'}'=94
'??'=95
']'=99
')'=101
'metadata'=107
'join'=115
'USING'=116
'info'=136

View File

@ -64,6 +64,7 @@ processingCommand
| {this.isDevVersion()}? forkCommand
| {this.isDevVersion()}? rerankCommand
| {this.isDevVersion()}? rrfCommand
| {this.isDevVersion()}? sampleCommand
;
whereCommand
@ -301,3 +302,7 @@ rerankCommand
completionCommand
: DEV_COMPLETION prompt=primaryExpression WITH inferenceId=identifierOrParameter (AS targetField=qualifiedName)?
;
sampleCommand
: DEV_SAMPLE probability=decimalValue seed=integerValue?
;

View File

@ -15,127 +15,128 @@ WHERE=14
DEV_COMPLETION=15
DEV_INLINESTATS=16
DEV_RERANK=17
FROM=18
DEV_TIME_SERIES=19
DEV_FORK=20
JOIN_LOOKUP=21
DEV_JOIN_FULL=22
DEV_JOIN_LEFT=23
DEV_JOIN_RIGHT=24
DEV_LOOKUP=25
MV_EXPAND=26
DROP=27
KEEP=28
DEV_INSIST=29
DEV_RRF=30
RENAME=31
SHOW=32
UNKNOWN_CMD=33
CHANGE_POINT_LINE_COMMENT=34
CHANGE_POINT_MULTILINE_COMMENT=35
CHANGE_POINT_WS=36
ENRICH_POLICY_NAME=37
ENRICH_LINE_COMMENT=38
ENRICH_MULTILINE_COMMENT=39
ENRICH_WS=40
ENRICH_FIELD_LINE_COMMENT=41
ENRICH_FIELD_MULTILINE_COMMENT=42
ENRICH_FIELD_WS=43
SETTING=44
SETTING_LINE_COMMENT=45
SETTTING_MULTILINE_COMMENT=46
SETTING_WS=47
EXPLAIN_WS=48
EXPLAIN_LINE_COMMENT=49
EXPLAIN_MULTILINE_COMMENT=50
PIPE=51
QUOTED_STRING=52
INTEGER_LITERAL=53
DECIMAL_LITERAL=54
AND=55
AS=56
ASC=57
ASSIGN=58
BY=59
CAST_OP=60
COLON=61
COMMA=62
DESC=63
DOT=64
FALSE=65
FIRST=66
IN=67
IS=68
LAST=69
LIKE=70
NOT=71
NULL=72
NULLS=73
ON=74
OR=75
PARAM=76
RLIKE=77
TRUE=78
WITH=79
EQ=80
CIEQ=81
NEQ=82
LT=83
LTE=84
GT=85
GTE=86
PLUS=87
MINUS=88
ASTERISK=89
SLASH=90
PERCENT=91
LEFT_BRACES=92
RIGHT_BRACES=93
DOUBLE_PARAMS=94
NAMED_OR_POSITIONAL_PARAM=95
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=96
OPENING_BRACKET=97
CLOSING_BRACKET=98
LP=99
RP=100
UNQUOTED_IDENTIFIER=101
QUOTED_IDENTIFIER=102
EXPR_LINE_COMMENT=103
EXPR_MULTILINE_COMMENT=104
EXPR_WS=105
METADATA=106
UNQUOTED_SOURCE=107
FROM_LINE_COMMENT=108
FROM_MULTILINE_COMMENT=109
FROM_WS=110
FORK_WS=111
FORK_LINE_COMMENT=112
FORK_MULTILINE_COMMENT=113
JOIN=114
USING=115
JOIN_LINE_COMMENT=116
JOIN_MULTILINE_COMMENT=117
JOIN_WS=118
LOOKUP_LINE_COMMENT=119
LOOKUP_MULTILINE_COMMENT=120
LOOKUP_WS=121
LOOKUP_FIELD_LINE_COMMENT=122
LOOKUP_FIELD_MULTILINE_COMMENT=123
LOOKUP_FIELD_WS=124
MVEXPAND_LINE_COMMENT=125
MVEXPAND_MULTILINE_COMMENT=126
MVEXPAND_WS=127
ID_PATTERN=128
PROJECT_LINE_COMMENT=129
PROJECT_MULTILINE_COMMENT=130
PROJECT_WS=131
RENAME_LINE_COMMENT=132
RENAME_MULTILINE_COMMENT=133
RENAME_WS=134
INFO=135
SHOW_LINE_COMMENT=136
SHOW_MULTILINE_COMMENT=137
SHOW_WS=138
DEV_SAMPLE=18
FROM=19
DEV_TIME_SERIES=20
DEV_FORK=21
JOIN_LOOKUP=22
DEV_JOIN_FULL=23
DEV_JOIN_LEFT=24
DEV_JOIN_RIGHT=25
DEV_LOOKUP=26
MV_EXPAND=27
DROP=28
KEEP=29
DEV_INSIST=30
DEV_RRF=31
RENAME=32
SHOW=33
UNKNOWN_CMD=34
CHANGE_POINT_LINE_COMMENT=35
CHANGE_POINT_MULTILINE_COMMENT=36
CHANGE_POINT_WS=37
ENRICH_POLICY_NAME=38
ENRICH_LINE_COMMENT=39
ENRICH_MULTILINE_COMMENT=40
ENRICH_WS=41
ENRICH_FIELD_LINE_COMMENT=42
ENRICH_FIELD_MULTILINE_COMMENT=43
ENRICH_FIELD_WS=44
SETTING=45
SETTING_LINE_COMMENT=46
SETTTING_MULTILINE_COMMENT=47
SETTING_WS=48
EXPLAIN_WS=49
EXPLAIN_LINE_COMMENT=50
EXPLAIN_MULTILINE_COMMENT=51
PIPE=52
QUOTED_STRING=53
INTEGER_LITERAL=54
DECIMAL_LITERAL=55
AND=56
AS=57
ASC=58
ASSIGN=59
BY=60
CAST_OP=61
COLON=62
COMMA=63
DESC=64
DOT=65
FALSE=66
FIRST=67
IN=68
IS=69
LAST=70
LIKE=71
NOT=72
NULL=73
NULLS=74
ON=75
OR=76
PARAM=77
RLIKE=78
TRUE=79
WITH=80
EQ=81
CIEQ=82
NEQ=83
LT=84
LTE=85
GT=86
GTE=87
PLUS=88
MINUS=89
ASTERISK=90
SLASH=91
PERCENT=92
LEFT_BRACES=93
RIGHT_BRACES=94
DOUBLE_PARAMS=95
NAMED_OR_POSITIONAL_PARAM=96
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=97
OPENING_BRACKET=98
CLOSING_BRACKET=99
LP=100
RP=101
UNQUOTED_IDENTIFIER=102
QUOTED_IDENTIFIER=103
EXPR_LINE_COMMENT=104
EXPR_MULTILINE_COMMENT=105
EXPR_WS=106
METADATA=107
UNQUOTED_SOURCE=108
FROM_LINE_COMMENT=109
FROM_MULTILINE_COMMENT=110
FROM_WS=111
FORK_WS=112
FORK_LINE_COMMENT=113
FORK_MULTILINE_COMMENT=114
JOIN=115
USING=116
JOIN_LINE_COMMENT=117
JOIN_MULTILINE_COMMENT=118
JOIN_WS=119
LOOKUP_LINE_COMMENT=120
LOOKUP_MULTILINE_COMMENT=121
LOOKUP_WS=122
LOOKUP_FIELD_LINE_COMMENT=123
LOOKUP_FIELD_MULTILINE_COMMENT=124
LOOKUP_FIELD_WS=125
MVEXPAND_LINE_COMMENT=126
MVEXPAND_MULTILINE_COMMENT=127
MVEXPAND_WS=128
ID_PATTERN=129
PROJECT_LINE_COMMENT=130
PROJECT_MULTILINE_COMMENT=131
PROJECT_WS=132
RENAME_LINE_COMMENT=133
RENAME_MULTILINE_COMMENT=134
RENAME_WS=135
INFO=136
SHOW_LINE_COMMENT=137
SHOW_MULTILINE_COMMENT=138
SHOW_WS=139
'change_point'=4
'enrich'=5
'explain'=6
@ -147,57 +148,57 @@ SHOW_WS=138
'sort'=12
'stats'=13
'where'=14
'from'=18
'lookup'=21
'mv_expand'=26
'drop'=27
'keep'=28
'rename'=31
'show'=32
'|'=51
'and'=55
'as'=56
'asc'=57
'='=58
'by'=59
'::'=60
':'=61
','=62
'desc'=63
'.'=64
'false'=65
'first'=66
'in'=67
'is'=68
'last'=69
'like'=70
'not'=71
'null'=72
'nulls'=73
'on'=74
'or'=75
'?'=76
'rlike'=77
'true'=78
'with'=79
'=='=80
'=~'=81
'!='=82
'<'=83
'<='=84
'>'=85
'>='=86
'+'=87
'-'=88
'*'=89
'/'=90
'%'=91
'{'=92
'}'=93
'??'=94
']'=98
')'=100
'metadata'=106
'join'=114
'USING'=115
'info'=135
'from'=19
'lookup'=22
'mv_expand'=27
'drop'=28
'keep'=29
'rename'=32
'show'=33
'|'=52
'and'=56
'as'=57
'asc'=58
'='=59
'by'=60
'::'=61
':'=62
','=63
'desc'=64
'.'=65
'false'=66
'first'=67
'in'=68
'is'=69
'last'=70
'like'=71
'not'=72
'null'=73
'nulls'=74
'on'=75
'or'=76
'?'=77
'rlike'=78
'true'=79
'with'=80
'=='=81
'=~'=82
'!='=83
'<'=84
'<='=85
'>'=86
'>='=87
'+'=88
'-'=89
'*'=90
'/'=91
'%'=92
'{'=93
'}'=94
'??'=95
']'=99
')'=101
'metadata'=107
'join'=115
'USING'=116
'info'=136

View File

@ -18,9 +18,10 @@ SORT : 'sort' -> pushMode(EXPRESSION_MODE);
STATS : 'stats' -> pushMode(EXPRESSION_MODE);
WHERE : 'where' -> pushMode(EXPRESSION_MODE);
DEV_COMPLETION : {this.isDevVersion()}? 'completion' -> pushMode(EXPRESSION_MODE);
DEV_INLINESTATS : {this.isDevVersion()}? 'inlinestats' -> pushMode(EXPRESSION_MODE);
DEV_RERANK : {this.isDevVersion()}? 'rerank' -> pushMode(EXPRESSION_MODE);
DEV_COMPLETION : {this.isDevVersion()}? 'completion' -> pushMode(EXPRESSION_MODE);
DEV_INLINESTATS : {this.isDevVersion()}? 'inlinestats' -> pushMode(EXPRESSION_MODE);
DEV_RERANK : {this.isDevVersion()}? 'rerank' -> pushMode(EXPRESSION_MODE);
DEV_SAMPLE : {this.isDevVersion()}? 'sample' -> pushMode(EXPRESSION_MODE);
mode EXPRESSION_MODE;

View File

@ -1033,7 +1033,12 @@ public class EsqlCapabilities {
/**
* Support last_over_time aggregation that gets evaluated per time-series
*/
LAST_OVER_TIME(Build.current().isSnapshot());
LAST_OVER_TIME(Build.current().isSnapshot()),
/**
* Support for the SAMPLE command
*/
SAMPLE(Build.current().isSnapshot());
private final boolean enabled;

View File

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.capabilities;
import org.elasticsearch.xpack.esql.common.Failures;
/**
* Interface implemented by expressions that require validation post physical optimization.
*/
public interface PostPhysicalOptimizationVerificationAware {
/**
* Validates the implementing expression - discovered failures are reported to the given
* {@link Failures} class.
*/
void postPhysicalOptimizationVerification(Failures failures);
}

View File

@ -25,8 +25,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.esql.planner.ToAggregator;
@ -37,9 +39,11 @@ import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression, HasSampleCorrection {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);
private final boolean isSampleCorrected;
@FunctionInfo(
returnType = "long",
description = "Returns the total number (count) of input values.",
@ -94,11 +98,20 @@ public class Count extends AggregateFunction implements ToAggregator, SurrogateE
}
public Count(Source source, Expression field, Expression filter) {
this(source, field, filter, false);
}
private Count(Source source, Expression field, Expression filter, boolean isSampleCorrected) {
super(source, field, filter, emptyList());
this.isSampleCorrected = isSampleCorrected;
}
private Count(StreamInput in) throws IOException {
super(in);
// isSampleCorrected is only used during query optimization to mark
// whether this function has been processed. Hence there's no need to
// serialize it.
this.isSampleCorrected = false;
}
@Override
@ -169,4 +182,14 @@ public class Count extends AggregateFunction implements ToAggregator, SurrogateE
return null;
}
@Override
public boolean isSampleCorrected() {
return isSampleCorrected;
}
@Override
public Expression sampleCorrection(Expression sampleProbability) {
return new ToLong(source(), new Div(source(), new Count(source(), field(), filter(), true), sampleProbability));
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;
import org.elasticsearch.xpack.esql.core.expression.Expression;
/**
* Interface signaling to the planner that an aggregation function has to be
* corrected in the presence of random sampling.
*/
public interface HasSampleCorrection {
boolean isSampleCorrected();
Expression sampleCorrection(Expression sampleProbability);
}

View File

@ -26,7 +26,9 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import java.io.IOException;
@ -43,9 +45,11 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;
/**
* Sum all values of a field in matching documents.
*/
public class Sum extends NumericAggregate implements SurrogateExpression {
public class Sum extends NumericAggregate implements SurrogateExpression, HasSampleCorrection {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new);
private final boolean isSampleCorrected;
@FunctionInfo(
returnType = { "long", "double" },
description = "The sum of a numeric expression.",
@ -65,11 +69,20 @@ public class Sum extends NumericAggregate implements SurrogateExpression {
}
public Sum(Source source, Expression field, Expression filter) {
this(source, field, filter, false);
}
private Sum(Source source, Expression field, Expression filter, boolean isSampleCorrected) {
super(source, field, filter, emptyList());
this.isSampleCorrected = isSampleCorrected;
}
private Sum(StreamInput in) throws IOException {
super(in);
// isSampleCorrected is only used during query optimization to mark
// whether this function has been processed. Hence there's no need to
// serialize it.
this.isSampleCorrected = false;
}
@Override
@ -147,4 +160,19 @@ public class Sum extends NumericAggregate implements SurrogateExpression {
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
: null;
}
@Override
public boolean isSampleCorrected() {
return isSampleCorrected;
}
@Override
public Expression sampleCorrection(Expression sampleProbability) {
Expression correctedSum = new Div(source(), new Sum(source(), field(), filter(), true), sampleProbability);
return switch (dataType()) {
case DOUBLE -> correctedSum;
case LONG -> new ToLong(source(), correctedSum);
default -> throw new IllegalStateException("unexpected data type [" + dataType() + "]");
};
}
}

View File

@ -8,11 +8,12 @@
package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.EnableSpatialDistancePushdown;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.InsertFieldExtraction;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushFiltersToSource;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushLimitToSource;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushSampleToSource;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushStatsToSource;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushTopNToSource;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.ReplaceSourceAttributes;
@ -23,7 +24,6 @@ import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor;
import org.elasticsearch.xpack.esql.rule.Rule;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
@ -45,8 +45,8 @@ public class LocalPhysicalPlanOptimizer extends ParameterizedRuleExecutor<Physic
}
PhysicalPlan verify(PhysicalPlan plan) {
Collection<Failure> failures = verifier.verify(plan);
if (failures.isEmpty() == false) {
Failures failures = verifier.verify(plan);
if (failures.hasFailures()) {
throw new VerificationException(failures);
}
return plan;
@ -64,6 +64,7 @@ public class LocalPhysicalPlanOptimizer extends ParameterizedRuleExecutor<Physic
esSourceRules.add(new PushTopNToSource());
esSourceRules.add(new PushLimitToSource());
esSourceRules.add(new PushFiltersToSource());
esSourceRules.add(new PushSampleToSource());
esSourceRules.add(new PushStatsToSource());
esSourceRules.add(new EnableSpatialDistancePushdown());
}

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ApplySampleCorrections;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanFunctionEqualsElimination;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanSimplification;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineBinaryComparisons;
@ -38,6 +39,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneUnusedIndexMode
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineFilters;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineSample;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownCompletion;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval;
@ -127,6 +129,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
return new Batch<>(
"Substitutions",
Limiter.ONCE,
new ApplySampleCorrections(),
new SubstituteSurrogatePlans(),
// Translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
// retrofitted for constant folding - this needs to be fixed.
@ -191,6 +194,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
new PruneLiteralsInOrderBy(),
new PushDownAndCombineLimits(),
new PushDownAndCombineFilters(),
new PushDownAndCombineSample(),
new PushDownCompletion(),
new PushDownEval(),
new PushDownRegexExtract(),

View File

@ -8,14 +8,13 @@
package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
import org.elasticsearch.xpack.esql.plan.physical.FragmentExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor;
import org.elasticsearch.xpack.esql.rule.RuleExecutor;
import java.util.Collection;
import java.util.List;
/**
@ -39,8 +38,8 @@ public class PhysicalPlanOptimizer extends ParameterizedRuleExecutor<PhysicalPla
}
PhysicalPlan verify(PhysicalPlan plan) {
Collection<Failure> failures = verifier.verify(plan);
if (failures.isEmpty() == false) {
Failures failures = verifier.verify(plan);
if (failures.hasFailures()) {
throw new VerificationException(failures);
}
return plan;

View File

@ -7,7 +7,7 @@
package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.capabilities.PostPhysicalOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
@ -17,10 +17,6 @@ import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Set;
import static org.elasticsearch.xpack.esql.common.Failure.fail;
/** Physical plan verifier. */
@ -31,8 +27,8 @@ public final class PhysicalVerifier {
private PhysicalVerifier() {}
/** Verifies the physical plan. */
public Collection<Failure> verify(PhysicalPlan plan) {
Set<Failure> failures = new LinkedHashSet<>();
public Failures verify(PhysicalPlan plan) {
Failures failures = new Failures();
Failures depFailures = new Failures();
// AwaitsFix https://github.com/elastic/elasticsearch/issues/118531
@ -56,6 +52,17 @@ public final class PhysicalVerifier {
}
}
PlanConsistencyChecker.checkPlan(p, depFailures);
if (failures.hasFailures() == false) {
if (p instanceof PostPhysicalOptimizationVerificationAware va) {
va.postPhysicalOptimizationVerification(failures);
}
p.forEachExpression(ex -> {
if (ex instanceof PostPhysicalOptimizationVerificationAware va) {
va.postPhysicalOptimizationVerification(failures);
}
});
}
});
if (depFailures.hasFailures()) {

View File

@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.aggregate.HasSampleCorrection;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.rule.Rule;
import java.util.ArrayList;
import java.util.List;
public class ApplySampleCorrections extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan logicalPlan) {
List<Expression> sampleProbabilities = new ArrayList<>();
return logicalPlan.transformUp(plan -> {
if (plan instanceof Sample sample) {
sampleProbabilities.add(sample.probability());
}
if (plan instanceof Aggregate && sampleProbabilities.isEmpty() == false) {
plan = plan.transformExpressionsOnly(
e -> e instanceof HasSampleCorrection hsc && hsc.isSampleCorrected() == false
? hsc.sampleCorrection(getSampleProbability(sampleProbabilities, e.source()))
: e
);
}
// Operations that map many to many rows break/reset sampling.
// Therefore, the sample probabilities are cleared.
if (plan instanceof Aggregate || plan instanceof Limit) {
sampleProbabilities.clear();
}
return plan;
});
}
private static Expression getSampleProbability(List<Expression> sampleProbabilities, Source source) {
Expression result = null;
for (Expression probability : sampleProbabilities) {
result = result == null ? probability : new Mul(source, result, probability);
}
return result;
}
}

View File

@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.Insist;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
/**
* Pushes down the SAMPLE operator. SAMPLE can be pushed down through an
* operator if
* <p>
* <code>| SAMPLE p | OPERATOR</code>
* <p>
* is equivalent to
* <p>
* <code>| OPERATOR | SAMPLE p</code>
* <p>
* statistically (i.e. same possible output with same probabilities).
* In that case, we push down sampling to Lucene for efficiency.
* <p>
*
* As a rule of thumb, if an operator can be swapped with sampling if it maps:
* <ul>
* <li>
* one row to one row (e.g. <code>DISSECT</code>, <code>DROP</code>, <code>ENRICH</code>,
* <code>EVAL</code>, <code>GROK</code>, <code>KEEP</code>, <code>RENAME</code>)
* </li>
* <li>
* one row to zero or one row (<code>WHERE</code>)
* </li>
* <li>
* reorders the rows (<code>SORT</code>)
* </li>
* </ul>
*/
public class PushDownAndCombineSample extends OptimizerRules.ParameterizedOptimizerRule<Sample, LogicalOptimizerContext> {
public PushDownAndCombineSample() {
super(OptimizerRules.TransformDirection.DOWN);
}
@Override
protected LogicalPlan rule(Sample sample, LogicalOptimizerContext context) {
LogicalPlan plan = sample;
var child = sample.child();
if (child instanceof Sample sampleChild) {
var probability = combinedProbability(context, sample, sampleChild);
var seed = combinedSeed(context, sample, sampleChild);
plan = new Sample(sample.source(), probability, seed, sampleChild.child());
} else if (child instanceof Enrich
|| child instanceof Eval
|| child instanceof Filter
|| child instanceof Insist
|| child instanceof OrderBy
|| child instanceof Project
|| child instanceof RegexExtract) {
var unaryChild = (UnaryPlan) child;
plan = unaryChild.replaceChild(sample.replaceChild(unaryChild.child()));
}
return plan;
}
private static Expression combinedProbability(LogicalOptimizerContext context, Sample parent, Sample child) {
var parentProbability = (double) Foldables.valueOf(context.foldCtx(), parent.probability());
var childProbability = (double) Foldables.valueOf(context.foldCtx(), child.probability());
return Literal.of(parent.probability(), parentProbability * childProbability);
}
private static Expression combinedSeed(LogicalOptimizerContext context, Sample parent, Sample child) {
var parentSeed = parent.seed();
var childSeed = child.seed();
Expression seed;
if (parentSeed != null) {
if (childSeed != null) {
var seedValue = (int) Foldables.valueOf(context.foldCtx(), parentSeed);
seedValue ^= (int) Foldables.valueOf(context.foldCtx(), childSeed);
seed = Literal.of(parentSeed, seedValue);
} else {
seed = parentSeed;
}
} else {
seed = childSeed;
}
return seed;
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.xpack.esql.planner.mapper.MapperUtils.hasScoreAttribute;
public class PushSampleToSource extends PhysicalOptimizerRules.ParameterizedOptimizerRule<SampleExec, LocalPhysicalOptimizerContext> {
@Override
protected PhysicalPlan rule(SampleExec sample, LocalPhysicalOptimizerContext ctx) {
PhysicalPlan plan = sample;
if (sample.child() instanceof EsQueryExec queryExec) {
var fullQuery = boolQuery();
if (queryExec.query() != null) {
if (hasScoreAttribute(queryExec.output())) {
fullQuery.must(queryExec.query());
} else {
fullQuery.filter(queryExec.query());
}
}
var sampleQuery = new RandomSamplingQueryBuilder((double) Foldables.valueOf(ctx.foldCtx(), sample.probability()));
if (sample.seed() != null) {
sampleQuery.seed((int) Foldables.valueOf(ctx.foldCtx(), sample.seed()));
}
fullQuery.filter(sampleQuery);
plan = queryExec.withQuery(fullQuery);
}
return plan;
}
}

View File

@ -200,7 +200,7 @@ public class PushTopNToSource extends PhysicalOptimizerRules.ParameterizedOptimi
// allow only exact FieldAttributes (no expressions) for sorting
BiFunction<Expression, LucenePushdownPredicates, Boolean> isSortableAttribute = (exp, lpp) -> lpp.isPushableFieldAttribute(exp)
// TODO: https://github.com/elastic/elasticsearch/issues/120219
|| (exp instanceof MetadataAttribute ma && MetadataAttribute.SCORE.equals(ma.name()));
|| MetadataAttribute.isScoreAttribute(exp);
return orders.stream().allMatch(o -> isSortableAttribute.apply(o.child(), lucenePushdownPredicates));
}
@ -209,7 +209,7 @@ public class PushTopNToSource extends PhysicalOptimizerRules.ParameterizedOptimi
for (Order o : orders) {
if (o.child() instanceof FieldAttribute fa) {
sorts.add(new EsQueryExec.FieldSort(fa.exactAttribute(), o.direction(), o.nullsPosition()));
} else if (o.child() instanceof MetadataAttribute ma && MetadataAttribute.SCORE.equals(ma.name())) {
} else if (MetadataAttribute.isScoreAttribute(o.child())) {
sorts.add(new EsQueryExec.ScoreSort(o.direction()));
} else {
assert false : "unexpected ordering on expression type " + o.child().getClass();

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -764,6 +764,18 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterSampleCommand(EsqlBaseParser.SampleCommandContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitSampleCommand(EsqlBaseParser.SampleCommandContext ctx) { }
/**
* {@inheritDoc}
*

View File

@ -454,6 +454,13 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*

View File

@ -655,6 +655,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
* @param ctx the parse tree
*/
void enterSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
/**
* Exit a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
* @param ctx the parse tree
*/
void exitSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
/**
* Enter a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.

View File

@ -400,6 +400,12 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
* @return the visitor result
*/
T visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#sampleCommand}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx);
/**
* Visit a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.

View File

@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
@ -768,4 +769,23 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
ctx.parameter().getText()
);
}
public PlanFactory visitSampleCommand(EsqlBaseParser.SampleCommandContext ctx) {
var probability = visitDecimalValue(ctx.probability);
Literal seed;
if (ctx.seed != null) {
seed = visitIntegerValue(ctx.seed);
if (seed.dataType() != DataType.INTEGER) {
throw new ParsingException(
seed.source(),
"seed must be an integer, provided [{}] of type [{}]",
ctx.seed.getText(),
seed.dataType()
);
}
} else {
seed = null;
}
return plan -> new Sample(source(ctx), probability, seed, plan);
}
}

View File

@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Lookup;
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
@ -47,6 +48,7 @@ import org.elasticsearch.xpack.esql.plan.physical.LimitExec;
import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
@ -87,6 +89,7 @@ public class PlanWritables {
OrderBy.ENTRY,
Project.ENTRY,
Rerank.ENTRY,
Sample.ENTRY,
TimeSeriesAggregate.ENTRY,
TopN.ENTRY
);
@ -114,6 +117,7 @@ public class PlanWritables {
MvExpandExec.ENTRY,
ProjectExec.ENTRY,
RerankExec.ENTRY,
SampleExec.ENTRY,
ShowExec.ENTRY,
SubqueryExec.ENTRY,
TimeSeriesAggregateExec.ENTRY,

View File

@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.logical;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.common.Failure.fail;
public class Sample extends UnaryPlan implements TelemetryAware, PostAnalysisVerificationAware {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Sample", Sample::new);
private final Expression probability;
private final Expression seed;
public Sample(Source source, Expression probability, @Nullable Expression seed, LogicalPlan child) {
super(source, child);
this.probability = probability;
this.seed = seed;
}
private Sample(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class), // probability
in.readOptionalNamedWriteable(Expression.class), // seed
in.readNamedWriteable(LogicalPlan.class) // child
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(probability);
out.writeOptionalNamedWriteable(seed);
out.writeNamedWriteable(child());
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
@Override
protected NodeInfo<Sample> info() {
return NodeInfo.create(this, Sample::new, probability, seed, child());
}
@Override
public Sample replaceChild(LogicalPlan newChild) {
return new Sample(source(), probability, seed, newChild);
}
public Expression probability() {
return probability;
}
public Expression seed() {
return seed;
}
@Override
public boolean expressionsResolved() {
return probability.resolved() && (seed == null || seed.resolved());
}
@Override
public int hashCode() {
return Objects.hash(probability, seed, child());
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
var other = (Sample) obj;
return Objects.equals(probability, other.probability) && Objects.equals(seed, other.seed) && Objects.equals(child(), other.child());
}
@Override
public void postAnalysisVerification(Failures failures) {
try {
RandomSamplingQuery.checkProbabilityRange((double) Foldables.valueOf(FoldContext.small(), probability));
} catch (IllegalArgumentException e) {
failures.add(fail(probability, e.getMessage()));
}
}
}

View File

@ -308,6 +308,12 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize {
: new EsQueryExec(source(), indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts, estimatedRowSize);
}
public EsQueryExec withQuery(QueryBuilder query) {
return Objects.equals(this.query, query)
? this
: new EsQueryExec(source(), indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts, estimatedRowSize);
}
@Override
public int hashCode() {
return Objects.hash(indexPattern, indexMode, indexNameWithModes, attrs, query, limit, sorts);

View File

@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.physical;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.esql.capabilities.PostPhysicalOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.common.Failure.fail;
public class SampleExec extends UnaryExec implements PostPhysicalOptimizationVerificationAware {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
PhysicalPlan.class,
"SampleExec",
SampleExec::new
);
private final Expression probability;
private final Expression seed;
public SampleExec(Source source, PhysicalPlan child, Expression probability, @Nullable Expression seed) {
super(source, child);
this.probability = probability;
this.seed = seed;
}
public SampleExec(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(PhysicalPlan.class), // child
in.readNamedWriteable(Expression.class), // probability
in.readOptionalNamedWriteable(Expression.class) // seed
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(child());
out.writeNamedWriteable(probability);
out.writeOptionalNamedWriteable(seed);
}
@Override
public UnaryExec replaceChild(PhysicalPlan newChild) {
return new SampleExec(source(), newChild, probability, seed);
}
@Override
protected NodeInfo<? extends PhysicalPlan> info() {
return NodeInfo.create(this, SampleExec::new, child(), probability, seed);
}
/**
* Returns the name of the writeable object
*/
@Override
public String getWriteableName() {
return ENTRY.name;
}
public Expression probability() {
return probability;
}
public Expression seed() {
return seed;
}
@Override
public int hashCode() {
return Objects.hash(child(), probability, seed);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
var other = (SampleExec) obj;
return Objects.equals(child(), other.child()) && Objects.equals(probability, other.probability) && Objects.equals(seed, other.seed);
}
@Override
public void postPhysicalOptimizationVerification(Failures failures) {
// It's currently impossible in ES|QL to handle all data in deterministic order, therefore
// a fixed random seed in the sample operator doesn't work as intended and is disallowed.
// TODO: fix this.
if (seed != null) {
// TODO: what should the error message here be? This doesn't seem right.
failures.add(fail(seed, "Seed not supported when sampling can't be pushed down to Lucene"));
}
}
}

View File

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.planner;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
@ -37,6 +38,7 @@ import org.elasticsearch.compute.operator.Operator.OperatorFactory;
import org.elasticsearch.compute.operator.OutputOperator.OutputOperatorFactory;
import org.elasticsearch.compute.operator.RowInTableLookupOperator;
import org.elasticsearch.compute.operator.RrfScoreEvalOperator;
import org.elasticsearch.compute.operator.SampleOperator;
import org.elasticsearch.compute.operator.ScoreOperator;
import org.elasticsearch.compute.operator.ShowOperator;
import org.elasticsearch.compute.operator.SinkOperator;
@ -66,6 +68,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NameId;
@ -107,6 +110,7 @@ import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
@ -253,6 +257,8 @@ public class LocalExecutionPlanner {
return planRerank(rerank, context);
} else if (node instanceof ChangePointExec changePoint) {
return planChangePoint(changePoint, context);
} else if (node instanceof SampleExec Sample) {
return planSample(Sample, context);
}
// source nodes
else if (node instanceof EsQueryExec esQuery) {
@ -800,6 +806,13 @@ public class LocalExecutionPlanner {
);
}
private PhysicalOperation planSample(SampleExec rsx, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(rsx.child(), context);
var probability = (double) Foldables.valueOf(context.foldCtx(), rsx.probability());
var seed = rsx.seed() != null ? (int) Foldables.valueOf(context.foldCtx(), rsx.seed()) : Randomness.get().nextInt();
return source.with(new SampleOperator.Factory(probability, seed), source.layout);
}
/**
* Immutable physical operation.
*/

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.LeafPlan;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
@ -28,6 +29,7 @@ import org.elasticsearch.xpack.esql.plan.physical.LimitExec;
import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import java.util.List;
@ -83,6 +85,10 @@ public class LocalMapper {
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
}
if (unary instanceof Sample sample) {
return new SampleExec(sample.source(), mappedChild, sample.probability(), sample.seed());
}
//
// Pipeline operators
//

View File

@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LeafPlan;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@ -37,6 +38,7 @@ import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec;
import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.SampleExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
@ -186,6 +188,12 @@ public class Mapper {
);
}
// TODO: share code with local LocalMapper?
if (unary instanceof Sample sample) {
mappedChild = addExchangeForFragment(sample, mappedChild);
return new SampleExec(sample.source(), mappedChild, sample.probability(), sample.seed());
}
//
// Pipeline operators
//

View File

@ -12,6 +12,7 @@ import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
@ -53,7 +54,7 @@ import java.util.List;
/**
* Class for sharing code across Mappers.
*/
class MapperUtils {
public class MapperUtils {
private MapperUtils() {}
static PhysicalPlan mapLeaf(LeafPlan p) {
@ -177,4 +178,13 @@ class MapperUtils {
static PhysicalPlan unsupported(LogicalPlan p) {
throw new EsqlIllegalArgumentException("unsupported logical plan node [" + p.nodeName() + "]");
}
public static boolean hasScoreAttribute(List<? extends Attribute> attributes) {
for (Attribute attr : attributes) {
if (MetadataAttribute.isScoreAttribute(attr)) {
return true;
}
}
return false;
}
}

View File

@ -31,6 +31,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
@ -68,7 +69,8 @@ public enum FeatureMetric {
INSIST(Insist.class::isInstance),
FORK(Fork.class::isInstance),
RRF(RrfScoreEval.class::isInstance),
COMPLETION(Completion.class::isInstance);
COMPLETION(Completion.class::isInstance),
SAMPLE(Sample.class::isInstance);
/**
* List here plans we want to exclude from telemetry

View File

@ -26,6 +26,8 @@ import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE;
@ -209,4 +211,13 @@ public final class AnalyzerTestUtils {
public static IndexResolution tsdbIndexResolution() {
return loadMapping("tsdb-mapping.json", "test");
}
public static <E> E randomValueOtherThanTest(Predicate<E> exclude, Supplier<E> supplier) {
while (true) {
E value = supplier.get();
if (exclude.test(value) == false) {
return value;
}
}
}
}

View File

@ -107,6 +107,7 @@ import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomValueOtherThanTest;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.hamcrest.Matchers.contains;
@ -3403,6 +3404,19 @@ public class AnalyzerTests extends ESTestCase {
assertThat(e.getMessage(), containsString("Unknown column [_id]"));
}
public void testRandomSampleProbability() {
var e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE 1."));
assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [1.0]"));
e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE .0"));
assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [0.0]"));
double p = randomValueOtherThanTest(d -> 0 < d && d < 1, () -> randomDoubleBetween(0, Double.MAX_VALUE, false));
e = expectThrows(VerificationException.class, () -> analyze("FROM test | SAMPLE " + p));
assertThat(e.getMessage(), containsString("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]"));
}
// TODO There's too much boilerplate involved here! We need a better way of creating FieldCapabilitiesResponses from a mapping or index.
private static FieldCapabilitiesIndexResponse fieldCapabilitiesIndexResponse(
String indexName,
Map<String, IndexFieldCapabilities> fields

View File

@ -109,6 +109,7 @@ import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.ChangePoint;
import org.elasticsearch.xpack.esql.plan.logical.Dissect;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@ -121,6 +122,7 @@ import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.Sample;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
@ -7804,4 +7806,153 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
var mvExpand2 = as(mvExpand.child(), MvExpand.class);
as(mvExpand2.child(), Row.class);
}
/**
* Eval[[1[INTEGER] AS irrelevant1, 2[INTEGER] AS irrelevant2]]
* \_Limit[1000[INTEGER],false]
* \_Sample[0.015[DOUBLE],15[INTEGER]]
* \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
*/
public void testSampleMerged() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var query = """
FROM TEST
| SAMPLE .3 5
| EVAL irrelevant1 = 1
| SAMPLE .5 10
| EVAL irrelevant2 = 2
| SAMPLE .1
""";
var optimized = optimizedPlan(query);
var eval = as(optimized, Eval.class);
var limit = as(eval.child(), Limit.class);
var sample = as(limit.child(), Sample.class);
var source = as(sample.child(), EsRelation.class);
assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.015));
assertThat(sample.seed().fold(FoldContext.small()), equalTo(5 ^ 10));
}
public void testSamplePushDown() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
for (var command : List.of(
"ENRICH languages_idx on first_name",
"EVAL x = 1",
// "INSIST emp_no", // TODO
"KEEP emp_no",
"DROP emp_no",
"RENAME emp_no AS x",
"GROK first_name \"%{WORD:bar}\"",
"DISSECT first_name \"%{z}\""
)) {
var query = "FROM TEST | " + command + " | SAMPLE .5";
var optimized = optimizedPlan(query);
var unary = as(optimized, UnaryPlan.class);
var limit = as(unary.child(), Limit.class);
var sample = as(limit.child(), Sample.class);
var source = as(sample.child(), EsRelation.class);
assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
assertNull(sample.seed());
}
}
public void testSamplePushDown_sort() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var query = "FROM TEST | WHERE emp_no > 0 | SAMPLE 0.5 | LIMIT 100";
var optimized = optimizedPlan(query);
var limit = as(optimized, Limit.class);
var filter = as(limit.child(), Filter.class);
var sample = as(filter.child(), Sample.class);
var source = as(sample.child(), EsRelation.class);
assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
assertNull(sample.seed());
}
public void testSamplePushDown_where() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var query = "FROM TEST | SORT emp_no | SAMPLE 0.5 | LIMIT 100";
var optimized = optimizedPlan(query);
var topN = as(optimized, TopN.class);
var sample = as(topN.child(), Sample.class);
var source = as(sample.child(), EsRelation.class);
assertThat(sample.probability().fold(FoldContext.small()), equalTo(0.5));
assertNull(sample.seed());
}
public void testSampleNoPushDown() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
for (var command : List.of("LIMIT 100", "MV_EXPAND languages", "STATS COUNT()")) {
var query = "FROM TEST | " + command + " | SAMPLE .5";
var optimized = optimizedPlan(query);
var limit = as(optimized, Limit.class);
var sample = as(limit.child(), Sample.class);
var unary = as(sample.child(), UnaryPlan.class);
var source = as(unary.child(), EsRelation.class);
}
}
/**
* Limit[1000[INTEGER],false]
* \_Sample[0.5[DOUBLE],null]
* \_Join[LEFT,[language_code{r}#4],[language_code{r}#4],[language_code{f}#17]]
* |_Eval[[emp_no{f}#6 AS language_code]]
* | \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
* \_EsRelation[languages_lookup][LOOKUP][language_code{f}#17, language_name{f}#18]
*/
public void testSampleNoPushDownLookupJoin() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var query = """
FROM TEST
| EVAL language_code = emp_no
| LOOKUP JOIN languages_lookup ON language_code
| SAMPLE .5
""";
var optimized = optimizedPlan(query);
var limit = as(optimized, Limit.class);
var sample = as(limit.child(), Sample.class);
var join = as(sample.child(), Join.class);
var eval = as(join.left(), Eval.class);
var source = as(eval.child(), EsRelation.class);
}
/**
* Limit[1000[INTEGER],false]
* \_Sample[0.5[DOUBLE],null]
* \_Limit[1000[INTEGER],false]
* \_ChangePoint[emp_no{f}#6,hire_date{f}#13,type{r}#4,pvalue{r}#5]
* \_TopN[[Order[hire_date{f}#13,ASC,ANY]],1001[INTEGER]]
* \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..]
*/
public void testSampleNoPushDownChangePoint() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var query = """
FROM TEST
| CHANGE_POINT emp_no ON hire_date
| SAMPLE .5 -55
""";
var optimized = optimizedPlan(query);
var limit = as(optimized, Limit.class);
var sample = as(limit.child(), Sample.class);
limit = as(sample.child(), Limit.class);
var changePoint = as(limit.child(), ChangePoint.class);
var topN = as(changePoint.child(), TopN.class);
var source = as(topN.child(), EsRelation.class);
}
}

View File

@ -35,6 +35,7 @@ import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
import org.elasticsearch.test.ESTestCase;
@ -185,6 +186,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_SHAPE;
import static org.elasticsearch.xpack.esql.core.util.TestUtils.stripThrough;
import static org.elasticsearch.xpack.esql.parser.ExpressionBuilder.MAX_EXPRESSION_DEPTH;
import static org.elasticsearch.xpack.esql.parser.LogicalPlanBuilder.MAX_QUERY_DEPTH;
import static org.elasticsearch.xpack.esql.planner.mapper.MapperUtils.hasScoreAttribute;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
@ -7732,7 +7734,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
EsRelation esRelation = as(filter.child(), EsRelation.class);
assertTrue(esRelation.optimized());
assertTrue(esRelation.resolved());
assertTrue(esRelation.output().stream().anyMatch(a -> a.name().equals(MetadataAttribute.SCORE) && a instanceof MetadataAttribute));
assertTrue(hasScoreAttribute(esRelation.output()));
}
public void testScoreTopN() {
@ -7754,7 +7756,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
Order scoreOrer = order.getFirst();
assertEquals(Order.OrderDirection.DESC, scoreOrer.direction());
Expression child = scoreOrer.child();
assertTrue(child instanceof MetadataAttribute ma && ma.name().equals(MetadataAttribute.SCORE));
assertTrue(MetadataAttribute.isScoreAttribute(child));
Filter filter = as(topN.child(), Filter.class);
Match match = as(filter.condition(), Match.class);
@ -7764,7 +7766,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
EsRelation esRelation = as(filter.child(), EsRelation.class);
assertTrue(esRelation.optimized());
assertTrue(esRelation.resolved());
assertTrue(esRelation.output().stream().anyMatch(a -> a.name().equals(MetadataAttribute.SCORE) && a instanceof MetadataAttribute));
assertTrue(hasScoreAttribute(esRelation.output()));
}
public void testReductionPlanForTopN() {
@ -7822,6 +7824,54 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
as(limit2.child(), FilterExec.class);
}
/*
* LimitExec[1000[INTEGER]]
* \_ExchangeExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, hire_date{f}#9, job{f}#10, job.raw{f}#11, langua
* ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7],false]
* \_ProjectExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, hire_date{f}#9, job{f}#10, job.raw{f}#11, langua
* ges{f}#5, last_name{f}#6, long_noidx{f}#12, salary{f}#7]]
* \_FieldExtractExec[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gen..]<[],[]>
* \_EsQueryExec[test], indexMode[standard],
* query[{"bool":{"filter":[{"sampling":{"probability":0.1,"seed":234,"hash":0}}],"boost":1.0}}]
* [_doc{f}#24], limit[1000], sort[] estimatedRowSize[332]
*/
public void testSamplePushDown() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
var plan = physicalPlan("""
FROM test
| SAMPLE +0.1 -234
""");
var optimized = optimizedPlan(plan);
var limit = as(optimized, LimitExec.class);
var exchange = as(limit.child(), ExchangeExec.class);
var project = as(exchange.child(), ProjectExec.class);
var fieldExtract = as(project.child(), FieldExtractExec.class);
var esQuery = as(fieldExtract.child(), EsQueryExec.class);
var boolQuery = as(esQuery.query(), BoolQueryBuilder.class);
var filter = boolQuery.filter();
var randomSampling = as(filter.get(0), RandomSamplingQueryBuilder.class);
assertThat(randomSampling.probability(), equalTo(0.1));
assertThat(randomSampling.seed(), equalTo(-234));
assertThat(randomSampling.hash(), equalTo(0));
}
public void testSample_seedNotSupportedInOperator() {
assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE.isEnabled());
optimizedPlan(physicalPlan("FROM test | SAMPLE 0.1"));
optimizedPlan(physicalPlan("FROM test | SAMPLE 0.1 42"));
optimizedPlan(physicalPlan("FROM test | MV_EXPAND first_name | SAMPLE 0.1"));
VerificationException e = expectThrows(
VerificationException.class,
() -> optimizedPlan(physicalPlan("FROM test | MV_EXPAND first_name | SAMPLE 0.1 42"))
);
assertThat(e.getMessage(), equalTo("Found 1 problem\nline 1:47: Seed not supported when sampling can't be pushed down to Lucene"));
}
@SuppressWarnings("SameParameterValue")
private static void assertFilterCondition(
Filter filter,
@ -8005,7 +8055,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
var logical = logicalOptimizer.optimize(dataSource.analyzer.analyze(parser.createStatement(query)));
// System.out.println("Logical\n" + logical);
var physical = mapper.map(logical);
// System.out.println(physical);
// System.out.println("Physical\n" + physical);
if (assertSerialization) {
assertSerialization(physical);
}

View File

@ -3485,6 +3485,14 @@ public class StatementParserTests extends AbstractStatementParserTests {
expectError("FROM foo* | COMPLETION prompt AS targetField", "line 1:31: mismatched input 'AS' expecting {");
}
public void testSample() {
expectError("FROM test | SAMPLE .1 2 3", "line 1:25: extraneous input '3' expecting <EOF>");
expectError("FROM test | SAMPLE .1 \"2\"", "line 1:23: extraneous input '\"2\"' expecting <EOF>");
expectError("FROM test | SAMPLE 1", "line 1:20: mismatched input '1' expecting {DECIMAL_LITERAL, '+', '-'}");
expectError("FROM test | SAMPLE", "line 1:19: mismatched input '<EOF>' expecting {DECIMAL_LITERAL, '+', '-'}");
expectError("FROM test | SAMPLE +.1 2147483648", "line 1:24: seed must be an integer, provided [2147483648] of type [LONG]");
}
static Alias alias(String name, Expression value) {
return new Alias(EMPTY, name, value);
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.logical;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.type.DataType;
import java.io.IOException;
public class SampleSerializationTests extends AbstractLogicalPlanSerializationTests<Sample> {
/**
* Creates a random test instance to use in the tests. This method will be
* called multiple times during test execution and should return a different
* random instance each time it is called.
*/
@Override
protected Sample createTestInstance() {
return new Sample(randomSource(), randomProbability(), randomSeed(), randomChild(0));
}
public static Literal randomProbability() {
return new Literal(randomSource(), randomDoubleBetween(0, 1, false), DataType.DOUBLE);
}
public static Literal randomSeed() {
return randomBoolean() ? new Literal(randomSource(), randomInt(), DataType.INTEGER) : null;
}
/**
* Returns an instance which is mutated slightly so it should not be equal
* to the given instance.
*
* @param instance
*/
@Override
protected Sample mutateInstance(Sample instance) throws IOException {
var probability = instance.probability();
var seed = instance.seed();
var child = instance.child();
int updateSelector = randomIntBetween(0, 2);
switch (updateSelector) {
case 0 -> probability = randomValueOtherThan(probability, SampleSerializationTests::randomProbability);
case 1 -> seed = randomValueOtherThan(seed, SampleSerializationTests::randomSeed);
case 2 -> child = randomValueOtherThan(child, () -> randomChild(0));
default -> throw new IllegalArgumentException("Invalid selector: " + updateSelector);
}
return new Sample(instance.source(), probability, seed, child);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.physical;
import org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests;
import java.io.IOException;
import static org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests.randomProbability;
import static org.elasticsearch.xpack.esql.plan.logical.SampleSerializationTests.randomSeed;
public class SampleExecSerializationTests extends AbstractPhysicalPlanSerializationTests<SampleExec> {
/**
* Creates a random test instance to use in the tests. This method will be
* called multiple times during test execution and should return a different
* random instance each time it is called.
*/
@Override
protected SampleExec createTestInstance() {
return new SampleExec(randomSource(), randomChild(0), randomProbability(), randomSeed());
}
/**
* Returns an instance which is mutated slightly so it should not be equal
* to the given instance.
*
* @param instance
*/
@Override
protected SampleExec mutateInstance(SampleExec instance) throws IOException {
var probability = instance.probability();
var seed = instance.seed();
var child = instance.child();
int updateSelector = randomIntBetween(0, 2);
switch (updateSelector) {
case 0 -> probability = randomValueOtherThan(probability, SampleSerializationTests::randomProbability);
case 1 -> seed = randomValueOtherThan(seed, SampleSerializationTests::randomSeed);
case 2 -> child = randomValueOtherThan(child, () -> randomChild(0));
default -> throw new IllegalArgumentException("Invalid selector: " + updateSelector);
}
return new SampleExec(instance.source(), child, probability, seed);
}
}

View File

@ -39,7 +39,7 @@ setup:
- do: {xpack.usage: {}}
- match: { esql.available: true }
- match: { esql.enabled: true }
- length: { esql.features: 25 }
- length: { esql.features: 26 }
- set: {esql.features.dissect: dissect_counter}
- set: {esql.features.drop: drop_counter}
- set: {esql.features.eval: eval_counter}
@ -65,6 +65,7 @@ setup:
- set: {esql.features.fork: fork_counter}
- set: {esql.features.rrf: rrf_counter}
- set: {esql.features.completion: completion_counter}
- set: {esql.features.sample: sample_counter}
- length: { esql.queries: 3 }
- set: {esql.queries.rest.total: rest_total_counter}
- set: {esql.queries.rest.failed: rest_failed_counter}
@ -108,6 +109,7 @@ setup:
- match: {esql.features.fork: $fork_counter}
- match: {esql.features.rrf: $rrf_counter}
- match: {esql.features.completion: $completion_counter}
- match: {esql.features.sample: $sample_counter}
- gt: {esql.queries.rest.total: $rest_total_counter}
- match: {esql.queries.rest.failed: $rest_failed_counter}
- match: {esql.queries.kibana.total: $kibana_total_counter}