Adding ES|QL RERANK command in snapshot builds (#123074)

This commit is contained in:
Aurélien FOUCRET 2025-04-04 16:39:18 +02:00 committed by GitHub
parent 8f38b13059
commit a4a271415d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 5556 additions and 2438 deletions

View File

@ -0,0 +1,5 @@
pr: 123074
summary: Adding ES|QL Reranker command in snapshot builds
area: Ranking
type: feature
issues: []

View File

@ -369,6 +369,9 @@ tests:
- class: org.elasticsearch.snapshots.SharedClusterSnapshotRestoreIT
method: testDeletionOfFailingToRecoverIndexShouldStopRestore
issue: https://github.com/elastic/elasticsearch/issues/126204
- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
method: testSimpleCircuitBreaking
issue: https://github.com/elastic/elasticsearch/issues/124337
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
method: testSchedulerCloseWaitsForRunningMerge
issue: https://github.com/elastic/elasticsearch/issues/125236

View File

@ -64,6 +64,10 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");
public static Builder builder(String inferenceEntityId, TaskType taskType) {
return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
}
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
static {
PARSER.declareStringArray(Request.Builder::setInput, INPUT);

View File

@ -0,0 +1,153 @@
// Generated from /Users/afoucret/git/elasticsearch/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4 by ANTLR 4.13.2
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import org.antlr.v4.runtime.Lexer;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.atn.*;
import org.antlr.v4.runtime.dfa.DFA;
import org.antlr.v4.runtime.misc.*;
@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
public class EsqlBaseLexer extends LexerConfig {
static { RuntimeMetaData.checkVersion("4.13.2", RuntimeMetaData.VERSION); }
protected static final DFA[] _decisionToDFA;
protected static final PredictionContextCache _sharedContextCache =
new PredictionContextCache();
public static final int
LINE_COMMENT=1, MULTILINE_COMMENT=2, WS=3;
public static String[] channelNames = {
"DEFAULT_TOKEN_CHANNEL", "HIDDEN"
};
public static String[] modeNames = {
"DEFAULT_MODE"
};
private static String[] makeRuleNames() {
return new String[] {
"LINE_COMMENT", "MULTILINE_COMMENT", "WS"
};
}
public static final String[] ruleNames = makeRuleNames();
private static String[] makeLiteralNames() {
return new String[] {
};
}
private static final String[] _LITERAL_NAMES = makeLiteralNames();
private static String[] makeSymbolicNames() {
return new String[] {
null, "LINE_COMMENT", "MULTILINE_COMMENT", "WS"
};
}
private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
/**
* @deprecated Use {@link #VOCABULARY} instead.
*/
@Deprecated
public static final String[] tokenNames;
static {
tokenNames = new String[_SYMBOLIC_NAMES.length];
for (int i = 0; i < tokenNames.length; i++) {
tokenNames[i] = VOCABULARY.getLiteralName(i);
if (tokenNames[i] == null) {
tokenNames[i] = VOCABULARY.getSymbolicName(i);
}
if (tokenNames[i] == null) {
tokenNames[i] = "<INVALID>";
}
}
}
@Override
@Deprecated
public String[] getTokenNames() {
return tokenNames;
}
@Override
public Vocabulary getVocabulary() {
return VOCABULARY;
}
public EsqlBaseLexer(CharStream input) {
super(input);
_interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
}
@Override
public String getGrammarFileName() { return "EsqlBaseLexer.g4"; }
@Override
public String[] getRuleNames() { return ruleNames; }
@Override
public String getSerializedATN() { return _serializedATN; }
@Override
public String[] getChannelNames() { return channelNames; }
@Override
public String[] getModeNames() { return modeNames; }
@Override
public ATN getATN() { return _ATN; }
public static final String _serializedATN =
"\u0004\u0000\u0003.\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
"\u0007\u0001\u0002\u0002\u0007\u0002\u0001\u0000\u0001\u0000\u0001\u0000"+
"\u0001\u0000\u0005\u0000\f\b\u0000\n\u0000\f\u0000\u000f\t\u0000\u0001"+
"\u0000\u0003\u0000\u0012\b\u0000\u0001\u0000\u0003\u0000\u0015\b\u0000"+
"\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"+
"\u0001\u0001\u0005\u0001\u001e\b\u0001\n\u0001\f\u0001!\t\u0001\u0001"+
"\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0002\u0004"+
"\u0002)\b\u0002\u000b\u0002\f\u0002*\u0001\u0002\u0001\u0002\u0001\u001f"+
"\u0000\u0003\u0001\u0001\u0003\u0002\u0005\u0003\u0001\u0000\u0002\u0002"+
"\u0000\n\n\r\r\u0003\u0000\t\n\r\r 3\u0000\u0001\u0001\u0000\u0000\u0000"+
"\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005\u0001\u0000\u0000\u0000"+
"\u0001\u0007\u0001\u0000\u0000\u0000\u0003\u0018\u0001\u0000\u0000\u0000"+
"\u0005(\u0001\u0000\u0000\u0000\u0007\b\u0005/\u0000\u0000\b\t\u0005/"+
"\u0000\u0000\t\r\u0001\u0000\u0000\u0000\n\f\b\u0000\u0000\u0000\u000b"+
"\n\u0001\u0000\u0000\u0000\f\u000f\u0001\u0000\u0000\u0000\r\u000b\u0001"+
"\u0000\u0000\u0000\r\u000e\u0001\u0000\u0000\u0000\u000e\u0011\u0001\u0000"+
"\u0000\u0000\u000f\r\u0001\u0000\u0000\u0000\u0010\u0012\u0005\r\u0000"+
"\u0000\u0011\u0010\u0001\u0000\u0000\u0000\u0011\u0012\u0001\u0000\u0000"+
"\u0000\u0012\u0014\u0001\u0000\u0000\u0000\u0013\u0015\u0005\n\u0000\u0000"+
"\u0014\u0013\u0001\u0000\u0000\u0000\u0014\u0015\u0001\u0000\u0000\u0000"+
"\u0015\u0016\u0001\u0000\u0000\u0000\u0016\u0017\u0006\u0000\u0000\u0000"+
"\u0017\u0002\u0001\u0000\u0000\u0000\u0018\u0019\u0005/\u0000\u0000\u0019"+
"\u001a\u0005*\u0000\u0000\u001a\u001f\u0001\u0000\u0000\u0000\u001b\u001e"+
"\u0003\u0003\u0001\u0000\u001c\u001e\t\u0000\u0000\u0000\u001d\u001b\u0001"+
"\u0000\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e!\u0001\u0000"+
"\u0000\u0000\u001f \u0001\u0000\u0000\u0000\u001f\u001d\u0001\u0000\u0000"+
"\u0000 \"\u0001\u0000\u0000\u0000!\u001f\u0001\u0000\u0000\u0000\"#\u0005"+
"*\u0000\u0000#$\u0005/\u0000\u0000$%\u0001\u0000\u0000\u0000%&\u0006\u0001"+
"\u0000\u0000&\u0004\u0001\u0000\u0000\u0000\')\u0007\u0001\u0000\u0000"+
"(\'\u0001\u0000\u0000\u0000)*\u0001\u0000\u0000\u0000*(\u0001\u0000\u0000"+
"\u0000*+\u0001\u0000\u0000\u0000+,\u0001\u0000\u0000\u0000,-\u0006\u0002"+
"\u0000\u0000-\u0006\u0001\u0000\u0000\u0000\u0007\u0000\r\u0011\u0014"+
"\u001d\u001f*\u0001\u0000\u0001\u0000";
public static final ATN _ATN =
new ATNDeserializer().deserialize(_serializedATN.toCharArray());
static {
_decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
_decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
}
}
}

View File

@ -0,0 +1,3 @@
LINE_COMMENT=1
MULTILINE_COMMENT=2
WS=3

View File

@ -52,6 +52,7 @@ import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTA
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V12;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_PLANNING_V1;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METADATA_FIELDS_REMOTE_TEST;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.UNMAPPED_FIELDS;
import static org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase.Mode.SYNC;
import static org.mockito.ArgumentMatchers.any;
@ -130,6 +131,8 @@ public class MultiClusterSpecIT extends EsqlSpecTestCase {
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V12.capabilityName()));
// Unmapped fields require a coorect capability response from every cluster, which isn't currently implemented.
assumeFalse("UNMAPPED FIELDS not yet supported in CCS", testCase.requiredCapabilities.contains(UNMAPPED_FIELDS.capabilityName()));
// Need to do additional developmnet to get CSS support for the rerank coammnd
assumeFalse("RERANK not yet supported in CCS", testCase.requiredCapabilities.contains(RERANK.capabilityName()));
}
@Override

View File

@ -66,8 +66,11 @@ import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND;
@ -134,6 +137,10 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
createInferenceEndpoint(client());
}
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
createRerankInferenceEndpoint(client());
}
boolean supportsLookup = supportsIndexModeLookup();
boolean supportsSourceMapping = supportsSourceFieldMapping();
if (indexExists(availableDatasetsForEs(client(), supportsLookup, supportsSourceMapping).iterator().next().indexName()) == false) {
@ -153,6 +160,7 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
}
deleteInferenceEndpoint(client());
deleteRerankInferenceEndpoint(client());
}
public boolean logResults() {

View File

@ -0,0 +1,192 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.qa.rest;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
import static org.hamcrest.core.StringContains.containsString;
public class RestRerankTestCase extends ESRestTestCase {
@Before
public void skipWhenRerankDisabled() throws IOException {
assumeTrue(
"Requires RERANK capability",
EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.RERANK.capabilityName()))
);
}
@Before
@After
public void assertRequestBreakerEmpty() throws Exception {
EsqlSpecTestCase.assertRequestBreakerEmpty();
}
@Before
public void setUpInferenceEndpoint() throws IOException {
createRerankInferenceEndpoint(adminClient());
}
@Before
public void setUpTestIndex() throws IOException {
Request request = new Request("PUT", "/rerank-test-index");
request.setJsonEntity("""
{
"mappings": {
"properties": {
"title": { "type": "text" },
"author": { "type": "text" }
}
}
}""");
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
request = new Request("POST", "/rerank-test-index/_bulk");
request.addParameter("refresh", "true");
request.setJsonEntity("""
{ "index": {"_id": 1} }
{ "title": "The Future of Exploration", "author": "John Doe" }
{ "index": {"_id": 2} }
{ "title": "Deep Sea Exploration", "author": "Jane Smith" }
{ "index": {"_id": 3} }
{ "title": "History of Space Exploration", "author": "Alice Johnson" }
""");
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
}
@After
public void wipeData() throws IOException {
try {
adminClient().performRequest(new Request("DELETE", "/rerank-test-index"));
} catch (ResponseException e) {
// 404 here just means we had no indexes
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
throw e;
}
}
deleteRerankInferenceEndpoint(adminClient());
}
public void testRerankWithSingleField() throws IOException {
String query = """
FROM rerank-test-index
| WHERE match(title, "exploration")
| RERANK "exploration" ON title WITH test_reranker
| EVAL _score = ROUND(_score, 5)
""";
Map<String, Object> result = runEsqlQuery(query);
var expectedValues = List.of(
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
List.of("John Doe", "The Future of Exploration", 0.02632d),
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
);
assertResultMap(result, defaultOutputColumns(), expectedValues);
}
public void testRerankWithMultipleFields() throws IOException {
String query = """
FROM rerank-test-index
| WHERE match(title, "exploration")
| RERANK "exploration" ON title, author WITH test_reranker
| EVAL _score = ROUND(_score, 5)
""";
Map<String, Object> result = runEsqlQuery(query);
;
var expectedValues = List.of(
List.of("Jane Smith", "Deep Sea Exploration", 0.01818d),
List.of("John Doe", "The Future of Exploration", 0.01754d),
List.of("Alice Johnson", "History of Space Exploration", 0.01515d)
);
assertResultMap(result, defaultOutputColumns(), expectedValues);
}
public void testRerankWithPositionalParams() throws IOException {
String query = """
FROM rerank-test-index
| WHERE match(title, "exploration")
| RERANK ? ON title WITH ?
| EVAL _score = ROUND(_score, 5)
""";
Map<String, Object> result = runEsqlQuery(query, "[\"exploration\", \"test_reranker\"]");
var expectedValues = List.of(
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
List.of("John Doe", "The Future of Exploration", 0.02632d),
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
);
assertResultMap(result, defaultOutputColumns(), expectedValues);
}
public void testRerankWithNamedParams() throws IOException {
String query = """
FROM rerank-test-index
| WHERE match(title, ?queryText)
| RERANK ?queryText ON title WITH ?inferenceId
| EVAL _score = ROUND(_score, 5)
""";
Map<String, Object> result = runEsqlQuery(query, "[{\"queryText\": \"exploration\"}, {\"inferenceId\": \"test_reranker\"}]");
var expectedValues = List.of(
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
List.of("John Doe", "The Future of Exploration", 0.02632d),
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
);
assertResultMap(result, defaultOutputColumns(), expectedValues);
}
public void testRerankWithMissingInferenceId() {
String query = """
FROM rerank-test-index
| WHERE match(title, "exploration")
| RERANK "exploration" ON title WITH test_missing
| EVAL _score = ROUND(_score, 5)
""";
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));
assertThat(re.getMessage(), containsString("Inference endpoint not found"));
}
private static List<Map<String, String>> defaultOutputColumns() {
return List.of(
Map.of("name", "author", "type", "text"),
Map.of("name", "title", "type", "text"),
Map.of("name", "_score", "type", "double")
);
}
private Map<String, Object> runEsqlQuery(String query) throws IOException {
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query);
return RestEsqlTestCase.runEsqlSync(builder);
}
private Map<String, Object> runEsqlQuery(String query, String params) throws IOException {
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query).params(params);
return RestEsqlTestCase.runEsqlSync(builder);
}
}

View File

@ -397,6 +397,47 @@ public class CsvTestsDataLoader {
return true;
}
public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
Request request = new Request("PUT", "_inference/rerank/test_reranker");
request.setJsonEntity("""
{
"service": "test_reranking_service",
"service_settings": {
"model_id": "my_model",
"api_key": "abc64"
},
"task_settings": {
"use_text_length": true
}
}
""");
client.performRequest(request);
}
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
try {
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker"));
} catch (ResponseException e) {
// 404 here means the endpoint was not created
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
throw e;
}
}
}
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
Request request = new Request("GET", "_inference/rerank/test_reranker");
try {
client.performRequest(request);
} catch (ResponseException e) {
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
return false;
}
throw e;
}
return true;
}
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
URL policyMapping = getResource("/" + policyFileName);
String entity = readTextFile(policyMapping);

View File

@ -11,6 +11,7 @@ import org.apache.lucene.document.InetAddressPoint;
import org.apache.lucene.sandbox.document.HalfFloatPoint;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.RemoteException;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
@ -68,6 +69,8 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Les
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.QueryParam;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@ -149,6 +152,8 @@ import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
public final class EsqlTestUtils {
@ -376,9 +381,21 @@ public final class EsqlTestUtils {
null,
mock(ClusterService.class),
mock(IndexNameExpressionResolver.class),
null
null,
mockInferenceRunner()
);
@SuppressWarnings("unchecked")
private static InferenceRunner mockInferenceRunner() {
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
doAnswer(i -> {
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
return null;
}).when(inferenceRunner).resolveInferenceIds(any(), any());
return inferenceRunner;
}
private EsqlTestUtils() {}
public static Configuration configuration(QueryPragmas pragmas, String query) {
@ -454,6 +471,10 @@ public final class EsqlTestUtils {
return new EnrichResolution();
}
public static InferenceResolution emptyInferenceResolution() {
return InferenceResolution.EMPTY;
}
public static SearchStats statsForExistingField(String... names) {
return fieldMatchingExistOrMissing(true, names);
}

View File

@ -0,0 +1,142 @@
// Note:
// The "test_reranker" service scores the row from the inputText length and does not really score by relevance.
// This makes the output more predictable which is helpful here.
reranker using a single field
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author, _score
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.03846
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
;
reranker using multiple fields
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title, author WITH test_reranker
| KEEP book_no, title, author, _score
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.02083
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.01429
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01136
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.00952
;
reranker after a limit
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| SORT _score DESC
| LIMIT 3
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author, _score
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.03846
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
;
reranker before a limit
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author, _score
| LIMIT 3
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.03846
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
;
reranker add the _score column when missing
required_capability: rerank
required_capability: match_operator_colon
FROM books
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author, _score
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.03846
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
;
reranker using another sort order
required_capability: rerank
required_capability: match_operator_colon
FROM books
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author, _score
| SORT author, title
| LIMIT 3
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:text | author:text | _score:double
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
5327 | War and Peace | Leo Tolstoy | 0.03846
;
reranker after RRF
required_capability: fork
required_capability: rrf
required_capability: match_operator_colon
required_capability: rerank
FROM books METADATA _id, _index, _score
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
| RRF
| RERANK "Tolkien" ON title WITH test_reranker
| LIMIT 2
| KEEP book_no, title, author, _score
| EVAL _score = ROUND(_score, 5)
;
book_no:keyword | title:keyword | author:keyword | _score:double
5335 | Letters of J R R Tolkien | J.R.R. Tolkien | 0.02632
2130 | The J. R. R. Tolkien Audio Collection | [Christopher Tolkien, John Ronald Reuel Tolkien] | 0.01961
;

View File

@ -13,49 +13,49 @@ SORT=12
STATS=13
WHERE=14
DEV_INLINESTATS=15
FROM=16
DEV_TIME_SERIES=17
DEV_FORK=18
JOIN_LOOKUP=19
DEV_JOIN_FULL=20
DEV_JOIN_LEFT=21
DEV_JOIN_RIGHT=22
DEV_LOOKUP=23
MV_EXPAND=24
DROP=25
KEEP=26
DEV_INSIST=27
DEV_RRF=28
RENAME=29
SHOW=30
UNKNOWN_CMD=31
CHANGE_POINT_LINE_COMMENT=32
CHANGE_POINT_MULTILINE_COMMENT=33
CHANGE_POINT_WS=34
ON=35
WITH=36
ENRICH_POLICY_NAME=37
ENRICH_LINE_COMMENT=38
ENRICH_MULTILINE_COMMENT=39
ENRICH_WS=40
ENRICH_FIELD_LINE_COMMENT=41
ENRICH_FIELD_MULTILINE_COMMENT=42
ENRICH_FIELD_WS=43
SETTING=44
SETTING_LINE_COMMENT=45
SETTTING_MULTILINE_COMMENT=46
SETTING_WS=47
EXPLAIN_WS=48
EXPLAIN_LINE_COMMENT=49
EXPLAIN_MULTILINE_COMMENT=50
PIPE=51
QUOTED_STRING=52
INTEGER_LITERAL=53
DECIMAL_LITERAL=54
BY=55
AND=56
ASC=57
ASSIGN=58
DEV_RERANK=16
FROM=17
DEV_TIME_SERIES=18
DEV_FORK=19
JOIN_LOOKUP=20
DEV_JOIN_FULL=21
DEV_JOIN_LEFT=22
DEV_JOIN_RIGHT=23
DEV_LOOKUP=24
MV_EXPAND=25
DROP=26
KEEP=27
DEV_INSIST=28
DEV_RRF=29
RENAME=30
SHOW=31
UNKNOWN_CMD=32
CHANGE_POINT_LINE_COMMENT=33
CHANGE_POINT_MULTILINE_COMMENT=34
CHANGE_POINT_WS=35
ENRICH_POLICY_NAME=36
ENRICH_LINE_COMMENT=37
ENRICH_MULTILINE_COMMENT=38
ENRICH_WS=39
ENRICH_FIELD_LINE_COMMENT=40
ENRICH_FIELD_MULTILINE_COMMENT=41
ENRICH_FIELD_WS=42
SETTING=43
SETTING_LINE_COMMENT=44
SETTTING_MULTILINE_COMMENT=45
SETTING_WS=46
EXPLAIN_WS=47
EXPLAIN_LINE_COMMENT=48
EXPLAIN_MULTILINE_COMMENT=49
PIPE=50
QUOTED_STRING=51
INTEGER_LITERAL=52
DECIMAL_LITERAL=53
AND=54
AS=55
ASC=56
ASSIGN=57
BY=58
CAST_OP=59
COLON=60
COMMA=61
@ -70,70 +70,71 @@ LIKE=69
NOT=70
NULL=71
NULLS=72
OR=73
PARAM=74
RLIKE=75
TRUE=76
EQ=77
CIEQ=78
NEQ=79
LT=80
LTE=81
GT=82
GTE=83
PLUS=84
MINUS=85
ASTERISK=86
SLASH=87
PERCENT=88
LEFT_BRACES=89
RIGHT_BRACES=90
DOUBLE_PARAMS=91
NAMED_OR_POSITIONAL_PARAM=92
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=93
OPENING_BRACKET=94
CLOSING_BRACKET=95
LP=96
RP=97
UNQUOTED_IDENTIFIER=98
QUOTED_IDENTIFIER=99
EXPR_LINE_COMMENT=100
EXPR_MULTILINE_COMMENT=101
EXPR_WS=102
METADATA=103
UNQUOTED_SOURCE=104
FROM_LINE_COMMENT=105
FROM_MULTILINE_COMMENT=106
FROM_WS=107
FORK_WS=108
FORK_LINE_COMMENT=109
FORK_MULTILINE_COMMENT=110
JOIN=111
USING=112
JOIN_LINE_COMMENT=113
JOIN_MULTILINE_COMMENT=114
JOIN_WS=115
LOOKUP_LINE_COMMENT=116
LOOKUP_MULTILINE_COMMENT=117
LOOKUP_WS=118
LOOKUP_FIELD_LINE_COMMENT=119
LOOKUP_FIELD_MULTILINE_COMMENT=120
LOOKUP_FIELD_WS=121
MVEXPAND_LINE_COMMENT=122
MVEXPAND_MULTILINE_COMMENT=123
MVEXPAND_WS=124
ID_PATTERN=125
PROJECT_LINE_COMMENT=126
PROJECT_MULTILINE_COMMENT=127
PROJECT_WS=128
AS=129
RENAME_LINE_COMMENT=130
RENAME_MULTILINE_COMMENT=131
RENAME_WS=132
INFO=133
SHOW_LINE_COMMENT=134
SHOW_MULTILINE_COMMENT=135
SHOW_WS=136
ON=73
OR=74
PARAM=75
RLIKE=76
TRUE=77
WITH=78
EQ=79
CIEQ=80
NEQ=81
LT=82
LTE=83
GT=84
GTE=85
PLUS=86
MINUS=87
ASTERISK=88
SLASH=89
PERCENT=90
LEFT_BRACES=91
RIGHT_BRACES=92
DOUBLE_PARAMS=93
NAMED_OR_POSITIONAL_PARAM=94
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=95
OPENING_BRACKET=96
CLOSING_BRACKET=97
LP=98
RP=99
UNQUOTED_IDENTIFIER=100
QUOTED_IDENTIFIER=101
EXPR_LINE_COMMENT=102
EXPR_MULTILINE_COMMENT=103
EXPR_WS=104
METADATA=105
UNQUOTED_SOURCE=106
FROM_LINE_COMMENT=107
FROM_MULTILINE_COMMENT=108
FROM_WS=109
FORK_WS=110
FORK_LINE_COMMENT=111
FORK_MULTILINE_COMMENT=112
JOIN=113
USING=114
JOIN_LINE_COMMENT=115
JOIN_MULTILINE_COMMENT=116
JOIN_WS=117
LOOKUP_LINE_COMMENT=118
LOOKUP_MULTILINE_COMMENT=119
LOOKUP_WS=120
LOOKUP_FIELD_LINE_COMMENT=121
LOOKUP_FIELD_MULTILINE_COMMENT=122
LOOKUP_FIELD_WS=123
MVEXPAND_LINE_COMMENT=124
MVEXPAND_MULTILINE_COMMENT=125
MVEXPAND_WS=126
ID_PATTERN=127
PROJECT_LINE_COMMENT=128
PROJECT_MULTILINE_COMMENT=129
PROJECT_WS=130
RENAME_LINE_COMMENT=131
RENAME_MULTILINE_COMMENT=132
RENAME_WS=133
INFO=134
SHOW_LINE_COMMENT=135
SHOW_MULTILINE_COMMENT=136
SHOW_WS=137
'enrich'=5
'explain'=6
'dissect'=7
@ -144,20 +145,19 @@ SHOW_WS=136
'sort'=12
'stats'=13
'where'=14
'from'=16
'lookup'=19
'mv_expand'=24
'drop'=25
'keep'=26
'rename'=29
'show'=30
'on'=35
'with'=36
'|'=51
'by'=55
'and'=56
'asc'=57
'='=58
'from'=17
'lookup'=20
'mv_expand'=25
'drop'=26
'keep'=27
'rename'=30
'show'=31
'|'=50
'and'=54
'as'=55
'asc'=56
'='=57
'by'=58
'::'=59
':'=60
','=61
@ -172,29 +172,30 @@ SHOW_WS=136
'not'=70
'null'=71
'nulls'=72
'or'=73
'?'=74
'rlike'=75
'true'=76
'=='=77
'=~'=78
'!='=79
'<'=80
'<='=81
'>'=82
'>='=83
'+'=84
'-'=85
'*'=86
'/'=87
'%'=88
'{'=89
'}'=90
'??'=91
']'=95
')'=97
'metadata'=103
'join'=111
'USING'=112
'as'=129
'info'=133
'on'=73
'or'=74
'?'=75
'rlike'=76
'true'=77
'with'=78
'=='=79
'=~'=80
'!='=81
'<'=82
'<='=83
'>'=84
'>='=85
'+'=86
'-'=87
'*'=88
'/'=89
'%'=90
'{'=91
'}'=92
'??'=93
']'=97
')'=99
'metadata'=105
'join'=113
'USING'=114
'info'=134

View File

@ -61,6 +61,7 @@ processingCommand
| {this.isDevVersion()}? changePointCommand
| {this.isDevVersion()}? insistCommand
| {this.isDevVersion()}? forkCommand
| {this.isDevVersion()}? rerankCommand
| {this.isDevVersion()}? rrfCommand
;
@ -288,3 +289,7 @@ forkSubQueryProcessingCommand
rrfCommand
: DEV_RRF
;
rerankCommand
: DEV_RERANK queryText=constant ON fields WITH inferenceId=identifierOrParameter
;

View File

@ -13,49 +13,49 @@ SORT=12
STATS=13
WHERE=14
DEV_INLINESTATS=15
FROM=16
DEV_TIME_SERIES=17
DEV_FORK=18
JOIN_LOOKUP=19
DEV_JOIN_FULL=20
DEV_JOIN_LEFT=21
DEV_JOIN_RIGHT=22
DEV_LOOKUP=23
MV_EXPAND=24
DROP=25
KEEP=26
DEV_INSIST=27
DEV_RRF=28
RENAME=29
SHOW=30
UNKNOWN_CMD=31
CHANGE_POINT_LINE_COMMENT=32
CHANGE_POINT_MULTILINE_COMMENT=33
CHANGE_POINT_WS=34
ON=35
WITH=36
ENRICH_POLICY_NAME=37
ENRICH_LINE_COMMENT=38
ENRICH_MULTILINE_COMMENT=39
ENRICH_WS=40
ENRICH_FIELD_LINE_COMMENT=41
ENRICH_FIELD_MULTILINE_COMMENT=42
ENRICH_FIELD_WS=43
SETTING=44
SETTING_LINE_COMMENT=45
SETTTING_MULTILINE_COMMENT=46
SETTING_WS=47
EXPLAIN_WS=48
EXPLAIN_LINE_COMMENT=49
EXPLAIN_MULTILINE_COMMENT=50
PIPE=51
QUOTED_STRING=52
INTEGER_LITERAL=53
DECIMAL_LITERAL=54
BY=55
AND=56
ASC=57
ASSIGN=58
DEV_RERANK=16
FROM=17
DEV_TIME_SERIES=18
DEV_FORK=19
JOIN_LOOKUP=20
DEV_JOIN_FULL=21
DEV_JOIN_LEFT=22
DEV_JOIN_RIGHT=23
DEV_LOOKUP=24
MV_EXPAND=25
DROP=26
KEEP=27
DEV_INSIST=28
DEV_RRF=29
RENAME=30
SHOW=31
UNKNOWN_CMD=32
CHANGE_POINT_LINE_COMMENT=33
CHANGE_POINT_MULTILINE_COMMENT=34
CHANGE_POINT_WS=35
ENRICH_POLICY_NAME=36
ENRICH_LINE_COMMENT=37
ENRICH_MULTILINE_COMMENT=38
ENRICH_WS=39
ENRICH_FIELD_LINE_COMMENT=40
ENRICH_FIELD_MULTILINE_COMMENT=41
ENRICH_FIELD_WS=42
SETTING=43
SETTING_LINE_COMMENT=44
SETTTING_MULTILINE_COMMENT=45
SETTING_WS=46
EXPLAIN_WS=47
EXPLAIN_LINE_COMMENT=48
EXPLAIN_MULTILINE_COMMENT=49
PIPE=50
QUOTED_STRING=51
INTEGER_LITERAL=52
DECIMAL_LITERAL=53
AND=54
AS=55
ASC=56
ASSIGN=57
BY=58
CAST_OP=59
COLON=60
COMMA=61
@ -70,70 +70,71 @@ LIKE=69
NOT=70
NULL=71
NULLS=72
OR=73
PARAM=74
RLIKE=75
TRUE=76
EQ=77
CIEQ=78
NEQ=79
LT=80
LTE=81
GT=82
GTE=83
PLUS=84
MINUS=85
ASTERISK=86
SLASH=87
PERCENT=88
LEFT_BRACES=89
RIGHT_BRACES=90
DOUBLE_PARAMS=91
NAMED_OR_POSITIONAL_PARAM=92
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=93
OPENING_BRACKET=94
CLOSING_BRACKET=95
LP=96
RP=97
UNQUOTED_IDENTIFIER=98
QUOTED_IDENTIFIER=99
EXPR_LINE_COMMENT=100
EXPR_MULTILINE_COMMENT=101
EXPR_WS=102
METADATA=103
UNQUOTED_SOURCE=104
FROM_LINE_COMMENT=105
FROM_MULTILINE_COMMENT=106
FROM_WS=107
FORK_WS=108
FORK_LINE_COMMENT=109
FORK_MULTILINE_COMMENT=110
JOIN=111
USING=112
JOIN_LINE_COMMENT=113
JOIN_MULTILINE_COMMENT=114
JOIN_WS=115
LOOKUP_LINE_COMMENT=116
LOOKUP_MULTILINE_COMMENT=117
LOOKUP_WS=118
LOOKUP_FIELD_LINE_COMMENT=119
LOOKUP_FIELD_MULTILINE_COMMENT=120
LOOKUP_FIELD_WS=121
MVEXPAND_LINE_COMMENT=122
MVEXPAND_MULTILINE_COMMENT=123
MVEXPAND_WS=124
ID_PATTERN=125
PROJECT_LINE_COMMENT=126
PROJECT_MULTILINE_COMMENT=127
PROJECT_WS=128
AS=129
RENAME_LINE_COMMENT=130
RENAME_MULTILINE_COMMENT=131
RENAME_WS=132
INFO=133
SHOW_LINE_COMMENT=134
SHOW_MULTILINE_COMMENT=135
SHOW_WS=136
ON=73
OR=74
PARAM=75
RLIKE=76
TRUE=77
WITH=78
EQ=79
CIEQ=80
NEQ=81
LT=82
LTE=83
GT=84
GTE=85
PLUS=86
MINUS=87
ASTERISK=88
SLASH=89
PERCENT=90
LEFT_BRACES=91
RIGHT_BRACES=92
DOUBLE_PARAMS=93
NAMED_OR_POSITIONAL_PARAM=94
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=95
OPENING_BRACKET=96
CLOSING_BRACKET=97
LP=98
RP=99
UNQUOTED_IDENTIFIER=100
QUOTED_IDENTIFIER=101
EXPR_LINE_COMMENT=102
EXPR_MULTILINE_COMMENT=103
EXPR_WS=104
METADATA=105
UNQUOTED_SOURCE=106
FROM_LINE_COMMENT=107
FROM_MULTILINE_COMMENT=108
FROM_WS=109
FORK_WS=110
FORK_LINE_COMMENT=111
FORK_MULTILINE_COMMENT=112
JOIN=113
USING=114
JOIN_LINE_COMMENT=115
JOIN_MULTILINE_COMMENT=116
JOIN_WS=117
LOOKUP_LINE_COMMENT=118
LOOKUP_MULTILINE_COMMENT=119
LOOKUP_WS=120
LOOKUP_FIELD_LINE_COMMENT=121
LOOKUP_FIELD_MULTILINE_COMMENT=122
LOOKUP_FIELD_WS=123
MVEXPAND_LINE_COMMENT=124
MVEXPAND_MULTILINE_COMMENT=125
MVEXPAND_WS=126
ID_PATTERN=127
PROJECT_LINE_COMMENT=128
PROJECT_MULTILINE_COMMENT=129
PROJECT_WS=130
RENAME_LINE_COMMENT=131
RENAME_MULTILINE_COMMENT=132
RENAME_WS=133
INFO=134
SHOW_LINE_COMMENT=135
SHOW_MULTILINE_COMMENT=136
SHOW_WS=137
'enrich'=5
'explain'=6
'dissect'=7
@ -144,20 +145,19 @@ SHOW_WS=136
'sort'=12
'stats'=13
'where'=14
'from'=16
'lookup'=19
'mv_expand'=24
'drop'=25
'keep'=26
'rename'=29
'show'=30
'on'=35
'with'=36
'|'=51
'by'=55
'and'=56
'asc'=57
'='=58
'from'=17
'lookup'=20
'mv_expand'=25
'drop'=26
'keep'=27
'rename'=30
'show'=31
'|'=50
'and'=54
'as'=55
'asc'=56
'='=57
'by'=58
'::'=59
':'=60
','=61
@ -172,29 +172,30 @@ SHOW_WS=136
'not'=70
'null'=71
'nulls'=72
'or'=73
'?'=74
'rlike'=75
'true'=76
'=='=77
'=~'=78
'!='=79
'<'=80
'<='=81
'>'=82
'>='=83
'+'=84
'-'=85
'*'=86
'/'=87
'%'=88
'{'=89
'}'=90
'??'=91
']'=95
')'=97
'metadata'=103
'join'=111
'USING'=112
'as'=129
'info'=133
'on'=73
'or'=74
'?'=75
'rlike'=76
'true'=77
'with'=78
'=='=79
'=~'=80
'!='=81
'<'=82
'<='=83
'>'=84
'>='=85
'+'=86
'-'=87
'*'=88
'/'=89
'%'=90
'{'=91
'}'=92
'??'=93
']'=97
')'=99
'metadata'=105
'join'=113
'USING'=114
'info'=134

View File

@ -16,8 +16,8 @@ mode ENRICH_MODE;
ENRICH_PIPE : PIPE -> type(PIPE), popMode;
ENRICH_OPENING_BRACKET : OPENING_BRACKET -> type(OPENING_BRACKET), pushMode(SETTING_MODE);
ON : 'on' -> pushMode(ENRICH_FIELD_MODE);
WITH : 'with' -> pushMode(ENRICH_FIELD_MODE);
ENRICH_ON : ON -> type(ON), pushMode(ENRICH_FIELD_MODE);
ENRICH_WITH : WITH -> type(WITH), pushMode(ENRICH_FIELD_MODE);
// similar to that of an index
// see https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html#indices-create-api-path-params

View File

@ -19,6 +19,7 @@ STATS : 'stats' -> pushMode(EXPRESSION_MODE);
WHERE : 'where' -> pushMode(EXPRESSION_MODE);
DEV_INLINESTATS : {this.isDevVersion()}? 'inlinestats' -> pushMode(EXPRESSION_MODE);
DEV_RERANK : {this.isDevVersion()}? 'rerank' -> pushMode(EXPRESSION_MODE);
mode EXPRESSION_MODE;
@ -82,11 +83,12 @@ DECIMAL_LITERAL
| DOT DIGIT+ EXPONENT
;
BY : 'by';
AND : 'and';
AS: 'as';
ASC : 'asc';
ASSIGN : '=';
BY : 'by';
CAST_OP : '::';
COLON : ':';
COMMA : ',';
@ -101,10 +103,12 @@ LIKE: 'like';
NOT : 'not';
NULL : 'null';
NULLS : 'nulls';
ON: 'on';
OR : 'or';
PARAM: '?';
RLIKE: 'rlike';
TRUE : 'true';
WITH: 'with';
EQ : '==';
CIEQ : '=~';

View File

@ -22,7 +22,7 @@ RENAME_NAMED_OR_POSITIONAL_PARAM : NAMED_OR_POSITIONAL_PARAM -> type(NAMED_OR_PO
RENAME_DOUBLE_PARAMS : DOUBLE_PARAMS -> type(DOUBLE_PARAMS);
RENAME_NAMED_OR_POSITIONAL_DOUBLE_PARAMS : NAMED_OR_POSITIONAL_DOUBLE_PARAMS -> type(NAMED_OR_POSITIONAL_DOUBLE_PARAMS);
AS : 'as';
RENAME_AS : AS -> type(AS);
RENAME_ID_PATTERN
: ID_PATTERN -> type(ID_PATTERN)

View File

@ -873,6 +873,11 @@ public class EsqlCapabilities {
*/
FORK(Build.current().isSnapshot()),
/**
* Support for RERANK command
*/
RERANK(Build.current().isSnapshot()),
/**
* Allow mixed numeric types in conditional functions - case, greatest and least
*/

View File

@ -33,7 +33,7 @@ import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.nanoTimeTo
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.spatialToString;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.versionToString;
abstract class PositionToXContent {
public abstract class PositionToXContent {
protected final Block block;
PositionToXContent(Block block) {

View File

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.common.logging.HeaderWarning;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.IndexMode;
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@ -67,6 +69,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Esq
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
@ -86,6 +89,8 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Rename;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
@ -122,7 +127,6 @@ import java.util.stream.Collectors;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
@ -154,7 +158,15 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
);
private static final List<Batch<LogicalPlan>> RULES = List.of(
new Batch<>("Initialize", Limiter.ONCE, new ResolveTable(), new ResolveEnrich(), new ResolveLookupTables(), new ResolveFunctions()),
new Batch<>(
"Initialize",
Limiter.ONCE,
new ResolveTable(),
new ResolveEnrich(),
new ResolveInference(),
new ResolveLookupTables(),
new ResolveFunctions()
),
new Batch<>(
"Resolution",
/*
@ -380,6 +392,34 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
}
}
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan, AnalyzerContext> {
@Override
protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) {
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
String inferenceId = plan.inferenceId().fold(FoldContext.small()).toString();
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
return plan;
} else if (resolvedInference != null) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
} else {
String error = context.inferenceResolution().getError(inferenceId);
return plan.withInferenceResolutionError(inferenceId, error);
}
}
}
private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {
@Override
@ -491,6 +531,10 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
return resolveRrfScoreEval(rrf, childrenOutput);
}
if (plan instanceof Rerank r) {
return resolveRerank(r, childrenOutput);
}
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}
@ -670,6 +714,33 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
return join;
}
private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput) {
List<Alias> newFields = new ArrayList<>();
boolean changed = false;
// First resolving fields used in expression
for (Alias field : rerank.rerankFields()) {
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childrenOutput));
newFields.add(result);
changed |= result != field;
}
if (changed) {
rerank = rerank.withRerankFields(newFields);
}
// Ensure the score attribute is present in the output.
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
Attribute resolved = resolveAttribute(ua, childrenOutput);
if (resolved.resolved() == false || resolved.dataType() != DOUBLE) {
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
}
rerank = rerank.withScoreAttribute(resolved);
}
return rerank;
}
private List<Attribute> resolveUsingColumns(List<Attribute> cols, List<Attribute> output, String side) {
List<Attribute> resolved = new ArrayList<>(cols.size());
for (Attribute col : cols) {
@ -987,7 +1058,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
var u = resolved;
var previousAliasName = reverseAliasing.get(resolved.name());
if (previousAliasName != null) {
String message = format(
String message = LoggerMessageFormat.format(
null,
"Column [{}] renamed to [{}] and is no longer available [{}]",
resolved.name(),
@ -1380,7 +1451,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
}
private static UnresolvedAttribute unresolvedAttribute(Expression value, String type, Exception e) {
String message = format(
String message = LoggerMessageFormat.format(
"Cannot convert string [{}] to [{}], error [{}]",
value.fold(FoldContext.small() /* TODO remove me */),
type,

View File

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.Map;
@ -18,7 +19,8 @@ public record AnalyzerContext(
EsqlFunctionRegistry functionRegistry,
IndexResolution indexResolution,
Map<String, IndexResolution> lookupResolution,
EnrichResolution enrichResolution
EnrichResolution enrichResolution,
InferenceResolution inferenceResolution
) {
// Currently for tests only, since most do not test lookups
// TODO: make this even simpler, remove the enrichResolution for tests that do not require it (most tests)
@ -26,8 +28,9 @@ public record AnalyzerContext(
Configuration configuration,
EsqlFunctionRegistry functionRegistry,
IndexResolution indexResolution,
EnrichResolution enrichResolution
EnrichResolution enrichResolution,
InferenceResolution inferenceResolution
) {
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution);
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution, inferenceResolution);
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import java.util.ArrayList;
import java.util.HashSet;
@ -26,15 +27,22 @@ import static java.util.Collections.emptyList;
public class PreAnalyzer {
public static class PreAnalysis {
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList());
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList(), emptyList());
public final List<IndexPattern> indices;
public final List<Enrich> enriches;
public final List<InferencePlan> inferencePlans;
public final List<IndexPattern> lookupIndices;
public PreAnalysis(List<IndexPattern> indices, List<Enrich> enriches, List<IndexPattern> lookupIndices) {
public PreAnalysis(
List<IndexPattern> indices,
List<Enrich> enriches,
List<InferencePlan> inferencePlans,
List<IndexPattern> lookupIndices
) {
this.indices = indices;
this.enriches = enriches;
this.inferencePlans = inferencePlans;
this.lookupIndices = lookupIndices;
}
}
@ -52,13 +60,15 @@ public class PreAnalyzer {
List<Enrich> unresolvedEnriches = new ArrayList<>();
List<IndexPattern> lookupIndices = new ArrayList<>();
List<InferencePlan> unresolvedInferencePlans = new ArrayList<>();
plan.forEachUp(UnresolvedRelation.class, p -> (p.indexMode() == IndexMode.LOOKUP ? lookupIndices : indices).add(p.indexPattern()));
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);
return new PreAnalysis(indices.stream().toList(), unresolvedEnriches, lookupIndices);
return new PreAnalysis(indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
public class InferenceResolution {
public static final InferenceResolution EMPTY = new InferenceResolution.Builder().build();
public static InferenceResolution.Builder builder() {
return new Builder();
}
private final Map<String, ResolvedInference> resolvedInferences;
private final Map<String, String> errors;
private InferenceResolution(Map<String, ResolvedInference> resolvedInferences, Map<String, String> errors) {
this.resolvedInferences = Collections.unmodifiableMap(resolvedInferences);
this.errors = Collections.unmodifiableMap(errors);
}
public ResolvedInference getResolvedInference(String inferenceId) {
return resolvedInferences.get(inferenceId);
}
public Collection<ResolvedInference> resolvedInferences() {
return resolvedInferences.values();
}
public boolean hasError() {
return errors.isEmpty() == false;
}
public String getError(String inferenceId) {
final String error = errors.get(inferenceId);
if (error != null) {
return error;
} else {
return "unresolved inference [" + inferenceId + "]";
}
}
public static class Builder {
private final Map<String, ResolvedInference> resolvedInferences;
private final Map<String, String> errors;
private Builder() {
this.resolvedInferences = ConcurrentCollections.newConcurrentMap();
this.errors = ConcurrentCollections.newConcurrentMap();
}
public Builder withResolvedInference(ResolvedInference resolvedInference) {
resolvedInferences.putIfAbsent(resolvedInference.inferenceId(), resolvedInference);
return this;
}
public Builder withError(String inferenceId, String reason) {
errors.putIfAbsent(inferenceId, reason);
return this;
}
public InferenceResolution build() {
return new InferenceResolution(resolvedInferences, errors);
}
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class InferenceRunner {
private final Client client;
public InferenceRunner(Client client) {
this.client = client;
}
public ThreadContext getThreadContext() {
return client.threadPool().getThreadContext();
}
public void resolveInferenceIds(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
}
private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
if (inferenceIds.isEmpty()) {
listener.onResponse(InferenceResolution.EMPTY);
return;
}
final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
final CountDownActionListener countdownListener = new CountDownActionListener(
inferenceIds.size(),
ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
);
for (var inferenceId : inferenceIds) {
client.execute(
GetInferenceModelAction.INSTANCE,
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
ActionListener.wrap(r -> {
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
countdownListener.onResponse(null);
}, e -> {
inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
countdownListener.onResponse(null);
})
);
}
}
private static String planInferenceId(InferencePlan plan) {
return plan.inferenceId().fold(FoldContext.small()).toString();
}
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
client.execute(InferenceAction.INSTANCE, request, listener);
}
}

View File

@ -0,0 +1,198 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import java.util.List;
public class RerankOperator extends AsyncOperator<Page> {
// Move to a setting.
private static final int MAX_INFERENCE_WORKER = 10;
public record Factory(
InferenceRunner inferenceRunner,
String inferenceId,
String queryText,
ExpressionEvaluator.Factory rowEncoderFactory,
int scoreChannel
) implements OperatorFactory {
@Override
public String describe() {
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
}
@Override
public Operator get(DriverContext driverContext) {
return new RerankOperator(
driverContext,
inferenceRunner,
inferenceId,
queryText,
rowEncoderFactory().get(driverContext),
scoreChannel
);
}
}
private final InferenceRunner inferenceRunner;
private final BlockFactory blockFactory;
private final String inferenceId;
private final String queryText;
private final ExpressionEvaluator rowEncoder;
private final int scoreChannel;
public RerankOperator(
DriverContext driverContext,
InferenceRunner inferenceRunner,
String inferenceId,
String queryText,
ExpressionEvaluator rowEncoder,
int scoreChannel
) {
super(driverContext, inferenceRunner.getThreadContext(), MAX_INFERENCE_WORKER);
assert inferenceRunner.getThreadContext() != null;
this.blockFactory = driverContext.blockFactory();
this.inferenceRunner = inferenceRunner;
this.inferenceId = inferenceId;
this.queryText = queryText;
this.rowEncoder = rowEncoder;
this.scoreChannel = scoreChannel;
}
@Override
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
// Ensure input page blocks are released when the listener is called.
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
try {
inferenceRunner.doInference(
buildInferenceRequest(inputPage),
ActionListener.wrap(
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
outputListener::onFailure
)
);
} catch (Exception e) {
outputListener.onFailure(e);
}
}
@Override
protected void doClose() {
Releasables.closeExpectNoException(rowEncoder);
}
@Override
protected void releaseFetchedOnAnyThread(Page page) {
releasePageOnAnyThread(page);
}
@Override
public Page getOutput() {
return fetchFromBuffer();
}
@Override
public String toString() {
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
}
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
return buildOutput(inputPage, rankedDocsResults);
}
throw new IllegalStateException(
"Inference result has wrong type. Got ["
+ inferenceResponse.getResults().getClass()
+ "] while expecting ["
+ RankedDocsResults.class
+ "]"
);
}
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
Block[] blocks = new Block[blockCount];
try {
for (int b = 0; b < blockCount; b++) {
if (b == scoreChannel) {
blocks[b] = buildScoreBlock(inputPage, rankedDocsResults);
} else {
blocks[b] = inputPage.getBlock(b);
blocks[b].incRef();
}
}
return new Page(blocks);
} catch (Exception e) {
Releasables.closeExpectNoException(blocks);
throw (e);
}
}
private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) {
Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()];
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) {
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
}
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
if (sortedRankedDocsScores[pos] != null) {
scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]);
} else {
scoreBlockFactory.appendNull();
}
}
return scoreBlockFactory.build();
}
}
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
String[] inputs = new String[inputPage.getPositionCount()];
BytesRef buffer = new BytesRef();
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
if (encodedRowsBlock.isNull(pos)) {
inputs[pos] = "";
} else {
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
inputs[pos] = BytesRefs.toString(buffer);
}
}
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
}
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import java.io.IOException;
public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable {
public ResolvedInference(StreamInput in) throws IOException {
this(in.readString(), TaskType.valueOf(in.readString()));
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(inferenceId);
out.writeString(taskType.name());
}
}

View File

@ -0,0 +1,145 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.BytesRefStreamOutput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
import org.elasticsearch.xpack.esql.action.PositionToXContent;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Encodes rows into an XContent format (JSON,YAML,...) for further processing.
* Extracted columns can be specified using {@link ExpressionEvaluator}
*/
public class XContentRowEncoder implements ExpressionEvaluator {
private final XContentType xContentType;
private final BlockFactory blockFactory;
private final ColumnInfoImpl[] columnsInfo;
private final ExpressionEvaluator[] fieldsValueEvaluators;
/**
* Creates a factory for YAML XContent row encoding.
*
* @param fieldsEvaluatorFactories A map of column information to expression evaluators.
* @return A Factory instance for creating YAML row encoder for the specified column.
*/
public static Factory yamlRowEncoderFactory(Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
return new Factory(XContentType.YAML, fieldsEvaluatorFactories);
}
private XContentRowEncoder(
XContentType xContentType,
BlockFactory blockFactory,
ColumnInfoImpl[] columnsInfo,
ExpressionEvaluator[] fieldsValueEvaluators
) {
assert columnsInfo.length == fieldsValueEvaluators.length;
this.xContentType = xContentType;
this.blockFactory = blockFactory;
this.columnsInfo = columnsInfo;
this.fieldsValueEvaluators = fieldsValueEvaluators;
}
@Override
public void close() {
Releasables.closeExpectNoException(fieldsValueEvaluators);
}
/**
* Process the provided Page and encode its rows into a BytesRefBlock containing XContent-formatted rows.
*
* @param page The input Page containing row data.
* @return A BytesRefBlock containing the encoded rows.
*/
@Override
public BytesRefBlock eval(Page page) {
Block[] fieldValueBlocks = new Block[fieldsValueEvaluators.length];
try (
BytesRefStreamOutput outputStream = new BytesRefStreamOutput();
XContentBuilder xContentBuilder = XContentFactory.contentBuilder(xContentType, outputStream);
BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(page.getPositionCount());
) {
PositionToXContent[] toXContents = new PositionToXContent[fieldsValueEvaluators.length];
for (int b = 0; b < fieldValueBlocks.length; b++) {
fieldValueBlocks[b] = fieldsValueEvaluators[b].eval(page);
toXContents[b] = PositionToXContent.positionToXContent(columnsInfo[b], fieldValueBlocks[b], new BytesRef());
}
for (int pos = 0; pos < page.getPositionCount(); pos++) {
xContentBuilder.startObject();
for (int i = 0; i < fieldValueBlocks.length; i++) {
String fieldName = columnsInfo[i].name();
Block currentBlock = fieldValueBlocks[i];
if (currentBlock.isNull(pos) || currentBlock.getValueCount(pos) < 1) {
continue;
}
toXContents[i].positionToXContent(xContentBuilder.field(fieldName), ToXContent.EMPTY_PARAMS, pos);
}
xContentBuilder.endObject().flush();
outputBlockBuilder.appendBytesRef(outputStream.get());
outputStream.reset();
}
return outputBlockBuilder.build();
} catch (IOException e) {
throw new UncheckedIOException(e);
} finally {
Releasables.closeExpectNoException(fieldValueBlocks);
}
}
public List<String> fieldNames() {
return Arrays.stream(columnsInfo).map(ColumnInfoImpl::name).collect(Collectors.toList());
}
@Override
public String toString() {
return "XContentRowEncoder[content_type=[" + xContentType.toString() + "], field_names=" + fieldNames() + "]";
}
public static class Factory implements ExpressionEvaluator.Factory {
private final XContentType xContentType;
private final Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
private Factory(XContentType xContentType, Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
this.xContentType = xContentType;
this.fieldsEvaluatorFactories = fieldsEvaluatorFactories;
}
public XContentRowEncoder get(DriverContext context) {
return new XContentRowEncoder(xContentType, context.blockFactory(), columnsInfo(), fieldsValueEvaluators(context));
}
private ColumnInfoImpl[] columnsInfo() {
return fieldsEvaluatorFactories.keySet().toArray(ColumnInfoImpl[]::new);
}
private ExpressionEvaluator[] fieldsValueEvaluators(DriverContext context) {
return fieldsEvaluatorFactories.values().stream().map(factory -> factory.get(context)).toArray(ExpressionEvaluator[]::new);
}
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -740,6 +740,18 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
* <p>The default implementation does nothing.</p>
*/
@Override public void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { }
/**
* {@inheritDoc}
*

View File

@ -440,6 +440,13 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*

View File

@ -635,6 +635,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
* @param ctx the parse tree
*/
void enterRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
/**
* Exit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
* @param ctx the parse tree
*/
void exitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
/**
* Enter a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.

View File

@ -388,6 +388,12 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
* @return the visitor result
*/
T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
/**
* Visit a parse tree produced by the {@code matchExpression}
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.

View File

@ -68,6 +68,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@ -707,4 +708,56 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
return new OrderBy(source, dedup, order);
};
}
@Override
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
var source = source(ctx);
if (false == EsqlCapabilities.Cap.RERANK.isEnabled()) {
throw new ParsingException(source, "RERANK is in preview and only available in SNAPSHOT build");
}
Expression queryText = expression(ctx.queryText);
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
if (queryTextLiteral.value() == null) {
throw new ParsingException(
source(ctx.queryText),
"Query text cannot be null or undefined in RERANK",
ctx.queryText.getText()
);
}
} else {
throw new ParsingException(
source(ctx.queryText),
"RERANK only support string as query text but [{}] cannot be used as string",
ctx.queryText.getText()
);
}
return p -> new Rerank(source, p, inferenceId(ctx.inferenceId), queryText, visitFields(ctx.fields()));
}
public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
if (ctx.identifier() != null) {
return new Literal(source(ctx), visitIdentifier(ctx.identifier()), KEYWORD);
}
if (expression(ctx.parameter()) instanceof Literal literalParam) {
if (literalParam.value() != null) {
return literalParam;
}
throw new ParsingException(
source(ctx.parameter()),
"Query parameter [{}] is null or undefined and cannot be used as inference id",
ctx.parameter().getText()
);
}
throw new ParsingException(
source(ctx.parameter()),
"Query parameter [{}] is not a string and cannot be used as inference id",
ctx.parameter().getText()
);
}
}

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@ -49,6 +50,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import java.util.ArrayList;
import java.util.List;
@ -81,6 +83,7 @@ public class PlanWritables {
MvExpand.ENTRY,
OrderBy.ENTRY,
Project.ENTRY,
Rerank.ENTRY,
TimeSeriesAggregate.ENTRY,
TopN.ENTRY
);
@ -106,6 +109,7 @@ public class PlanWritables {
LocalSourceExec.ENTRY,
MvExpandExec.ENTRY,
ProjectExec.ENTRY,
RerankExec.ENTRY,
ShowExec.ENTRY,
SubqueryExec.ENTRY,
TimeSeriesAggregateExec.ENTRY,

View File

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.logical.inference;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import java.io.IOException;
import java.util.Objects;
public abstract class InferencePlan extends UnaryPlan {
private final Expression inferenceId;
protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
super(source, child);
this.inferenceId = inferenceId;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(child());
out.writeNamedWriteable(inferenceId());
}
public Expression inferenceId() {
return inferenceId;
}
@Override
public boolean expressionsResolved() {
return inferenceId.resolved();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
InferencePlan other = (InferencePlan) o;
return Objects.equals(inferenceId(), other.inferenceId());
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), inferenceId());
}
public abstract TaskType taskType();
public abstract InferencePlan withInferenceId(Expression newInferenceId);
public InferencePlan withInferenceResolutionError(String inferenceId, String error) {
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
}
}

View File

@ -0,0 +1,190 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.logical.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.QueryPlan;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
import org.elasticsearch.xpack.esql.plan.logical.SurrogateLogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
public class Rerank extends InferencePlan implements SortAgnostic, SurrogateLogicalPlan {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
private final Attribute scoreAttribute;
private final Expression queryText;
private final List<Alias> rerankFields;
private List<Attribute> lazyOutput;
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
super(source, child, inferenceId);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
}
public Rerank(
Source source,
LogicalPlan child,
Expression inferenceId,
Expression queryText,
List<Alias> rerankFields,
Attribute scoreAttribute
) {
super(source, child, inferenceId);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = scoreAttribute;
}
public Rerank(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(LogicalPlan.class),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class),
in.readCollectionAsList(Alias::new),
in.readNamedWriteable(Attribute.class)
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteable(queryText);
out.writeCollection(rerankFields());
out.writeNamedWriteable(scoreAttribute);
}
public Expression queryText() {
return queryText;
}
public List<Alias> rerankFields() {
return rerankFields;
}
public Attribute scoreAttribute() {
return scoreAttribute;
}
@Override
public TaskType taskType() {
return TaskType.RERANK;
}
@Override
public Rerank withInferenceId(Expression newInferenceId) {
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
}
public Rerank withRerankFields(List<Alias> newRerankFields) {
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
}
public Rerank withScoreAttribute(Attribute newScoreAttribute) {
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
@Override
public UnaryPlan replaceChild(LogicalPlan newChild) {
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
}
@Override
protected AttributeSet computeReferences() {
AttributeSet.Builder refs = computeReferences(rerankFields).asBuilder();
if (planHasAttribute(child(), scoreAttribute)) {
refs.add(scoreAttribute);
}
return refs.build();
}
public static AttributeSet computeReferences(List<Alias> fields) {
AttributeSet rerankFields = AttributeSet.of(asAttributes(fields));
return Expressions.references(fields).subtract(rerankFields);
}
@Override
public boolean expressionsResolved() {
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
}
@Override
protected NodeInfo<? extends LogicalPlan> info() {
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
Rerank rerank = (Rerank) o;
return Objects.equals(queryText, rerank.queryText)
&& Objects.equals(rerankFields, rerank.rerankFields)
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
}
@Override
public LogicalPlan surrogate() {
Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
return new OrderBy(source(), this, List.of(sortOrder));
}
@Override
public List<Attribute> output() {
if (lazyOutput == null) {
lazyOutput = planHasAttribute(child(), scoreAttribute)
? child().output()
: mergeOutputAttributes(List.of(scoreAttribute), child().output());
}
return lazyOutput;
}
public static boolean planHasAttribute(QueryPlan<?> plan, Attribute attribute) {
return plan.outputSet().stream().anyMatch(attr -> attr.equals(attribute));
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.physical.inference;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import java.io.IOException;
import java.util.Objects;
public abstract class InferenceExec extends UnaryExec {
private final Expression inferenceId;
protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
super(source, child);
this.inferenceId = inferenceId;
}
public Expression inferenceId() {
return inferenceId;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(child());
out.writeNamedWriteable(inferenceId());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
InferenceExec that = (InferenceExec) o;
return inferenceId.equals(that.inferenceId);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), inferenceId());
}
}

View File

@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.physical.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
import static org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.planHasAttribute;
public class RerankExec extends InferenceExec {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
PhysicalPlan.class,
"RerankExec",
RerankExec::new
);
private final Expression queryText;
private final List<Alias> rerankFields;
private final Attribute scoreAttribute;
public RerankExec(
Source source,
PhysicalPlan child,
Expression inferenceId,
Expression queryText,
List<Alias> rerankFields,
Attribute scoreAttribute
) {
super(source, child, inferenceId);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = scoreAttribute;
}
public RerankExec(StreamInput in) throws IOException {
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(PhysicalPlan.class),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class),
in.readCollectionAsList(Alias::new),
in.readNamedWriteable(Attribute.class)
);
}
public Expression queryText() {
return queryText;
}
public List<Alias> rerankFields() {
return rerankFields;
}
public Attribute scoreAttribute() {
return scoreAttribute;
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeNamedWriteable(queryText());
out.writeCollection(rerankFields());
out.writeNamedWriteable(scoreAttribute);
}
@Override
protected NodeInfo<? extends PhysicalPlan> info() {
return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
}
@Override
public UnaryExec replaceChild(PhysicalPlan newChild) {
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
}
@Override
public List<Attribute> output() {
if (planHasAttribute(child(), scoreAttribute)) {
return child().output();
}
return mergeOutputAttributes(List.of(scoreAttribute), child().output());
}
@Override
protected AttributeSet computeReferences() {
AttributeSet.Builder refs = Rerank.computeReferences(rerankFields).asBuilder();
if (planHasAttribute(child(), scoreAttribute)) {
refs.add(scoreAttribute);
}
return refs.build();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
RerankExec rerank = (RerankExec) o;
return Objects.equals(queryText, rerank.queryText)
&& Objects.equals(rerankFields, rerank.rerankFields)
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
}
}

View File

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.planner;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.Describable;
@ -23,6 +24,7 @@ import org.elasticsearch.compute.operator.ColumnExtractOperator;
import org.elasticsearch.compute.operator.ColumnLoadOperator;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
import org.elasticsearch.compute.operator.LimitOperator;
@ -57,6 +59,7 @@ import org.elasticsearch.logging.Logger;
import org.elasticsearch.node.Node;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
@ -78,6 +81,9 @@ import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.RerankOperator;
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.ChangePointExec;
@ -104,6 +110,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.score.ScoreMapper;
@ -112,6 +119,7 @@ import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -144,6 +152,7 @@ public class LocalExecutionPlanner {
private final Supplier<ExchangeSink> exchangeSinkSupplier;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
private final InferenceRunner inferenceRunner;
private final PhysicalOperationProviders physicalOperationProviders;
private final List<ShardContext> shardContexts;
@ -159,6 +168,7 @@ public class LocalExecutionPlanner {
Supplier<ExchangeSink> exchangeSinkSupplier,
EnrichLookupService enrichLookupService,
LookupFromIndexService lookupFromIndexService,
InferenceRunner inferenceRunner,
PhysicalOperationProviders physicalOperationProviders,
List<ShardContext> shardContexts
) {
@ -174,6 +184,7 @@ public class LocalExecutionPlanner {
this.exchangeSinkSupplier = exchangeSinkSupplier;
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
this.inferenceRunner = inferenceRunner;
this.physicalOperationProviders = physicalOperationProviders;
this.shardContexts = shardContexts;
}
@ -242,6 +253,8 @@ public class LocalExecutionPlanner {
return planLimit(limit, context);
} else if (node instanceof MvExpandExec mvExpand) {
return planMvExpand(mvExpand, context);
} else if (node instanceof RerankExec rerank) {
return planRerank(rerank, context);
} else if (node instanceof ChangePointExec changePoint) {
return planChangePoint(changePoint, context);
}
@ -543,6 +556,36 @@ public class LocalExecutionPlanner {
);
}
private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(rerank.child(), context);
Map<ColumnInfoImpl, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers = new LinkedHashMap<>();
for (var rerankField : rerank.rerankFields()) {
rerankFieldsEvaluatorSuppliers.put(
new ColumnInfoImpl(rerankField.name(), rerankField.dataType(), null),
EvalMapper.toEvaluator(context.foldCtx(), rerankField.child(), source.layout)
);
}
XContentRowEncoder.Factory rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
Layout outputLayout = source.layout;
if (source.layout.get(rerank.scoreAttribute().id()) == null) {
outputLayout = source.layout.builder().append(rerank.scoreAttribute()).build();
}
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
return source.with(
new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel),
outputLayout
);
}
private PhysicalOperation planHashJoin(HashJoinExec join, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(join.left(), context);
int positionsChannel = source.layout.numberOfChannels();

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
@ -38,6 +39,7 @@ import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import java.util.ArrayList;
import java.util.List;
@ -173,6 +175,18 @@ public class Mapper {
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
}
if (unary instanceof Rerank rerank) {
mappedChild = addExchangeForFragment(rerank, mappedChild);
return new RerankExec(
rerank.source(),
mappedChild,
rerank.inferenceId(),
rerank.queryText(),
rerank.rerankFields(),
rerank.scoreAttribute()
);
}
//
// Pipeline operators
//

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@ -42,6 +43,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
import java.util.List;
@ -86,6 +88,17 @@ class MapperUtils {
return new GrokExec(grok.source(), child, grok.input(), grok.parser(), grok.extractedFields());
}
if (p instanceof Rerank rerank) {
return new RerankExec(
rerank.source(),
child,
rerank.inferenceId(),
rerank.queryText(),
rerank.rerankFields(),
rerank.scoreAttribute()
);
}
if (p instanceof Enrich enrich) {
return new EnrichExec(
enrich.source(),

View File

@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
@ -125,6 +126,7 @@ public class ComputeService {
private final DriverTaskRunner driverRunner;
private final EnrichLookupService enrichLookupService;
private final LookupFromIndexService lookupFromIndexService;
private final InferenceRunner inferenceRunner;
private final ClusterService clusterService;
private final AtomicLong childSessionIdGenerator = new AtomicLong();
private final DataNodeComputeHandler dataNodeComputeHandler;
@ -133,25 +135,24 @@ public class ComputeService {
@SuppressWarnings("this-escape")
public ComputeService(
SearchService searchService,
TransportService transportService,
ExchangeService exchangeService,
TransportActionServices transportActionServices,
EnrichLookupService enrichLookupService,
LookupFromIndexService lookupFromIndexService,
ClusterService clusterService,
ThreadPool threadPool,
BigArrays bigArrays,
BlockFactory blockFactory
) {
this.searchService = searchService;
this.transportService = transportService;
this.searchService = transportActionServices.searchService();
this.transportService = transportActionServices.transportService();
this.exchangeService = transportActionServices.exchangeService();
this.bigArrays = bigArrays.withCircuitBreaking();
this.blockFactory = blockFactory;
var esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH);
this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor);
this.enrichLookupService = enrichLookupService;
this.lookupFromIndexService = lookupFromIndexService;
this.clusterService = clusterService;
this.inferenceRunner = transportActionServices.inferenceRunner();
this.clusterService = transportActionServices.clusterService();
this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor);
this.clusterComputeHandler = new ClusterComputeHandler(
this,
@ -160,7 +161,6 @@ public class ComputeService {
esqlExecutor,
dataNodeComputeHandler
);
this.exchangeService = exchangeService;
}
public void execute(
@ -428,6 +428,7 @@ public class ComputeService {
context.exchangeSinkSupplier(),
enrichLookupService,
lookupFromIndexService,
inferenceRunner,
new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()),
contexts
);

View File

@ -13,6 +13,7 @@ import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
public record TransportActionServices(
TransportService transportService,
@ -20,5 +21,6 @@ public record TransportActionServices(
ExchangeService exchangeService,
ClusterService clusterService,
IndexNameExpressionResolver indexNameExpressionResolver,
UsageService usageService
UsageService usageService,
InferenceRunner inferenceRunner
) {}

View File

@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner;
import org.elasticsearch.xpack.esql.session.Result;
@ -126,17 +127,7 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
bigArrays,
blockFactoryProvider.blockFactory()
);
this.computeService = new ComputeService(
searchService,
transportService,
exchangeService,
enrichLookupService,
lookupFromIndexService,
clusterService,
threadPool,
bigArrays,
blockFactoryProvider.blockFactory()
);
this.asyncTaskManagementService = new AsyncTaskManagementService<>(
XPackPlugin.ASYNC_RESULTS_INDEX,
client,
@ -159,8 +150,19 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
exchangeService,
clusterService,
indexNameExpressionResolver,
usageService
usageService,
new InferenceRunner(client)
);
this.computeService = new ComputeService(
services,
enrichLookupService,
lookupFromIndexService,
threadPool,
bigArrays,
blockFactoryProvider.blockFactory()
);
defaultAllowPartialResults = EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS.get(clusterService.getSettings());
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS, v -> defaultAllowPartialResults = v);

View File

@ -50,6 +50,8 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.index.MappingException;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
@ -63,6 +65,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@ -119,6 +122,7 @@ public class EsqlSession {
private final PlanTelemetry planTelemetry;
private final IndicesExpressionGrouper indicesExpressionGrouper;
private Set<String> configuredClusters;
private final InferenceRunner inferenceRunner;
public EsqlSession(
String sessionId,
@ -146,6 +150,7 @@ public class EsqlSession {
this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
this.planTelemetry = planTelemetry;
this.indicesExpressionGrouper = indicesExpressionGrouper;
this.inferenceRunner = services.inferenceRunner();
this.preMapper = new PreMapper(services);
}
@ -335,7 +340,7 @@ public class EsqlSession {
Function<PreAnalysisResult, LogicalPlan> analyzeAction = (l) -> {
Analyzer analyzer = new Analyzer(
new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution),
new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution, l.inferenceResolution),
verifier
);
LogicalPlan plan = analyzer.analyze(parsed);
@ -367,7 +372,9 @@ public class EsqlSession {
var listener = SubscribableListener.<EnrichResolution>newForked(
l -> enrichPolicyResolver.resolvePolicies(targetClusters, unresolvedPolicies, l)
).<PreAnalysisResult>andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l));
)
.<PreAnalysisResult>andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
.<PreAnalysisResult>andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
// first resolve the lookup indices, then the main indices
for (var index : preAnalysis.lookupIndices) {
listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); });
@ -580,6 +587,14 @@ public class EsqlSession {
}
}
private void resolveInferences(
List<InferencePlan> inferencePlans,
PreAnalysisResult preAnalysisResult,
ActionListener<PreAnalysisResult> l
) {
inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
}
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set<String> enrichPolicyMatchFields, PreAnalysisResult result) {
if (false == parsed.anyMatch(plan -> plan instanceof Aggregate || plan instanceof Project)) {
// no explicit columns selection, for example "from employees"
@ -746,18 +761,44 @@ public class EsqlSession {
Map<String, IndexResolution> lookupIndices,
EnrichResolution enrichResolution,
Set<String> fieldNames,
Set<String> wildcardJoinIndices
Set<String> wildcardJoinIndices,
InferenceResolution inferenceResolution
) {
PreAnalysisResult(EnrichResolution newEnrichResolution) {
this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of());
this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of(), InferenceResolution.EMPTY);
}
PreAnalysisResult withEnrichResolution(EnrichResolution newEnrichResolution) {
return new PreAnalysisResult(indices(), lookupIndices(), newEnrichResolution, fieldNames(), wildcardJoinIndices());
return new PreAnalysisResult(
indices(),
lookupIndices(),
newEnrichResolution,
fieldNames(),
wildcardJoinIndices(),
inferenceResolution()
);
}
PreAnalysisResult withInferenceResolution(InferenceResolution newInferenceResolution) {
return new PreAnalysisResult(
indices(),
lookupIndices(),
enrichResolution(),
fieldNames(),
wildcardJoinIndices(),
newInferenceResolution
);
}
PreAnalysisResult withIndexResolution(IndexResolution newIndexResolution) {
return new PreAnalysisResult(newIndexResolution, lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices());
return new PreAnalysisResult(
newIndexResolution,
lookupIndices(),
enrichResolution(),
fieldNames(),
wildcardJoinIndices(),
inferenceResolution()
);
}
PreAnalysisResult addLookupIndexResolution(String index, IndexResolution newIndexResolution) {
@ -766,11 +807,25 @@ public class EsqlSession {
}
PreAnalysisResult withFieldNames(Set<String> newFields) {
return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), newFields, wildcardJoinIndices());
return new PreAnalysisResult(
indices(),
lookupIndices(),
enrichResolution(),
newFields,
wildcardJoinIndices(),
inferenceResolution()
);
}
public PreAnalysisResult withWildcardJoinIndices(Set<String> wildcardJoinIndices) {
return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices);
return new PreAnalysisResult(
indices(),
lookupIndices(),
enrichResolution(),
fieldNames(),
wildcardJoinIndices,
inferenceResolution()
);
}
}
}

View File

@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
@ -119,6 +120,7 @@ import static org.elasticsearch.xpack.esql.CsvTestUtils.loadPageFromCsv;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
@ -261,6 +263,10 @@ public class CsvTests extends ESTestCase {
"enrich can't load fields in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.ENRICH_LOAD.capabilityName())
);
assumeFalse(
"can't use rereank in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.RERANK.capabilityName())
);
assumeFalse(
"can't use match in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MATCH_OPERATOR_COLON.capabilityName())
@ -478,7 +484,10 @@ public class CsvTests extends ESTestCase {
private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.MultiIndexTestDataset datasets) {
var indexResolution = loadIndexResolution(datasets);
var enrichPolicies = loadEnrichPolicies();
var analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies), TEST_VERIFIER);
var analyzer = new Analyzer(
new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies, emptyInferenceResolution()),
TEST_VERIFIER
);
LogicalPlan plan = analyzer.analyze(parsed);
plan.setAnalyzed();
LOGGER.debug("Analyzed plan:\n{}", plan);
@ -666,6 +675,7 @@ public class CsvTests extends ESTestCase {
() -> exchangeSink.createExchangeSink(() -> {}),
Mockito.mock(EnrichLookupService.class),
Mockito.mock(LookupFromIndexService.class),
Mockito.mock(InferenceRunner.class),
physicalOperationProviders,
List.of()
);

View File

@ -8,12 +8,15 @@
package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.parser.QueryParams;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@ -29,6 +32,7 @@ import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE;
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.RANGE_TYPE;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
public final class AnalyzerTestUtils {
@ -57,7 +61,8 @@ public final class AnalyzerTestUtils {
new EsqlFunctionRegistry(),
indexResolution,
defaultLookupResolution(),
defaultEnrichResolution()
defaultEnrichResolution(),
emptyInferenceResolution()
),
verifier
);
@ -70,7 +75,8 @@ public final class AnalyzerTestUtils {
new EsqlFunctionRegistry(),
indexResolution,
lookupResolution,
defaultEnrichResolution()
defaultEnrichResolution(),
defaultInferenceResolution()
),
verifier
);
@ -78,7 +84,14 @@ public final class AnalyzerTestUtils {
public static Analyzer analyzer(IndexResolution indexResolution, Verifier verifier, Configuration config) {
return new Analyzer(
new AnalyzerContext(config, new EsqlFunctionRegistry(), indexResolution, defaultLookupResolution(), defaultEnrichResolution()),
new AnalyzerContext(
config,
new EsqlFunctionRegistry(),
indexResolution,
defaultLookupResolution(),
defaultEnrichResolution(),
defaultInferenceResolution()
),
verifier
);
}
@ -90,7 +103,8 @@ public final class AnalyzerTestUtils {
new EsqlFunctionRegistry(),
analyzerDefaultMapping(),
defaultLookupResolution(),
defaultEnrichResolution()
defaultEnrichResolution(),
defaultInferenceResolution()
),
verifier
);
@ -162,6 +176,14 @@ public final class AnalyzerTestUtils {
return enrichResolution;
}
public static InferenceResolution defaultInferenceResolution() {
return InferenceResolution.builder()
.withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
.withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
.withError("error-inference-id", "error with inference resolution")
.build();
}
public static void loadEnrichPolicyResolution(
EnrichResolution enrich,
String policyType,

View File

@ -45,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.index.EsIndex;
@ -67,6 +68,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.elasticsearch.xpack.esql.session.IndexResolver;
@ -86,6 +88,7 @@ import static org.elasticsearch.test.MapMatcher.assertMap;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsIdentifier;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsPattern;
@ -108,6 +111,7 @@ import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.matchesRegex;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
//@TestLogging(value = "org.elasticsearch.xpack.esql.analysis:TRACE", reason = "debug")
@ -1626,7 +1630,13 @@ public class AnalyzerTests extends ESTestCase {
enrichResolution.addError("languages", Enrich.Mode.ANY, "error-2");
enrichResolution.addError("foo", Enrich.Mode.ANY, "foo-error-101");
AnalyzerContext context = new AnalyzerContext(configuration("from test"), new EsqlFunctionRegistry(), testIndex, enrichResolution);
AnalyzerContext context = new AnalyzerContext(
configuration("from test"),
new EsqlFunctionRegistry(),
testIndex,
enrichResolution,
emptyInferenceResolution()
);
Analyzer analyzer = new Analyzer(context, TEST_VERIFIER);
{
LogicalPlan plan = analyze("from test | EVAL x = to_string(languages) | ENRICH _coordinator:languages ON x", analyzer);
@ -1776,7 +1786,13 @@ public class AnalyzerTests extends ESTestCase {
languageIndex.get().mapping()
)
);
AnalyzerContext context = new AnalyzerContext(configuration(query), new EsqlFunctionRegistry(), testIndex, enrichResolution);
AnalyzerContext context = new AnalyzerContext(
configuration(query),
new EsqlFunctionRegistry(),
testIndex,
enrichResolution,
emptyInferenceResolution()
);
Analyzer analyzer = new Analyzer(context, TEST_VERIFIER);
LogicalPlan plan = analyze(query, analyzer);
var limit = as(plan, Limit.class);
@ -2172,7 +2188,8 @@ public class AnalyzerTests extends ESTestCase {
new EsqlFunctionRegistry(),
analyzerDefaultMapping(),
Map.of("foobar", missingLookupIndex),
defaultEnrichResolution()
defaultEnrichResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);
@ -3267,6 +3284,193 @@ public class AnalyzerTests extends ESTestCase {
assertThat(esRelation.output(), equalTo(NO_FIELDS));
}
public void testResolveRerankInferenceId() {
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
{
LogicalPlan plan = analyze(
" FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
"mapping-books.json"
);
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
}
{
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `completion-inference-id`",
"mapping-books.json"
)
);
assertThat(
ve.getMessage(),
containsString(
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a Rerank command. "
+ "Only inference endpoints with the task type [rerank] are supported"
)
);
}
{
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `error-inference-id`",
"mapping-books.json"
)
);
assertThat(ve.getMessage(), containsString("error with inference resolution"));
}
{
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `unknown-inference-id`",
"mapping-books.json"
)
);
assertThat(ve.getMessage(), containsString("unresolved inference [unknown-inference-id]"));
}
}
public void testResolveRerankFields() {
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
{
// Single field.
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| KEEP description, title, year, _score
| DROP description
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
EsqlProject keep = as(rerank.child(), EsqlProject.class);
EsqlProject drop = as(keep.child(), EsqlProject.class);
Filter filter = as(drop.child(), Filter.class);
EsRelation relation = as(filter.child(), EsRelation.class);
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
assertThat(titleAttribute, notNullValue());
assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
assertThat(
rerank.scoreAttribute(),
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
);
}
{
// Multiple fields.
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"food"
| RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year WITH `reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
Filter filter = as(rerank.child(), Filter.class);
EsRelation relation = as(filter.child(), EsRelation.class);
assertThat(rerank.queryText(), equalTo(string("food")));
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
assertThat(rerank.rerankFields(), hasSize(3));
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
assertThat(titleAttribute, notNullValue());
assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute)));
Attribute descriptionAttribute = relation.output()
.stream()
.filter(attribute -> attribute.name().equals("description"))
.findFirst()
.get();
assertThat(descriptionAttribute, notNullValue());
Alias descriptionAlias = rerank.rerankFields().get(1);
assertThat(descriptionAlias.name(), equalTo("description"));
assertThat(
as(descriptionAlias.child(), Substring.class).children(),
equalTo(List.of(descriptionAttribute, literal(0), literal(100)))
);
Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
assertThat(yearAttribute, notNullValue());
assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
assertThat(
rerank.scoreAttribute(),
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
);
}
{
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH `reranking-inference-id`",
"mapping-books.json"
)
);
assertThat(ve.getMessage(), containsString("Unknown column [missingField]"));
}
}
public void testResolveRerankScoreField() {
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
{
// When the metadata field is required in FROM, it is reused.
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
Filter filter = as(rerank.child(), Filter.class);
EsRelation relation = as(filter.child(), EsRelation.class);
Attribute metadataScoreAttribute = relation.output()
.stream()
.filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
.findFirst()
.get();
assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
assertThat(rerank.output(), hasItem(metadataScoreAttribute));
}
{
// When the metadata field is not required in FROM, it is added to the output of RERANK
LogicalPlan plan = analyze("""
FROM books
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
Filter filter = as(rerank.child(), Filter.class);
EsRelation relation = as(filter.child(), EsRelation.class);
assertThat(relation.output().stream().noneMatch(attr -> attr.name().equals(MetadataAttribute.SCORE)), is(true));
assertThat(rerank.scoreAttribute(), equalTo(MetadataAttribute.create(EMPTY, MetadataAttribute.SCORE)));
assertThat(rerank.output(), hasItem(rerank.scoreAttribute()));
}
}
@Override
protected IndexAnalyzers createDefaultIndexAnalyzers() {
return super.createDefaultIndexAnalyzers();

View File

@ -35,6 +35,7 @@ import java.util.List;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@ -45,7 +46,7 @@ public class ParsingTests extends ESTestCase {
private final IndexResolution defaultIndex = loadIndexResolution("mapping-basic.json");
private final Analyzer defaultAnalyzer = new Analyzer(
new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution()),
new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution(), emptyInferenceResolution()),
TEST_VERIFIER
);

View File

@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.telemetry.Metrics;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
import static org.hamcrest.Matchers.containsString;
@ -90,7 +91,13 @@ public class CheckLicenseTests extends ESTestCase {
private static Analyzer analyzer(EsqlFunctionRegistry registry, License.OperationMode operationMode) {
return new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, registry, analyzerDefaultMapping(), defaultEnrichResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
registry,
analyzerDefaultMapping(),
defaultEnrichResolution(),
emptyInferenceResolution()
),
new Verifier(new Metrics(new EsqlFunctionRegistry()), getLicenseState(operationMode))
);
}

View File

@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import java.util.List;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class InferenceRunnerTests extends ESTestCase {
public void testResolveInferenceIds() throws Exception {
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
List<InferencePlan> inferencePlans = List.of(mockInferencePlan("rerank-plan"));
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
assertNotNull(inferenceResolution);
assertThat(inferenceResolution.resolvedInferences(), contains(new ResolvedInference("rerank-plan", TaskType.RERANK)));
assertThat(inferenceResolution.hasError(), equalTo(false));
});
}
public void testResolveMultipleInferenceIds() throws Exception {
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
List<InferencePlan> inferencePlans = List.of(
mockInferencePlan("rerank-plan"),
mockInferencePlan("rerank-plan"),
mockInferencePlan("completion-plan")
);
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
assertNotNull(inferenceResolution);
assertThat(
inferenceResolution.resolvedInferences(),
contains(
new ResolvedInference("rerank-plan", TaskType.RERANK),
new ResolvedInference("completion-plan", TaskType.COMPLETION)
)
);
assertThat(inferenceResolution.hasError(), equalTo(false));
});
}
public void testResolveMissingInferenceIds() throws Exception {
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
List<InferencePlan> inferencePlans = List.of(mockInferencePlan("missing-plan"));
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
assertNotNull(inferenceResolution);
assertThat(inferenceResolution.resolvedInferences(), empty());
assertThat(inferenceResolution.hasError(), equalTo(true));
assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found"));
});
}
@SuppressWarnings({ "unchecked", "raw-types" })
private static Client mockClient() {
Client client = mock(Client.class);
doAnswer(i -> {
GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class);
ActionListener<ActionResponse> listener = (ActionListener<ActionResponse>) i.getArgument(2, ActionListener.class);
ActionResponse response = getInferenceModelResponse(request);
if (response == null) {
listener.onFailure(new ResourceNotFoundException("inference endpoint not found"));
} else {
listener.onResponse(response);
}
return null;
}).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any());
return client;
}
private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) {
GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class);
if (request.getInferenceEntityId().equals("rerank-plan")) {
when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("rerank-plan", TaskType.RERANK)));
return response;
}
if (request.getInferenceEntityId().equals("completion-plan")) {
when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("completion-plan", TaskType.COMPLETION)));
return response;
}
return null;
}
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {
return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class));
}
private static InferencePlan mockInferencePlan(String inferenceId) {
InferencePlan plan = mock(InferencePlan.class);
when(plan.inferenceId()).thenReturn(new Literal(Source.EMPTY, inferenceId, DataType.KEYWORD));
return plan;
}
}

View File

@ -0,0 +1,297 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
import org.elasticsearch.compute.test.OperatorTestCase;
import org.elasticsearch.compute.test.RandomBlock;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.hamcrest.Matcher;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RerankOperatorTests extends OperatorTestCase {
private static final String ESQL_TEST_EXECUTOR = "esql_test_executor";
private static final String SIMPLE_INFERENCE_ID = "test_reranker";
private static final String SIMPLE_QUERY = "query text";
private ThreadPool threadPool;
private List<ElementType> inputChannelElementTypes;
private XContentRowEncoder.Factory rowEncoderFactory;
private int scoreChannel;
@Before
private void initChannels() {
int channelCount = randomIntBetween(2, 10);
scoreChannel = randomIntBetween(0, channelCount - 1);
inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList());
rowEncoderFactory = mockRowEncoderFactory();
}
@Before
public void setThreadPool() {
int numThreads = randomBoolean() ? 1 : between(2, 16);
threadPool = new TestThreadPool(
"test",
new FixedExecutorBuilder(Settings.EMPTY, ESQL_TEST_EXECUTOR, numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT)
);
}
@After
public void shutdownThreadPool() {
terminate(threadPool);
}
@Override
protected Operator.OperatorFactory simple() {
InferenceRunner inferenceRunner = mockedSimpleInferenceRunner();
return new RerankOperator.Factory(inferenceRunner, SIMPLE_INFERENCE_ID, SIMPLE_QUERY, rowEncoderFactory, scoreChannel);
}
private InferenceRunner mockedSimpleInferenceRunner() {
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
1,
ActionListener.class
);
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
when(inferenceResponse.getResults()).thenReturn(
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
);
listener.onResponse(inferenceResponse);
return null;
}).when(inferenceRunner).doInference(any(), any());
return inferenceRunner;
}
private RankedDocsResults mockedRankedDocResults(InferenceAction.Request request) {
List<RankedDocsResults.RankedDoc> rankedDocs = new ArrayList<>();
for (int rank = 0; rank < request.getInput().size(); rank++) {
if (rank % 10 != 0) {
rankedDocs.add(new RankedDocsResults.RankedDoc(rank, 1f / rank, request.getInput().get(rank)));
}
}
return new RankedDocsResults(rankedDocs);
}
@Override
protected Matcher<String> expectedDescriptionOfSimple() {
return expectedToStringOfSimple();
}
@Override
protected Matcher<String> expectedToStringOfSimple() {
return equalTo(
"RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]"
);
}
@Override
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
@Override
protected int remaining() {
return size - currentPosition;
}
@Override
protected Page createPage(int positionOffset, int length) {
Block[] blocks = new Block[inputChannelElementTypes.size()];
try {
currentPosition += length;
for (int b = 0; b < inputChannelElementTypes.size(); b++) {
blocks[b] = RandomBlock.randomBlock(
blockFactory,
inputChannelElementTypes.get(b),
length,
randomBoolean(),
0,
10,
0,
10
).block();
}
return new Page(blocks);
} catch (Exception e) {
Releasables.closeExpectNoException(blocks);
throw (e);
}
}
};
}
/**
* Ensures that the Operator.Status of this operator has the standard fields.
*/
public void testOperatorStatus() throws IOException {
DriverContext driverContext = driverContext();
try (var operator = simple().get(driverContext)) {
AsyncOperator.Status status = asInstanceOf(AsyncOperator.Status.class, operator.status());
assertThat(status, notNullValue());
assertThat(status.receivedPages(), equalTo(0L));
assertThat(status.completedPages(), equalTo(0L));
assertThat(status.procesNanos(), greaterThanOrEqualTo(0L));
}
}
@Override
protected void assertSimpleOutput(List<Page> inputPages, List<Page> resultPages) {
assertThat(inputPages, hasSize(resultPages.size()));
for (int pageId = 0; pageId < inputPages.size(); pageId++) {
Page inputPage = inputPages.get(pageId);
Page resultPage = resultPages.get(pageId);
// Check all rows are present and the output shape is unchanged.
assertThat(inputPage.getPositionCount(), equalTo(resultPage.getPositionCount()));
assertThat(inputPage.getBlockCount(), equalTo(resultPage.getBlockCount()));
BytesRef readBuffer = new BytesRef();
for (int channel = 0; channel < inputPage.getBlockCount(); channel++) {
Block inputBlock = inputPage.getBlock(channel);
Block resultBlock = resultPage.getBlock(channel);
assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount()));
assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType()));
if (channel == scoreChannel) {
assertExpectedScore(asInstanceOf(DoubleBlock.class, resultBlock));
} else {
switch (inputBlock.elementType()) {
case BOOLEAN -> assertBlockContentEquals(inputBlock, resultBlock, BooleanBlock::getBoolean, BooleanBlock.class);
case INT -> assertBlockContentEquals(inputBlock, resultBlock, IntBlock::getInt, IntBlock.class);
case LONG -> assertBlockContentEquals(inputBlock, resultBlock, LongBlock::getLong, LongBlock.class);
case FLOAT -> assertBlockContentEquals(inputBlock, resultBlock, FloatBlock::getFloat, FloatBlock.class);
case DOUBLE -> assertBlockContentEquals(inputBlock, resultBlock, DoubleBlock::getDouble, DoubleBlock.class);
case BYTES_REF -> assertByteRefsBlockContentEquals(inputBlock, resultBlock, readBuffer);
default -> throw new AssertionError(
LoggerMessageFormat.format("Unexpected block type {}", inputBlock.elementType())
);
}
}
}
}
}
private int inputChannelCount() {
return inputChannelElementTypes.size();
}
private ElementType randomElementType(int channel) {
return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
}
private XContentRowEncoder.Factory mockRowEncoderFactory() {
XContentRowEncoder.Factory factory = mock(XContentRowEncoder.Factory.class);
doAnswer(factoryInvocation -> {
DriverContext driverContext = factoryInvocation.getArgument(0, DriverContext.class);
XContentRowEncoder rowEncoder = mock(XContentRowEncoder.class);
doAnswer(encoderInvocation -> {
Page inputPage = encoderInvocation.getArgument(0, Page.class);
return driverContext.blockFactory()
.newConstantBytesRefBlockWith(new BytesRef(randomRealisticUnicodeOfCodepointLength(4)), inputPage.getPositionCount());
}).when(rowEncoder).eval(any(Page.class));
return rowEncoder;
}).when(factory).get(any(DriverContext.class));
return factory;
}
private void assertExpectedScore(DoubleBlock scoreBlockResult) {
assertAllPositions(scoreBlockResult, (pos) -> {
if (pos % 10 == 0) {
assertThat(scoreBlockResult.isNull(pos), equalTo(true));
} else {
assertThat(scoreBlockResult.getValueCount(pos), equalTo(1));
assertThat(scoreBlockResult.getDouble(scoreBlockResult.getFirstValueIndex(pos)), equalTo((double) (1f / pos)));
}
});
}
<V extends Block, U> void assertBlockContentEquals(
Block input,
Block result,
BiFunction<V, Integer, U> valueReader,
Class<V> blockClass
) {
V inputBlock = asInstanceOf(blockClass, input);
V resultBlock = asInstanceOf(blockClass, result);
assertAllPositions(inputBlock, (pos) -> {
if (inputBlock.isNull(pos)) {
assertThat(resultBlock.isNull(pos), equalTo(inputBlock.isNull(pos)));
} else {
assertThat(resultBlock.getValueCount(pos), equalTo(inputBlock.getValueCount(pos)));
assertThat(resultBlock.getFirstValueIndex(pos), equalTo(inputBlock.getFirstValueIndex(pos)));
for (int i = 0; i < inputBlock.getValueCount(pos); i++) {
assertThat(
valueReader.apply(resultBlock, resultBlock.getFirstValueIndex(pos) + i),
equalTo(valueReader.apply(inputBlock, inputBlock.getFirstValueIndex(pos) + i))
);
}
}
});
}
private void assertAllPositions(Block block, Consumer<Integer> consumer) {
for (int pos = 0; pos < block.getPositionCount(); pos++) {
consumer.accept(pos);
}
}
private <V extends Block, U> void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) {
assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class);
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
public class ResolvedInferenceTests extends AbstractWireTestCase<ResolvedInference> {
@Override
protected ResolvedInference createTestInstance() {
return new ResolvedInference(randomIdentifier(), randomTaskType());
}
@Override
protected ResolvedInference mutateInstance(ResolvedInference instance) throws IOException {
if (randomBoolean()) {
return new ResolvedInference(randomValueOtherThan(instance.inferenceId(), ESTestCase::randomIdentifier), instance.taskType());
}
return new ResolvedInference(instance.inferenceId(), randomValueOtherThan(instance.taskType(), this::randomTaskType));
}
@Override
protected ResolvedInference copyInstance(ResolvedInference instance, TransportVersion version) throws IOException {
return copyInstance(instance, getNamedWriteableRegistry(), (out, v) -> v.writeTo(out), in -> new ResolvedInference(in), version);
}
private TaskType randomTaskType() {
return randomFrom(TaskType.values());
}
}

View File

@ -66,6 +66,8 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
@ -98,7 +100,13 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
emptyPolicyResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);
}
@ -449,7 +457,13 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
emptyPolicyResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);

View File

@ -101,6 +101,7 @@ import static org.elasticsearch.index.query.QueryBuilders.termQuery;
import static org.elasticsearch.index.query.QueryBuilders.termsQuery;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@ -184,7 +185,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
IndexResolution getIndexResult = IndexResolution.valid(test);
return new Analyzer(
new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution),
new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution, emptyInferenceResolution()),
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
);
}

View File

@ -155,6 +155,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptySource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
@ -248,7 +249,8 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
new EsqlFunctionRegistry(),
getIndexResult,
defaultLookupResolution(),
enrichResolution
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
@ -258,7 +260,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
EsIndex airports = new EsIndex("airports", mappingAirports, Map.of("airports", IndexMode.STANDARD));
IndexResolution getIndexResultAirports = IndexResolution.valid(airports);
analyzerAirports = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResultAirports,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
@ -267,7 +275,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
EsIndex types = new EsIndex("types", mappingTypes, Map.of("types", IndexMode.STANDARD));
IndexResolution getIndexResultTypes = IndexResolution.valid(types);
analyzerTypes = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResultTypes,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
@ -276,14 +290,26 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
EsIndex extra = new EsIndex("extra", mappingExtra, Map.of("extra", IndexMode.STANDARD));
IndexResolution getIndexResultExtra = IndexResolution.valid(extra);
analyzerExtra = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultExtra, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResultExtra,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
metricMapping = loadMapping("k8s-mappings.json");
var metricsIndex = IndexResolution.valid(new EsIndex("k8s", metricMapping, Map.of("k8s", IndexMode.TIME_SERIES)));
metricsAnalyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), metricsIndex, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
metricsIndex,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
@ -298,7 +324,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
)
);
multiIndexAnalyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), multiIndex, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
multiIndex,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);
}
@ -5268,7 +5300,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
EsIndex empty = new EsIndex("empty_test", emptyMap(), Map.of());
IndexResolution getIndexResultAirports = IndexResolution.valid(empty);
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResultAirports,
enrichResolution,
emptyInferenceResolution()
),
TEST_VERIFIER
);

View File

@ -164,6 +164,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_SEARCH_STATS;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
@ -356,7 +357,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
EsIndex index = new EsIndex(indexName, mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(index);
Analyzer analyzer = new Analyzer(
new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution),
new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution, emptyInferenceResolution()),
TEST_VERIFIER
);
return new TestDataSource(mapping, index, analyzer, stats);
@ -7673,6 +7674,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
() -> exchangeSinkHandler.createExchangeSink(() -> {}),
null,
null,
null,
new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null),
List.of()
);

View File

@ -40,6 +40,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;
@ -63,7 +64,8 @@ public class PropagateInlineEvalsTests extends ESTestCase {
new EsqlFunctionRegistry(),
getIndexResult,
defaultLookupResolution(),
new EnrichResolution()
new EnrichResolution(),
defaultInferenceResolution()
),
TEST_VERIFIER
);

View File

@ -30,6 +30,10 @@ public class GrammarInDevelopmentParsingTests extends ESTestCase {
parse("row a = 1 | match foo", "match");
}
public void testDevelopmentRerank() {
parse("row a = 1 | rerank \"foo\" ON title WITH reranker", "rerank");
}
void parse(String query, String errorMessage) {
ParsingException pe = expectThrows(ParsingException.class, () -> parser().createStatement(query));
assertThat(pe.getMessage(), containsString("mismatched input '" + errorMessage + "'"));

View File

@ -64,6 +64,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@ -3332,6 +3333,87 @@ public class StatementParserTests extends AbstractStatementParserTests {
expectError("explain [row x = 1", "line 1:19: missing ']' at '<EOF>'");
}
public void testRerankSingleField() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceID");
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testRerankMultipleFields() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH inferenceID");
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
assertThat(
rerank.rerankFields(),
equalTo(
List.of(
alias("title", attribute("title")),
alias("description", attribute("description")),
alias("authors_renamed", attribute("authors"))
)
)
);
}
public void testRerankComputedFields() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH inferenceID");
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
assertThat(
rerank.rerankFields(),
equalTo(
List.of(
alias("title", attribute("title")),
alias("short_description", function("SUBSTRING", List.of(attribute("description"), integer(0), integer(100))))
)
)
);
}
public void testRerankWithPositionalParameters() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker")));
var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testRerankWithNamedParameters() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker")));
var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testInvalidRerank() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
expectError("FROM foo* | RERANK ON title WITH inferenceId", "line 1:20: mismatched input 'ON' expecting {QUOTED_STRING");
expectError("FROM foo* | RERANK \"query text\" WITH inferenceId", "line 1:33: mismatched input 'WITH' expecting 'on'");
expectError("FROM foo* | RERANK \"query text\" ON title", "line 1:41: mismatched input '<EOF>' expecting {'and',");
}
static Alias alias(String name, Expression value) {
return new Alias(EMPTY, name, value);
}

View File

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.logical.inference;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.AliasTests;
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class RerankSerializationTests extends AbstractLogicalPlanSerializationTests<Rerank> {
@Override
protected Rerank createTestInstance() {
Source source = randomSource();
LogicalPlan child = randomChild(0);
return new Rerank(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
}
@Override
protected Rerank mutateInstance(Rerank instance) throws IOException {
LogicalPlan child = instance.child();
Expression inferenceId = instance.inferenceId();
Expression queryText = instance.queryText();
List<Alias> fields = instance.rerankFields();
switch (between(0, 3)) {
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankSerializationTests.randomIdentifier()));
case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankSerializationTests.randomIdentifier()));
case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
}
return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute());
}
@Override
protected boolean alwaysEmptySource() {
return true;
}
private List<Alias> randomFields() {
return randomList(0, 10, AliasTests::randomAlias);
}
private Literal string(String value) {
return new Literal(EMPTY, value, DataType.KEYWORD);
}
private Attribute scoreAttribute() {
return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plan.physical.inference;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.AliasTests;
import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class RerankExecSerializationTests extends AbstractPhysicalPlanSerializationTests<RerankExec> {
@Override
protected RerankExec createTestInstance() {
Source source = randomSource();
PhysicalPlan child = randomChild(0);
return new RerankExec(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
}
@Override
protected RerankExec mutateInstance(RerankExec instance) throws IOException {
PhysicalPlan child = instance.child();
Expression inferenceId = instance.inferenceId();
Expression queryText = instance.queryText();
List<Alias> fields = instance.rerankFields();
switch (between(0, 3)) {
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankExecSerializationTests.randomIdentifier()));
case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankExecSerializationTests.randomIdentifier()));
case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
}
return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute());
}
@Override
protected boolean alwaysEmptySource() {
return true;
}
private List<Alias> randomFields() {
return randomList(0, 10, AliasTests::randomAlias);
}
static Literal string(String value) {
return new Literal(EMPTY, value, DataType.KEYWORD);
}
private Attribute scoreAttribute() {
return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
}
}

View File

@ -50,6 +50,7 @@ import static java.util.Arrays.asList;
import static org.elasticsearch.index.query.QueryBuilders.rangeQuery;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@ -84,7 +85,13 @@ public class FilterTests extends ESTestCase {
mapper = new Mapper();
analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
EsqlTestUtils.emptyPolicyResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);
}

View File

@ -232,6 +232,7 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
null,
null,
null,
null,
esPhysicalOperationProviders(shardContexts),
shardContexts
);

View File

@ -13,7 +13,6 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
@ -28,6 +27,8 @@ import org.junit.BeforeClass;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.hamcrest.Matchers.containsString;
@ -46,7 +47,13 @@ public class QueryTranslatorTests extends ESTestCase {
IndexResolution getIndexResult = IndexResolution.valid(test);
return new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, new EnrichResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
emptyPolicyResolution(),
emptyInferenceResolution()
),
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
);
}

View File

@ -39,6 +39,7 @@ import java.util.Map;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
@ -191,7 +192,13 @@ public class ClusterRequestTests extends AbstractWireSerializingTestCase<Cluster
IndexResolution getIndexResult = IndexResolution.valid(test);
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
emptyPolicyResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));

View File

@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfigur
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
@ -292,7 +293,13 @@ public class DataNodeRequestSerializationTests extends AbstractWireSerializingTe
IndexResolution getIndexResult = IndexResolution.valid(test);
var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG, FoldContext.small()));
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
new AnalyzerContext(
EsqlTestUtils.TEST_CFG,
new EsqlFunctionRegistry(),
getIndexResult,
emptyPolicyResolution(),
emptyInferenceResolution()
),
TEST_VERIFIER
);
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));

View File

@ -86,7 +86,7 @@ public abstract class AbstractTestInferenceService implements InferenceService {
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}
@ -99,11 +99,15 @@ public abstract class AbstractTestInferenceService implements InferenceService {
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
}
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
return TestTaskSettings.fromMap(taskSettingsMap);
}
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);
@Override
@ -149,15 +153,15 @@ public abstract class AbstractTestInferenceService implements InferenceService {
TaskType taskType,
String service,
ServiceSettings serviceSettings,
TestTaskSettings taskSettings,
TaskSettings taskSettings,
TestSecretSettings secretSettings
) {
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
}
@Override
public TestTaskSettings getTaskSettings() {
return (TestTaskSettings) super.getTaskSettings();
public TaskSettings getTaskSettings() {
return super.getTaskSettings();
}
@Override

View File

@ -45,6 +45,11 @@ public class TestInferenceServicePlugin extends Plugin {
TestRerankingServiceExtension.TestServiceSettings.NAME,
TestRerankingServiceExtension.TestServiceSettings::new
),
new NamedWriteableRegistry.Entry(
TaskSettings.class,
TestRerankingServiceExtension.TestTaskSettings.NAME,
TestRerankingServiceExtension.TestTaskSettings::new
),
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
TestStreamingCompletionServiceExtension.TestServiceSettings.NAME,

View File

@ -27,6 +27,7 @@ import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
@ -43,6 +44,8 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random;
public class TestRerankingServiceExtension implements InferenceServiceExtension {
@Override
@ -84,11 +87,15 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
}
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
}
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
@ -107,13 +114,15 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
Map<String, Object> taskSettingsMap,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap);
switch (model.getConfigurations().getTaskType()) {
case ANY, RERANK -> listener.onResponse(makeResults(input));
case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
@ -151,7 +160,7 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
);
}
private RankedDocsResults makeResults(List<String> input) {
private RankedDocsResults makeResults(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
int totalResults = input.size();
try {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
@ -161,17 +170,19 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
} catch (NumberFormatException ex) {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;
float minScore = taskSettings.minScore();
float resultDiff = taskSettings.resultDiff();
for (int i = 0; i < input.size(); i++) {
results.add(
new RankedDocsResults.RankedDoc(
totalResults - 1 - i,
minScore + resultDiff * (totalResults - i),
input.get(totalResults - 1 - i)
)
);
float relevanceScore = minScore + resultDiff * (totalResults - i);
String inputText = input.get(totalResults - 1 - i);
if (taskSettings.useTextLength()) {
relevanceScore = 1f / inputText.length();
}
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
}
// Ensure result are sorted by descending score
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
return new RankedDocsResults(results);
}
}
@ -208,6 +219,77 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
}
}
public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings {
static final String NAME = "test_reranking_task_settings";
public static TestTaskSettings fromMap(Map<String, Object> map) {
boolean useTextLength = false;
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;
if (map.containsKey("use_text_length")) {
useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString());
}
if (map.containsKey("min_score")) {
minScore = Float.parseFloat(map.remove("min_score").toString());
}
if (map.containsKey("result_diff")) {
resultDiff = Float.parseFloat(map.remove("result_diff").toString());
}
return new TestTaskSettings(useTextLength, minScore, resultDiff);
}
public TestTaskSettings(StreamInput in) throws IOException {
this(in.readBoolean(), in.readFloat(), in.readFloat());
}
@Override
public boolean isEmpty() {
return false;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(useTextLength);
out.writeFloat(minScore);
out.writeFloat(resultDiff);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("use_text_length", useTextLength);
builder.field("min_score", minScore);
builder.field("result_diff", resultDiff);
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
}
@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettingsMap) {
TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap));
return new TestTaskSettings(
newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength,
newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore,
newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff
);
}
}
public record TestServiceSettings(String modelId) implements ServiceSettings {
static final String NAME = "test_reranking_service_settings";