Adding ES|QL RERANK command in snapshot builds (#123074)
This commit is contained in:
parent
8f38b13059
commit
a4a271415d
|
@ -0,0 +1,5 @@
|
|||
pr: 123074
|
||||
summary: Adding ES|QL Reranker command in snapshot builds
|
||||
area: Ranking
|
||||
type: feature
|
||||
issues: []
|
|
@ -369,6 +369,9 @@ tests:
|
|||
- class: org.elasticsearch.snapshots.SharedClusterSnapshotRestoreIT
|
||||
method: testDeletionOfFailingToRecoverIndexShouldStopRestore
|
||||
issue: https://github.com/elastic/elasticsearch/issues/126204
|
||||
- class: org.elasticsearch.xpack.esql.inference.RerankOperatorTests
|
||||
method: testSimpleCircuitBreaking
|
||||
issue: https://github.com/elastic/elasticsearch/issues/124337
|
||||
- class: org.elasticsearch.index.engine.ThreadPoolMergeSchedulerTests
|
||||
method: testSchedulerCloseWaitsForRunningMerge
|
||||
issue: https://github.com/elastic/elasticsearch/issues/125236
|
||||
|
|
|
@ -64,6 +64,10 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
|||
public static final ParseField TOP_N = new ParseField("top_n");
|
||||
public static final ParseField TIMEOUT = new ParseField("timeout");
|
||||
|
||||
public static Builder builder(String inferenceEntityId, TaskType taskType) {
|
||||
return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
|
||||
}
|
||||
|
||||
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
|
||||
static {
|
||||
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
// Generated from /Users/afoucret/git/elasticsearch/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4 by ANTLR 4.13.2
|
||||
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import org.antlr.v4.runtime.Lexer;
|
||||
import org.antlr.v4.runtime.CharStream;
|
||||
import org.antlr.v4.runtime.Token;
|
||||
import org.antlr.v4.runtime.TokenStream;
|
||||
import org.antlr.v4.runtime.*;
|
||||
import org.antlr.v4.runtime.atn.*;
|
||||
import org.antlr.v4.runtime.dfa.DFA;
|
||||
import org.antlr.v4.runtime.misc.*;
|
||||
|
||||
@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
|
||||
public class EsqlBaseLexer extends LexerConfig {
|
||||
static { RuntimeMetaData.checkVersion("4.13.2", RuntimeMetaData.VERSION); }
|
||||
|
||||
protected static final DFA[] _decisionToDFA;
|
||||
protected static final PredictionContextCache _sharedContextCache =
|
||||
new PredictionContextCache();
|
||||
public static final int
|
||||
LINE_COMMENT=1, MULTILINE_COMMENT=2, WS=3;
|
||||
public static String[] channelNames = {
|
||||
"DEFAULT_TOKEN_CHANNEL", "HIDDEN"
|
||||
};
|
||||
|
||||
public static String[] modeNames = {
|
||||
"DEFAULT_MODE"
|
||||
};
|
||||
|
||||
private static String[] makeRuleNames() {
|
||||
return new String[] {
|
||||
"LINE_COMMENT", "MULTILINE_COMMENT", "WS"
|
||||
};
|
||||
}
|
||||
public static final String[] ruleNames = makeRuleNames();
|
||||
|
||||
private static String[] makeLiteralNames() {
|
||||
return new String[] {
|
||||
};
|
||||
}
|
||||
private static final String[] _LITERAL_NAMES = makeLiteralNames();
|
||||
private static String[] makeSymbolicNames() {
|
||||
return new String[] {
|
||||
null, "LINE_COMMENT", "MULTILINE_COMMENT", "WS"
|
||||
};
|
||||
}
|
||||
private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
|
||||
public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #VOCABULARY} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public static final String[] tokenNames;
|
||||
static {
|
||||
tokenNames = new String[_SYMBOLIC_NAMES.length];
|
||||
for (int i = 0; i < tokenNames.length; i++) {
|
||||
tokenNames[i] = VOCABULARY.getLiteralName(i);
|
||||
if (tokenNames[i] == null) {
|
||||
tokenNames[i] = VOCABULARY.getSymbolicName(i);
|
||||
}
|
||||
|
||||
if (tokenNames[i] == null) {
|
||||
tokenNames[i] = "<INVALID>";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public String[] getTokenNames() {
|
||||
return tokenNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
public Vocabulary getVocabulary() {
|
||||
return VOCABULARY;
|
||||
}
|
||||
|
||||
|
||||
public EsqlBaseLexer(CharStream input) {
|
||||
super(input);
|
||||
_interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getGrammarFileName() { return "EsqlBaseLexer.g4"; }
|
||||
|
||||
@Override
|
||||
public String[] getRuleNames() { return ruleNames; }
|
||||
|
||||
@Override
|
||||
public String getSerializedATN() { return _serializedATN; }
|
||||
|
||||
@Override
|
||||
public String[] getChannelNames() { return channelNames; }
|
||||
|
||||
@Override
|
||||
public String[] getModeNames() { return modeNames; }
|
||||
|
||||
@Override
|
||||
public ATN getATN() { return _ATN; }
|
||||
|
||||
public static final String _serializedATN =
|
||||
"\u0004\u0000\u0003.\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
|
||||
"\u0007\u0001\u0002\u0002\u0007\u0002\u0001\u0000\u0001\u0000\u0001\u0000"+
|
||||
"\u0001\u0000\u0005\u0000\f\b\u0000\n\u0000\f\u0000\u000f\t\u0000\u0001"+
|
||||
"\u0000\u0003\u0000\u0012\b\u0000\u0001\u0000\u0003\u0000\u0015\b\u0000"+
|
||||
"\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"+
|
||||
"\u0001\u0001\u0005\u0001\u001e\b\u0001\n\u0001\f\u0001!\t\u0001\u0001"+
|
||||
"\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0002\u0004"+
|
||||
"\u0002)\b\u0002\u000b\u0002\f\u0002*\u0001\u0002\u0001\u0002\u0001\u001f"+
|
||||
"\u0000\u0003\u0001\u0001\u0003\u0002\u0005\u0003\u0001\u0000\u0002\u0002"+
|
||||
"\u0000\n\n\r\r\u0003\u0000\t\n\r\r 3\u0000\u0001\u0001\u0000\u0000\u0000"+
|
||||
"\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005\u0001\u0000\u0000\u0000"+
|
||||
"\u0001\u0007\u0001\u0000\u0000\u0000\u0003\u0018\u0001\u0000\u0000\u0000"+
|
||||
"\u0005(\u0001\u0000\u0000\u0000\u0007\b\u0005/\u0000\u0000\b\t\u0005/"+
|
||||
"\u0000\u0000\t\r\u0001\u0000\u0000\u0000\n\f\b\u0000\u0000\u0000\u000b"+
|
||||
"\n\u0001\u0000\u0000\u0000\f\u000f\u0001\u0000\u0000\u0000\r\u000b\u0001"+
|
||||
"\u0000\u0000\u0000\r\u000e\u0001\u0000\u0000\u0000\u000e\u0011\u0001\u0000"+
|
||||
"\u0000\u0000\u000f\r\u0001\u0000\u0000\u0000\u0010\u0012\u0005\r\u0000"+
|
||||
"\u0000\u0011\u0010\u0001\u0000\u0000\u0000\u0011\u0012\u0001\u0000\u0000"+
|
||||
"\u0000\u0012\u0014\u0001\u0000\u0000\u0000\u0013\u0015\u0005\n\u0000\u0000"+
|
||||
"\u0014\u0013\u0001\u0000\u0000\u0000\u0014\u0015\u0001\u0000\u0000\u0000"+
|
||||
"\u0015\u0016\u0001\u0000\u0000\u0000\u0016\u0017\u0006\u0000\u0000\u0000"+
|
||||
"\u0017\u0002\u0001\u0000\u0000\u0000\u0018\u0019\u0005/\u0000\u0000\u0019"+
|
||||
"\u001a\u0005*\u0000\u0000\u001a\u001f\u0001\u0000\u0000\u0000\u001b\u001e"+
|
||||
"\u0003\u0003\u0001\u0000\u001c\u001e\t\u0000\u0000\u0000\u001d\u001b\u0001"+
|
||||
"\u0000\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e!\u0001\u0000"+
|
||||
"\u0000\u0000\u001f \u0001\u0000\u0000\u0000\u001f\u001d\u0001\u0000\u0000"+
|
||||
"\u0000 \"\u0001\u0000\u0000\u0000!\u001f\u0001\u0000\u0000\u0000\"#\u0005"+
|
||||
"*\u0000\u0000#$\u0005/\u0000\u0000$%\u0001\u0000\u0000\u0000%&\u0006\u0001"+
|
||||
"\u0000\u0000&\u0004\u0001\u0000\u0000\u0000\')\u0007\u0001\u0000\u0000"+
|
||||
"(\'\u0001\u0000\u0000\u0000)*\u0001\u0000\u0000\u0000*(\u0001\u0000\u0000"+
|
||||
"\u0000*+\u0001\u0000\u0000\u0000+,\u0001\u0000\u0000\u0000,-\u0006\u0002"+
|
||||
"\u0000\u0000-\u0006\u0001\u0000\u0000\u0000\u0007\u0000\r\u0011\u0014"+
|
||||
"\u001d\u001f*\u0001\u0000\u0001\u0000";
|
||||
public static final ATN _ATN =
|
||||
new ATNDeserializer().deserialize(_serializedATN.toCharArray());
|
||||
static {
|
||||
_decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
|
||||
for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
|
||||
_decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
LINE_COMMENT=1
|
||||
MULTILINE_COMMENT=2
|
||||
WS=3
|
|
@ -52,6 +52,7 @@ import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.INLINESTA
|
|||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_LOOKUP_V12;
|
||||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.JOIN_PLANNING_V1;
|
||||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METADATA_FIELDS_REMOTE_TEST;
|
||||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
|
||||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.UNMAPPED_FIELDS;
|
||||
import static org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase.Mode.SYNC;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
|
@ -130,6 +131,8 @@ public class MultiClusterSpecIT extends EsqlSpecTestCase {
|
|||
assumeFalse("LOOKUP JOIN not yet supported in CCS", testCase.requiredCapabilities.contains(JOIN_LOOKUP_V12.capabilityName()));
|
||||
// Unmapped fields require a coorect capability response from every cluster, which isn't currently implemented.
|
||||
assumeFalse("UNMAPPED FIELDS not yet supported in CCS", testCase.requiredCapabilities.contains(UNMAPPED_FIELDS.capabilityName()));
|
||||
// Need to do additional developmnet to get CSS support for the rerank coammnd
|
||||
assumeFalse("RERANK not yet supported in CCS", testCase.requiredCapabilities.contains(RERANK.capabilityName()));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -66,8 +66,11 @@ import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
|
|||
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasRerankInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
|
||||
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.METRICS_COMMAND;
|
||||
|
@ -134,6 +137,10 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
|
|||
createInferenceEndpoint(client());
|
||||
}
|
||||
|
||||
if (supportsInferenceTestService() && clusterHasRerankInferenceEndpoint(client()) == false) {
|
||||
createRerankInferenceEndpoint(client());
|
||||
}
|
||||
|
||||
boolean supportsLookup = supportsIndexModeLookup();
|
||||
boolean supportsSourceMapping = supportsSourceFieldMapping();
|
||||
if (indexExists(availableDatasetsForEs(client(), supportsLookup, supportsSourceMapping).iterator().next().indexName()) == false) {
|
||||
|
@ -153,6 +160,7 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
|
|||
}
|
||||
|
||||
deleteInferenceEndpoint(client());
|
||||
deleteRerankInferenceEndpoint(client());
|
||||
}
|
||||
|
||||
public boolean logResults() {
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.qa.rest;
|
||||
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createRerankInferenceEndpoint;
|
||||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteRerankInferenceEndpoint;
|
||||
import static org.hamcrest.core.StringContains.containsString;
|
||||
|
||||
public class RestRerankTestCase extends ESRestTestCase {
|
||||
|
||||
@Before
|
||||
public void skipWhenRerankDisabled() throws IOException {
|
||||
assumeTrue(
|
||||
"Requires RERANK capability",
|
||||
EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.RERANK.capabilityName()))
|
||||
);
|
||||
}
|
||||
|
||||
@Before
|
||||
@After
|
||||
public void assertRequestBreakerEmpty() throws Exception {
|
||||
EsqlSpecTestCase.assertRequestBreakerEmpty();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUpInferenceEndpoint() throws IOException {
|
||||
createRerankInferenceEndpoint(adminClient());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUpTestIndex() throws IOException {
|
||||
Request request = new Request("PUT", "/rerank-test-index");
|
||||
request.setJsonEntity("""
|
||||
{
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"title": { "type": "text" },
|
||||
"author": { "type": "text" }
|
||||
}
|
||||
}
|
||||
}""");
|
||||
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
|
||||
|
||||
request = new Request("POST", "/rerank-test-index/_bulk");
|
||||
request.addParameter("refresh", "true");
|
||||
request.setJsonEntity("""
|
||||
{ "index": {"_id": 1} }
|
||||
{ "title": "The Future of Exploration", "author": "John Doe" }
|
||||
{ "index": {"_id": 2} }
|
||||
{ "title": "Deep Sea Exploration", "author": "Jane Smith" }
|
||||
{ "index": {"_id": 3} }
|
||||
{ "title": "History of Space Exploration", "author": "Alice Johnson" }
|
||||
""");
|
||||
assertEquals(200, client().performRequest(request).getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@After
|
||||
public void wipeData() throws IOException {
|
||||
try {
|
||||
adminClient().performRequest(new Request("DELETE", "/rerank-test-index"));
|
||||
} catch (ResponseException e) {
|
||||
// 404 here just means we had no indexes
|
||||
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
deleteRerankInferenceEndpoint(adminClient());
|
||||
}
|
||||
|
||||
public void testRerankWithSingleField() throws IOException {
|
||||
String query = """
|
||||
FROM rerank-test-index
|
||||
| WHERE match(title, "exploration")
|
||||
| RERANK "exploration" ON title WITH test_reranker
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
""";
|
||||
|
||||
Map<String, Object> result = runEsqlQuery(query);
|
||||
|
||||
var expectedValues = List.of(
|
||||
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
|
||||
List.of("John Doe", "The Future of Exploration", 0.02632d),
|
||||
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
|
||||
);
|
||||
|
||||
assertResultMap(result, defaultOutputColumns(), expectedValues);
|
||||
}
|
||||
|
||||
public void testRerankWithMultipleFields() throws IOException {
|
||||
String query = """
|
||||
FROM rerank-test-index
|
||||
| WHERE match(title, "exploration")
|
||||
| RERANK "exploration" ON title, author WITH test_reranker
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
""";
|
||||
|
||||
Map<String, Object> result = runEsqlQuery(query);
|
||||
;
|
||||
var expectedValues = List.of(
|
||||
List.of("Jane Smith", "Deep Sea Exploration", 0.01818d),
|
||||
List.of("John Doe", "The Future of Exploration", 0.01754d),
|
||||
List.of("Alice Johnson", "History of Space Exploration", 0.01515d)
|
||||
);
|
||||
|
||||
assertResultMap(result, defaultOutputColumns(), expectedValues);
|
||||
}
|
||||
|
||||
public void testRerankWithPositionalParams() throws IOException {
|
||||
String query = """
|
||||
FROM rerank-test-index
|
||||
| WHERE match(title, "exploration")
|
||||
| RERANK ? ON title WITH ?
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
""";
|
||||
|
||||
Map<String, Object> result = runEsqlQuery(query, "[\"exploration\", \"test_reranker\"]");
|
||||
|
||||
var expectedValues = List.of(
|
||||
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
|
||||
List.of("John Doe", "The Future of Exploration", 0.02632d),
|
||||
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
|
||||
);
|
||||
|
||||
assertResultMap(result, defaultOutputColumns(), expectedValues);
|
||||
}
|
||||
|
||||
public void testRerankWithNamedParams() throws IOException {
|
||||
String query = """
|
||||
FROM rerank-test-index
|
||||
| WHERE match(title, ?queryText)
|
||||
| RERANK ?queryText ON title WITH ?inferenceId
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
""";
|
||||
|
||||
Map<String, Object> result = runEsqlQuery(query, "[{\"queryText\": \"exploration\"}, {\"inferenceId\": \"test_reranker\"}]");
|
||||
|
||||
var expectedValues = List.of(
|
||||
List.of("Jane Smith", "Deep Sea Exploration", 0.02941d),
|
||||
List.of("John Doe", "The Future of Exploration", 0.02632d),
|
||||
List.of("Alice Johnson", "History of Space Exploration", 0.02381d)
|
||||
);
|
||||
|
||||
assertResultMap(result, defaultOutputColumns(), expectedValues);
|
||||
}
|
||||
|
||||
public void testRerankWithMissingInferenceId() {
|
||||
String query = """
|
||||
FROM rerank-test-index
|
||||
| WHERE match(title, "exploration")
|
||||
| RERANK "exploration" ON title WITH test_missing
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
""";
|
||||
|
||||
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));
|
||||
assertThat(re.getMessage(), containsString("Inference endpoint not found"));
|
||||
}
|
||||
|
||||
private static List<Map<String, String>> defaultOutputColumns() {
|
||||
return List.of(
|
||||
Map.of("name", "author", "type", "text"),
|
||||
Map.of("name", "title", "type", "text"),
|
||||
Map.of("name", "_score", "type", "double")
|
||||
);
|
||||
}
|
||||
|
||||
private Map<String, Object> runEsqlQuery(String query) throws IOException {
|
||||
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query);
|
||||
return RestEsqlTestCase.runEsqlSync(builder);
|
||||
}
|
||||
|
||||
private Map<String, Object> runEsqlQuery(String query, String params) throws IOException {
|
||||
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query).params(params);
|
||||
return RestEsqlTestCase.runEsqlSync(builder);
|
||||
}
|
||||
}
|
|
@ -397,6 +397,47 @@ public class CsvTestsDataLoader {
|
|||
return true;
|
||||
}
|
||||
|
||||
public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
|
||||
Request request = new Request("PUT", "_inference/rerank/test_reranker");
|
||||
request.setJsonEntity("""
|
||||
{
|
||||
"service": "test_reranking_service",
|
||||
"service_settings": {
|
||||
"model_id": "my_model",
|
||||
"api_key": "abc64"
|
||||
},
|
||||
"task_settings": {
|
||||
"use_text_length": true
|
||||
}
|
||||
}
|
||||
""");
|
||||
client.performRequest(request);
|
||||
}
|
||||
|
||||
public static void deleteRerankInferenceEndpoint(RestClient client) throws IOException {
|
||||
try {
|
||||
client.performRequest(new Request("DELETE", "_inference/rerank/test_reranker"));
|
||||
} catch (ResponseException e) {
|
||||
// 404 here means the endpoint was not created
|
||||
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean clusterHasRerankInferenceEndpoint(RestClient client) throws IOException {
|
||||
Request request = new Request("GET", "_inference/rerank/test_reranker");
|
||||
try {
|
||||
client.performRequest(request);
|
||||
} catch (ResponseException e) {
|
||||
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
|
||||
return false;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
|
||||
URL policyMapping = getResource("/" + policyFileName);
|
||||
String entity = readTextFile(policyMapping);
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.lucene.document.InetAddressPoint;
|
|||
import org.apache.lucene.sandbox.document.HalfFloatPoint;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.cluster.RemoteException;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
|
@ -68,6 +69,8 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Les
|
|||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
|
||||
import org.elasticsearch.xpack.esql.parser.QueryParam;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
|
||||
|
@ -149,6 +152,8 @@ import static org.hamcrest.Matchers.hasSize;
|
|||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public final class EsqlTestUtils {
|
||||
|
@ -376,9 +381,21 @@ public final class EsqlTestUtils {
|
|||
null,
|
||||
mock(ClusterService.class),
|
||||
mock(IndexNameExpressionResolver.class),
|
||||
null
|
||||
null,
|
||||
mockInferenceRunner()
|
||||
);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static InferenceRunner mockInferenceRunner() {
|
||||
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
|
||||
doAnswer(i -> {
|
||||
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
|
||||
return null;
|
||||
}).when(inferenceRunner).resolveInferenceIds(any(), any());
|
||||
|
||||
return inferenceRunner;
|
||||
}
|
||||
|
||||
private EsqlTestUtils() {}
|
||||
|
||||
public static Configuration configuration(QueryPragmas pragmas, String query) {
|
||||
|
@ -454,6 +471,10 @@ public final class EsqlTestUtils {
|
|||
return new EnrichResolution();
|
||||
}
|
||||
|
||||
public static InferenceResolution emptyInferenceResolution() {
|
||||
return InferenceResolution.EMPTY;
|
||||
}
|
||||
|
||||
public static SearchStats statsForExistingField(String... names) {
|
||||
return fieldMatchingExistOrMissing(true, names);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
// Note:
|
||||
// The "test_reranker" service scores the row from the inputText length and does not really score by relevance.
|
||||
// This makes the output more predictable which is helpful here.
|
||||
|
||||
|
||||
reranker using a single field
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| RERANK "war and peace" ON title WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
5327 | War and Peace | Leo Tolstoy | 0.03846
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
|
||||
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
|
||||
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
|
||||
;
|
||||
|
||||
|
||||
reranker using multiple fields
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| RERANK "war and peace" ON title, author WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
5327 | War and Peace | Leo Tolstoy | 0.02083
|
||||
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.01429
|
||||
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01136
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.00952
|
||||
;
|
||||
|
||||
|
||||
reranker after a limit
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| SORT _score DESC
|
||||
| LIMIT 3
|
||||
| RERANK "war and peace" ON title WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
5327 | War and Peace | Leo Tolstoy | 0.03846
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
|
||||
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
|
||||
;
|
||||
|
||||
|
||||
reranker before a limit
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| RERANK "war and peace" ON title WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| LIMIT 3
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
5327 | War and Peace | Leo Tolstoy | 0.03846
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
|
||||
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
|
||||
;
|
||||
|
||||
|
||||
reranker add the _score column when missing
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| RERANK "war and peace" ON title WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
5327 | War and Peace | Leo Tolstoy | 0.03846
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
|
||||
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.02083
|
||||
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
|
||||
;
|
||||
|
||||
|
||||
reranker using another sort order
|
||||
required_capability: rerank
|
||||
required_capability: match_operator_colon
|
||||
|
||||
FROM books
|
||||
| WHERE title:"war and peace" AND author:"Tolstoy"
|
||||
| RERANK "war and peace" ON title WITH test_reranker
|
||||
| KEEP book_no, title, author, _score
|
||||
| SORT author, title
|
||||
| LIMIT 3
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:text | author:text | _score:double
|
||||
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.02222
|
||||
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01515
|
||||
5327 | War and Peace | Leo Tolstoy | 0.03846
|
||||
;
|
||||
|
||||
|
||||
reranker after RRF
|
||||
required_capability: fork
|
||||
required_capability: rrf
|
||||
required_capability: match_operator_colon
|
||||
required_capability: rerank
|
||||
|
||||
FROM books METADATA _id, _index, _score
|
||||
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
|
||||
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
|
||||
| RRF
|
||||
| RERANK "Tolkien" ON title WITH test_reranker
|
||||
| LIMIT 2
|
||||
| KEEP book_no, title, author, _score
|
||||
| EVAL _score = ROUND(_score, 5)
|
||||
;
|
||||
|
||||
book_no:keyword | title:keyword | author:keyword | _score:double
|
||||
5335 | Letters of J R R Tolkien | J.R.R. Tolkien | 0.02632
|
||||
2130 | The J. R. R. Tolkien Audio Collection | [Christopher Tolkien, John Ronald Reuel Tolkien] | 0.01961
|
||||
;
|
|
@ -13,49 +13,49 @@ SORT=12
|
|||
STATS=13
|
||||
WHERE=14
|
||||
DEV_INLINESTATS=15
|
||||
FROM=16
|
||||
DEV_TIME_SERIES=17
|
||||
DEV_FORK=18
|
||||
JOIN_LOOKUP=19
|
||||
DEV_JOIN_FULL=20
|
||||
DEV_JOIN_LEFT=21
|
||||
DEV_JOIN_RIGHT=22
|
||||
DEV_LOOKUP=23
|
||||
MV_EXPAND=24
|
||||
DROP=25
|
||||
KEEP=26
|
||||
DEV_INSIST=27
|
||||
DEV_RRF=28
|
||||
RENAME=29
|
||||
SHOW=30
|
||||
UNKNOWN_CMD=31
|
||||
CHANGE_POINT_LINE_COMMENT=32
|
||||
CHANGE_POINT_MULTILINE_COMMENT=33
|
||||
CHANGE_POINT_WS=34
|
||||
ON=35
|
||||
WITH=36
|
||||
ENRICH_POLICY_NAME=37
|
||||
ENRICH_LINE_COMMENT=38
|
||||
ENRICH_MULTILINE_COMMENT=39
|
||||
ENRICH_WS=40
|
||||
ENRICH_FIELD_LINE_COMMENT=41
|
||||
ENRICH_FIELD_MULTILINE_COMMENT=42
|
||||
ENRICH_FIELD_WS=43
|
||||
SETTING=44
|
||||
SETTING_LINE_COMMENT=45
|
||||
SETTTING_MULTILINE_COMMENT=46
|
||||
SETTING_WS=47
|
||||
EXPLAIN_WS=48
|
||||
EXPLAIN_LINE_COMMENT=49
|
||||
EXPLAIN_MULTILINE_COMMENT=50
|
||||
PIPE=51
|
||||
QUOTED_STRING=52
|
||||
INTEGER_LITERAL=53
|
||||
DECIMAL_LITERAL=54
|
||||
BY=55
|
||||
AND=56
|
||||
ASC=57
|
||||
ASSIGN=58
|
||||
DEV_RERANK=16
|
||||
FROM=17
|
||||
DEV_TIME_SERIES=18
|
||||
DEV_FORK=19
|
||||
JOIN_LOOKUP=20
|
||||
DEV_JOIN_FULL=21
|
||||
DEV_JOIN_LEFT=22
|
||||
DEV_JOIN_RIGHT=23
|
||||
DEV_LOOKUP=24
|
||||
MV_EXPAND=25
|
||||
DROP=26
|
||||
KEEP=27
|
||||
DEV_INSIST=28
|
||||
DEV_RRF=29
|
||||
RENAME=30
|
||||
SHOW=31
|
||||
UNKNOWN_CMD=32
|
||||
CHANGE_POINT_LINE_COMMENT=33
|
||||
CHANGE_POINT_MULTILINE_COMMENT=34
|
||||
CHANGE_POINT_WS=35
|
||||
ENRICH_POLICY_NAME=36
|
||||
ENRICH_LINE_COMMENT=37
|
||||
ENRICH_MULTILINE_COMMENT=38
|
||||
ENRICH_WS=39
|
||||
ENRICH_FIELD_LINE_COMMENT=40
|
||||
ENRICH_FIELD_MULTILINE_COMMENT=41
|
||||
ENRICH_FIELD_WS=42
|
||||
SETTING=43
|
||||
SETTING_LINE_COMMENT=44
|
||||
SETTTING_MULTILINE_COMMENT=45
|
||||
SETTING_WS=46
|
||||
EXPLAIN_WS=47
|
||||
EXPLAIN_LINE_COMMENT=48
|
||||
EXPLAIN_MULTILINE_COMMENT=49
|
||||
PIPE=50
|
||||
QUOTED_STRING=51
|
||||
INTEGER_LITERAL=52
|
||||
DECIMAL_LITERAL=53
|
||||
AND=54
|
||||
AS=55
|
||||
ASC=56
|
||||
ASSIGN=57
|
||||
BY=58
|
||||
CAST_OP=59
|
||||
COLON=60
|
||||
COMMA=61
|
||||
|
@ -70,70 +70,71 @@ LIKE=69
|
|||
NOT=70
|
||||
NULL=71
|
||||
NULLS=72
|
||||
OR=73
|
||||
PARAM=74
|
||||
RLIKE=75
|
||||
TRUE=76
|
||||
EQ=77
|
||||
CIEQ=78
|
||||
NEQ=79
|
||||
LT=80
|
||||
LTE=81
|
||||
GT=82
|
||||
GTE=83
|
||||
PLUS=84
|
||||
MINUS=85
|
||||
ASTERISK=86
|
||||
SLASH=87
|
||||
PERCENT=88
|
||||
LEFT_BRACES=89
|
||||
RIGHT_BRACES=90
|
||||
DOUBLE_PARAMS=91
|
||||
NAMED_OR_POSITIONAL_PARAM=92
|
||||
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=93
|
||||
OPENING_BRACKET=94
|
||||
CLOSING_BRACKET=95
|
||||
LP=96
|
||||
RP=97
|
||||
UNQUOTED_IDENTIFIER=98
|
||||
QUOTED_IDENTIFIER=99
|
||||
EXPR_LINE_COMMENT=100
|
||||
EXPR_MULTILINE_COMMENT=101
|
||||
EXPR_WS=102
|
||||
METADATA=103
|
||||
UNQUOTED_SOURCE=104
|
||||
FROM_LINE_COMMENT=105
|
||||
FROM_MULTILINE_COMMENT=106
|
||||
FROM_WS=107
|
||||
FORK_WS=108
|
||||
FORK_LINE_COMMENT=109
|
||||
FORK_MULTILINE_COMMENT=110
|
||||
JOIN=111
|
||||
USING=112
|
||||
JOIN_LINE_COMMENT=113
|
||||
JOIN_MULTILINE_COMMENT=114
|
||||
JOIN_WS=115
|
||||
LOOKUP_LINE_COMMENT=116
|
||||
LOOKUP_MULTILINE_COMMENT=117
|
||||
LOOKUP_WS=118
|
||||
LOOKUP_FIELD_LINE_COMMENT=119
|
||||
LOOKUP_FIELD_MULTILINE_COMMENT=120
|
||||
LOOKUP_FIELD_WS=121
|
||||
MVEXPAND_LINE_COMMENT=122
|
||||
MVEXPAND_MULTILINE_COMMENT=123
|
||||
MVEXPAND_WS=124
|
||||
ID_PATTERN=125
|
||||
PROJECT_LINE_COMMENT=126
|
||||
PROJECT_MULTILINE_COMMENT=127
|
||||
PROJECT_WS=128
|
||||
AS=129
|
||||
RENAME_LINE_COMMENT=130
|
||||
RENAME_MULTILINE_COMMENT=131
|
||||
RENAME_WS=132
|
||||
INFO=133
|
||||
SHOW_LINE_COMMENT=134
|
||||
SHOW_MULTILINE_COMMENT=135
|
||||
SHOW_WS=136
|
||||
ON=73
|
||||
OR=74
|
||||
PARAM=75
|
||||
RLIKE=76
|
||||
TRUE=77
|
||||
WITH=78
|
||||
EQ=79
|
||||
CIEQ=80
|
||||
NEQ=81
|
||||
LT=82
|
||||
LTE=83
|
||||
GT=84
|
||||
GTE=85
|
||||
PLUS=86
|
||||
MINUS=87
|
||||
ASTERISK=88
|
||||
SLASH=89
|
||||
PERCENT=90
|
||||
LEFT_BRACES=91
|
||||
RIGHT_BRACES=92
|
||||
DOUBLE_PARAMS=93
|
||||
NAMED_OR_POSITIONAL_PARAM=94
|
||||
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=95
|
||||
OPENING_BRACKET=96
|
||||
CLOSING_BRACKET=97
|
||||
LP=98
|
||||
RP=99
|
||||
UNQUOTED_IDENTIFIER=100
|
||||
QUOTED_IDENTIFIER=101
|
||||
EXPR_LINE_COMMENT=102
|
||||
EXPR_MULTILINE_COMMENT=103
|
||||
EXPR_WS=104
|
||||
METADATA=105
|
||||
UNQUOTED_SOURCE=106
|
||||
FROM_LINE_COMMENT=107
|
||||
FROM_MULTILINE_COMMENT=108
|
||||
FROM_WS=109
|
||||
FORK_WS=110
|
||||
FORK_LINE_COMMENT=111
|
||||
FORK_MULTILINE_COMMENT=112
|
||||
JOIN=113
|
||||
USING=114
|
||||
JOIN_LINE_COMMENT=115
|
||||
JOIN_MULTILINE_COMMENT=116
|
||||
JOIN_WS=117
|
||||
LOOKUP_LINE_COMMENT=118
|
||||
LOOKUP_MULTILINE_COMMENT=119
|
||||
LOOKUP_WS=120
|
||||
LOOKUP_FIELD_LINE_COMMENT=121
|
||||
LOOKUP_FIELD_MULTILINE_COMMENT=122
|
||||
LOOKUP_FIELD_WS=123
|
||||
MVEXPAND_LINE_COMMENT=124
|
||||
MVEXPAND_MULTILINE_COMMENT=125
|
||||
MVEXPAND_WS=126
|
||||
ID_PATTERN=127
|
||||
PROJECT_LINE_COMMENT=128
|
||||
PROJECT_MULTILINE_COMMENT=129
|
||||
PROJECT_WS=130
|
||||
RENAME_LINE_COMMENT=131
|
||||
RENAME_MULTILINE_COMMENT=132
|
||||
RENAME_WS=133
|
||||
INFO=134
|
||||
SHOW_LINE_COMMENT=135
|
||||
SHOW_MULTILINE_COMMENT=136
|
||||
SHOW_WS=137
|
||||
'enrich'=5
|
||||
'explain'=6
|
||||
'dissect'=7
|
||||
|
@ -144,20 +145,19 @@ SHOW_WS=136
|
|||
'sort'=12
|
||||
'stats'=13
|
||||
'where'=14
|
||||
'from'=16
|
||||
'lookup'=19
|
||||
'mv_expand'=24
|
||||
'drop'=25
|
||||
'keep'=26
|
||||
'rename'=29
|
||||
'show'=30
|
||||
'on'=35
|
||||
'with'=36
|
||||
'|'=51
|
||||
'by'=55
|
||||
'and'=56
|
||||
'asc'=57
|
||||
'='=58
|
||||
'from'=17
|
||||
'lookup'=20
|
||||
'mv_expand'=25
|
||||
'drop'=26
|
||||
'keep'=27
|
||||
'rename'=30
|
||||
'show'=31
|
||||
'|'=50
|
||||
'and'=54
|
||||
'as'=55
|
||||
'asc'=56
|
||||
'='=57
|
||||
'by'=58
|
||||
'::'=59
|
||||
':'=60
|
||||
','=61
|
||||
|
@ -172,29 +172,30 @@ SHOW_WS=136
|
|||
'not'=70
|
||||
'null'=71
|
||||
'nulls'=72
|
||||
'or'=73
|
||||
'?'=74
|
||||
'rlike'=75
|
||||
'true'=76
|
||||
'=='=77
|
||||
'=~'=78
|
||||
'!='=79
|
||||
'<'=80
|
||||
'<='=81
|
||||
'>'=82
|
||||
'>='=83
|
||||
'+'=84
|
||||
'-'=85
|
||||
'*'=86
|
||||
'/'=87
|
||||
'%'=88
|
||||
'{'=89
|
||||
'}'=90
|
||||
'??'=91
|
||||
']'=95
|
||||
')'=97
|
||||
'metadata'=103
|
||||
'join'=111
|
||||
'USING'=112
|
||||
'as'=129
|
||||
'info'=133
|
||||
'on'=73
|
||||
'or'=74
|
||||
'?'=75
|
||||
'rlike'=76
|
||||
'true'=77
|
||||
'with'=78
|
||||
'=='=79
|
||||
'=~'=80
|
||||
'!='=81
|
||||
'<'=82
|
||||
'<='=83
|
||||
'>'=84
|
||||
'>='=85
|
||||
'+'=86
|
||||
'-'=87
|
||||
'*'=88
|
||||
'/'=89
|
||||
'%'=90
|
||||
'{'=91
|
||||
'}'=92
|
||||
'??'=93
|
||||
']'=97
|
||||
')'=99
|
||||
'metadata'=105
|
||||
'join'=113
|
||||
'USING'=114
|
||||
'info'=134
|
||||
|
|
|
@ -61,6 +61,7 @@ processingCommand
|
|||
| {this.isDevVersion()}? changePointCommand
|
||||
| {this.isDevVersion()}? insistCommand
|
||||
| {this.isDevVersion()}? forkCommand
|
||||
| {this.isDevVersion()}? rerankCommand
|
||||
| {this.isDevVersion()}? rrfCommand
|
||||
;
|
||||
|
||||
|
@ -288,3 +289,7 @@ forkSubQueryProcessingCommand
|
|||
rrfCommand
|
||||
: DEV_RRF
|
||||
;
|
||||
|
||||
rerankCommand
|
||||
: DEV_RERANK queryText=constant ON fields WITH inferenceId=identifierOrParameter
|
||||
;
|
||||
|
|
|
@ -13,49 +13,49 @@ SORT=12
|
|||
STATS=13
|
||||
WHERE=14
|
||||
DEV_INLINESTATS=15
|
||||
FROM=16
|
||||
DEV_TIME_SERIES=17
|
||||
DEV_FORK=18
|
||||
JOIN_LOOKUP=19
|
||||
DEV_JOIN_FULL=20
|
||||
DEV_JOIN_LEFT=21
|
||||
DEV_JOIN_RIGHT=22
|
||||
DEV_LOOKUP=23
|
||||
MV_EXPAND=24
|
||||
DROP=25
|
||||
KEEP=26
|
||||
DEV_INSIST=27
|
||||
DEV_RRF=28
|
||||
RENAME=29
|
||||
SHOW=30
|
||||
UNKNOWN_CMD=31
|
||||
CHANGE_POINT_LINE_COMMENT=32
|
||||
CHANGE_POINT_MULTILINE_COMMENT=33
|
||||
CHANGE_POINT_WS=34
|
||||
ON=35
|
||||
WITH=36
|
||||
ENRICH_POLICY_NAME=37
|
||||
ENRICH_LINE_COMMENT=38
|
||||
ENRICH_MULTILINE_COMMENT=39
|
||||
ENRICH_WS=40
|
||||
ENRICH_FIELD_LINE_COMMENT=41
|
||||
ENRICH_FIELD_MULTILINE_COMMENT=42
|
||||
ENRICH_FIELD_WS=43
|
||||
SETTING=44
|
||||
SETTING_LINE_COMMENT=45
|
||||
SETTTING_MULTILINE_COMMENT=46
|
||||
SETTING_WS=47
|
||||
EXPLAIN_WS=48
|
||||
EXPLAIN_LINE_COMMENT=49
|
||||
EXPLAIN_MULTILINE_COMMENT=50
|
||||
PIPE=51
|
||||
QUOTED_STRING=52
|
||||
INTEGER_LITERAL=53
|
||||
DECIMAL_LITERAL=54
|
||||
BY=55
|
||||
AND=56
|
||||
ASC=57
|
||||
ASSIGN=58
|
||||
DEV_RERANK=16
|
||||
FROM=17
|
||||
DEV_TIME_SERIES=18
|
||||
DEV_FORK=19
|
||||
JOIN_LOOKUP=20
|
||||
DEV_JOIN_FULL=21
|
||||
DEV_JOIN_LEFT=22
|
||||
DEV_JOIN_RIGHT=23
|
||||
DEV_LOOKUP=24
|
||||
MV_EXPAND=25
|
||||
DROP=26
|
||||
KEEP=27
|
||||
DEV_INSIST=28
|
||||
DEV_RRF=29
|
||||
RENAME=30
|
||||
SHOW=31
|
||||
UNKNOWN_CMD=32
|
||||
CHANGE_POINT_LINE_COMMENT=33
|
||||
CHANGE_POINT_MULTILINE_COMMENT=34
|
||||
CHANGE_POINT_WS=35
|
||||
ENRICH_POLICY_NAME=36
|
||||
ENRICH_LINE_COMMENT=37
|
||||
ENRICH_MULTILINE_COMMENT=38
|
||||
ENRICH_WS=39
|
||||
ENRICH_FIELD_LINE_COMMENT=40
|
||||
ENRICH_FIELD_MULTILINE_COMMENT=41
|
||||
ENRICH_FIELD_WS=42
|
||||
SETTING=43
|
||||
SETTING_LINE_COMMENT=44
|
||||
SETTTING_MULTILINE_COMMENT=45
|
||||
SETTING_WS=46
|
||||
EXPLAIN_WS=47
|
||||
EXPLAIN_LINE_COMMENT=48
|
||||
EXPLAIN_MULTILINE_COMMENT=49
|
||||
PIPE=50
|
||||
QUOTED_STRING=51
|
||||
INTEGER_LITERAL=52
|
||||
DECIMAL_LITERAL=53
|
||||
AND=54
|
||||
AS=55
|
||||
ASC=56
|
||||
ASSIGN=57
|
||||
BY=58
|
||||
CAST_OP=59
|
||||
COLON=60
|
||||
COMMA=61
|
||||
|
@ -70,70 +70,71 @@ LIKE=69
|
|||
NOT=70
|
||||
NULL=71
|
||||
NULLS=72
|
||||
OR=73
|
||||
PARAM=74
|
||||
RLIKE=75
|
||||
TRUE=76
|
||||
EQ=77
|
||||
CIEQ=78
|
||||
NEQ=79
|
||||
LT=80
|
||||
LTE=81
|
||||
GT=82
|
||||
GTE=83
|
||||
PLUS=84
|
||||
MINUS=85
|
||||
ASTERISK=86
|
||||
SLASH=87
|
||||
PERCENT=88
|
||||
LEFT_BRACES=89
|
||||
RIGHT_BRACES=90
|
||||
DOUBLE_PARAMS=91
|
||||
NAMED_OR_POSITIONAL_PARAM=92
|
||||
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=93
|
||||
OPENING_BRACKET=94
|
||||
CLOSING_BRACKET=95
|
||||
LP=96
|
||||
RP=97
|
||||
UNQUOTED_IDENTIFIER=98
|
||||
QUOTED_IDENTIFIER=99
|
||||
EXPR_LINE_COMMENT=100
|
||||
EXPR_MULTILINE_COMMENT=101
|
||||
EXPR_WS=102
|
||||
METADATA=103
|
||||
UNQUOTED_SOURCE=104
|
||||
FROM_LINE_COMMENT=105
|
||||
FROM_MULTILINE_COMMENT=106
|
||||
FROM_WS=107
|
||||
FORK_WS=108
|
||||
FORK_LINE_COMMENT=109
|
||||
FORK_MULTILINE_COMMENT=110
|
||||
JOIN=111
|
||||
USING=112
|
||||
JOIN_LINE_COMMENT=113
|
||||
JOIN_MULTILINE_COMMENT=114
|
||||
JOIN_WS=115
|
||||
LOOKUP_LINE_COMMENT=116
|
||||
LOOKUP_MULTILINE_COMMENT=117
|
||||
LOOKUP_WS=118
|
||||
LOOKUP_FIELD_LINE_COMMENT=119
|
||||
LOOKUP_FIELD_MULTILINE_COMMENT=120
|
||||
LOOKUP_FIELD_WS=121
|
||||
MVEXPAND_LINE_COMMENT=122
|
||||
MVEXPAND_MULTILINE_COMMENT=123
|
||||
MVEXPAND_WS=124
|
||||
ID_PATTERN=125
|
||||
PROJECT_LINE_COMMENT=126
|
||||
PROJECT_MULTILINE_COMMENT=127
|
||||
PROJECT_WS=128
|
||||
AS=129
|
||||
RENAME_LINE_COMMENT=130
|
||||
RENAME_MULTILINE_COMMENT=131
|
||||
RENAME_WS=132
|
||||
INFO=133
|
||||
SHOW_LINE_COMMENT=134
|
||||
SHOW_MULTILINE_COMMENT=135
|
||||
SHOW_WS=136
|
||||
ON=73
|
||||
OR=74
|
||||
PARAM=75
|
||||
RLIKE=76
|
||||
TRUE=77
|
||||
WITH=78
|
||||
EQ=79
|
||||
CIEQ=80
|
||||
NEQ=81
|
||||
LT=82
|
||||
LTE=83
|
||||
GT=84
|
||||
GTE=85
|
||||
PLUS=86
|
||||
MINUS=87
|
||||
ASTERISK=88
|
||||
SLASH=89
|
||||
PERCENT=90
|
||||
LEFT_BRACES=91
|
||||
RIGHT_BRACES=92
|
||||
DOUBLE_PARAMS=93
|
||||
NAMED_OR_POSITIONAL_PARAM=94
|
||||
NAMED_OR_POSITIONAL_DOUBLE_PARAMS=95
|
||||
OPENING_BRACKET=96
|
||||
CLOSING_BRACKET=97
|
||||
LP=98
|
||||
RP=99
|
||||
UNQUOTED_IDENTIFIER=100
|
||||
QUOTED_IDENTIFIER=101
|
||||
EXPR_LINE_COMMENT=102
|
||||
EXPR_MULTILINE_COMMENT=103
|
||||
EXPR_WS=104
|
||||
METADATA=105
|
||||
UNQUOTED_SOURCE=106
|
||||
FROM_LINE_COMMENT=107
|
||||
FROM_MULTILINE_COMMENT=108
|
||||
FROM_WS=109
|
||||
FORK_WS=110
|
||||
FORK_LINE_COMMENT=111
|
||||
FORK_MULTILINE_COMMENT=112
|
||||
JOIN=113
|
||||
USING=114
|
||||
JOIN_LINE_COMMENT=115
|
||||
JOIN_MULTILINE_COMMENT=116
|
||||
JOIN_WS=117
|
||||
LOOKUP_LINE_COMMENT=118
|
||||
LOOKUP_MULTILINE_COMMENT=119
|
||||
LOOKUP_WS=120
|
||||
LOOKUP_FIELD_LINE_COMMENT=121
|
||||
LOOKUP_FIELD_MULTILINE_COMMENT=122
|
||||
LOOKUP_FIELD_WS=123
|
||||
MVEXPAND_LINE_COMMENT=124
|
||||
MVEXPAND_MULTILINE_COMMENT=125
|
||||
MVEXPAND_WS=126
|
||||
ID_PATTERN=127
|
||||
PROJECT_LINE_COMMENT=128
|
||||
PROJECT_MULTILINE_COMMENT=129
|
||||
PROJECT_WS=130
|
||||
RENAME_LINE_COMMENT=131
|
||||
RENAME_MULTILINE_COMMENT=132
|
||||
RENAME_WS=133
|
||||
INFO=134
|
||||
SHOW_LINE_COMMENT=135
|
||||
SHOW_MULTILINE_COMMENT=136
|
||||
SHOW_WS=137
|
||||
'enrich'=5
|
||||
'explain'=6
|
||||
'dissect'=7
|
||||
|
@ -144,20 +145,19 @@ SHOW_WS=136
|
|||
'sort'=12
|
||||
'stats'=13
|
||||
'where'=14
|
||||
'from'=16
|
||||
'lookup'=19
|
||||
'mv_expand'=24
|
||||
'drop'=25
|
||||
'keep'=26
|
||||
'rename'=29
|
||||
'show'=30
|
||||
'on'=35
|
||||
'with'=36
|
||||
'|'=51
|
||||
'by'=55
|
||||
'and'=56
|
||||
'asc'=57
|
||||
'='=58
|
||||
'from'=17
|
||||
'lookup'=20
|
||||
'mv_expand'=25
|
||||
'drop'=26
|
||||
'keep'=27
|
||||
'rename'=30
|
||||
'show'=31
|
||||
'|'=50
|
||||
'and'=54
|
||||
'as'=55
|
||||
'asc'=56
|
||||
'='=57
|
||||
'by'=58
|
||||
'::'=59
|
||||
':'=60
|
||||
','=61
|
||||
|
@ -172,29 +172,30 @@ SHOW_WS=136
|
|||
'not'=70
|
||||
'null'=71
|
||||
'nulls'=72
|
||||
'or'=73
|
||||
'?'=74
|
||||
'rlike'=75
|
||||
'true'=76
|
||||
'=='=77
|
||||
'=~'=78
|
||||
'!='=79
|
||||
'<'=80
|
||||
'<='=81
|
||||
'>'=82
|
||||
'>='=83
|
||||
'+'=84
|
||||
'-'=85
|
||||
'*'=86
|
||||
'/'=87
|
||||
'%'=88
|
||||
'{'=89
|
||||
'}'=90
|
||||
'??'=91
|
||||
']'=95
|
||||
')'=97
|
||||
'metadata'=103
|
||||
'join'=111
|
||||
'USING'=112
|
||||
'as'=129
|
||||
'info'=133
|
||||
'on'=73
|
||||
'or'=74
|
||||
'?'=75
|
||||
'rlike'=76
|
||||
'true'=77
|
||||
'with'=78
|
||||
'=='=79
|
||||
'=~'=80
|
||||
'!='=81
|
||||
'<'=82
|
||||
'<='=83
|
||||
'>'=84
|
||||
'>='=85
|
||||
'+'=86
|
||||
'-'=87
|
||||
'*'=88
|
||||
'/'=89
|
||||
'%'=90
|
||||
'{'=91
|
||||
'}'=92
|
||||
'??'=93
|
||||
']'=97
|
||||
')'=99
|
||||
'metadata'=105
|
||||
'join'=113
|
||||
'USING'=114
|
||||
'info'=134
|
||||
|
|
|
@ -16,8 +16,8 @@ mode ENRICH_MODE;
|
|||
ENRICH_PIPE : PIPE -> type(PIPE), popMode;
|
||||
ENRICH_OPENING_BRACKET : OPENING_BRACKET -> type(OPENING_BRACKET), pushMode(SETTING_MODE);
|
||||
|
||||
ON : 'on' -> pushMode(ENRICH_FIELD_MODE);
|
||||
WITH : 'with' -> pushMode(ENRICH_FIELD_MODE);
|
||||
ENRICH_ON : ON -> type(ON), pushMode(ENRICH_FIELD_MODE);
|
||||
ENRICH_WITH : WITH -> type(WITH), pushMode(ENRICH_FIELD_MODE);
|
||||
|
||||
// similar to that of an index
|
||||
// see https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html#indices-create-api-path-params
|
||||
|
|
|
@ -19,6 +19,7 @@ STATS : 'stats' -> pushMode(EXPRESSION_MODE);
|
|||
WHERE : 'where' -> pushMode(EXPRESSION_MODE);
|
||||
|
||||
DEV_INLINESTATS : {this.isDevVersion()}? 'inlinestats' -> pushMode(EXPRESSION_MODE);
|
||||
DEV_RERANK : {this.isDevVersion()}? 'rerank' -> pushMode(EXPRESSION_MODE);
|
||||
|
||||
|
||||
mode EXPRESSION_MODE;
|
||||
|
@ -82,11 +83,12 @@ DECIMAL_LITERAL
|
|||
| DOT DIGIT+ EXPONENT
|
||||
;
|
||||
|
||||
BY : 'by';
|
||||
|
||||
AND : 'and';
|
||||
AS: 'as';
|
||||
ASC : 'asc';
|
||||
ASSIGN : '=';
|
||||
BY : 'by';
|
||||
CAST_OP : '::';
|
||||
COLON : ':';
|
||||
COMMA : ',';
|
||||
|
@ -101,10 +103,12 @@ LIKE: 'like';
|
|||
NOT : 'not';
|
||||
NULL : 'null';
|
||||
NULLS : 'nulls';
|
||||
ON: 'on';
|
||||
OR : 'or';
|
||||
PARAM: '?';
|
||||
RLIKE: 'rlike';
|
||||
TRUE : 'true';
|
||||
WITH: 'with';
|
||||
|
||||
EQ : '==';
|
||||
CIEQ : '=~';
|
||||
|
|
|
@ -22,7 +22,7 @@ RENAME_NAMED_OR_POSITIONAL_PARAM : NAMED_OR_POSITIONAL_PARAM -> type(NAMED_OR_PO
|
|||
RENAME_DOUBLE_PARAMS : DOUBLE_PARAMS -> type(DOUBLE_PARAMS);
|
||||
RENAME_NAMED_OR_POSITIONAL_DOUBLE_PARAMS : NAMED_OR_POSITIONAL_DOUBLE_PARAMS -> type(NAMED_OR_POSITIONAL_DOUBLE_PARAMS);
|
||||
|
||||
AS : 'as';
|
||||
RENAME_AS : AS -> type(AS);
|
||||
|
||||
RENAME_ID_PATTERN
|
||||
: ID_PATTERN -> type(ID_PATTERN)
|
||||
|
|
|
@ -873,6 +873,11 @@ public class EsqlCapabilities {
|
|||
*/
|
||||
FORK(Build.current().isSnapshot()),
|
||||
|
||||
/**
|
||||
* Support for RERANK command
|
||||
*/
|
||||
RERANK(Build.current().isSnapshot()),
|
||||
|
||||
/**
|
||||
* Allow mixed numeric types in conditional functions - case, greatest and least
|
||||
*/
|
||||
|
|
|
@ -33,7 +33,7 @@ import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.nanoTimeTo
|
|||
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.spatialToString;
|
||||
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.versionToString;
|
||||
|
||||
abstract class PositionToXContent {
|
||||
public abstract class PositionToXContent {
|
||||
protected final Block block;
|
||||
|
||||
PositionToXContent(Block block) {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.esql.analysis;
|
||||
|
||||
import org.elasticsearch.common.logging.HeaderWarning;
|
||||
import org.elasticsearch.common.logging.LoggerMessageFormat;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.core.Strings;
|
||||
import org.elasticsearch.index.IndexMode;
|
||||
|
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
|||
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Nullability;
|
||||
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
|
||||
|
@ -67,6 +69,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Esq
|
|||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
|
||||
import org.elasticsearch.xpack.esql.parser.ParsingException;
|
||||
import org.elasticsearch.xpack.esql.plan.IndexPattern;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
|
||||
|
@ -86,6 +89,8 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.Rename;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
|
||||
|
@ -122,7 +127,6 @@ import java.util.stream.Collectors;
|
|||
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
|
||||
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
|
||||
|
@ -154,7 +158,15 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
);
|
||||
|
||||
private static final List<Batch<LogicalPlan>> RULES = List.of(
|
||||
new Batch<>("Initialize", Limiter.ONCE, new ResolveTable(), new ResolveEnrich(), new ResolveLookupTables(), new ResolveFunctions()),
|
||||
new Batch<>(
|
||||
"Initialize",
|
||||
Limiter.ONCE,
|
||||
new ResolveTable(),
|
||||
new ResolveEnrich(),
|
||||
new ResolveInference(),
|
||||
new ResolveLookupTables(),
|
||||
new ResolveFunctions()
|
||||
),
|
||||
new Batch<>(
|
||||
"Resolution",
|
||||
/*
|
||||
|
@ -380,6 +392,34 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
}
|
||||
}
|
||||
|
||||
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan, AnalyzerContext> {
|
||||
@Override
|
||||
protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) {
|
||||
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
|
||||
|
||||
String inferenceId = plan.inferenceId().fold(FoldContext.small()).toString();
|
||||
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
|
||||
|
||||
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
|
||||
return plan;
|
||||
} else if (resolvedInference != null) {
|
||||
String error = "cannot use inference endpoint ["
|
||||
+ inferenceId
|
||||
+ "] with task type ["
|
||||
+ resolvedInference.taskType()
|
||||
+ "] within a "
|
||||
+ plan.nodeName()
|
||||
+ " command. Only inference endpoints with the task type ["
|
||||
+ plan.taskType()
|
||||
+ "] are supported.";
|
||||
return plan.withInferenceResolutionError(inferenceId, error);
|
||||
} else {
|
||||
String error = context.inferenceResolution().getError(inferenceId);
|
||||
return plan.withInferenceResolutionError(inferenceId, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {
|
||||
|
||||
@Override
|
||||
|
@ -491,6 +531,10 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
return resolveRrfScoreEval(rrf, childrenOutput);
|
||||
}
|
||||
|
||||
if (plan instanceof Rerank r) {
|
||||
return resolveRerank(r, childrenOutput);
|
||||
}
|
||||
|
||||
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
|
||||
}
|
||||
|
||||
|
@ -670,6 +714,33 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
return join;
|
||||
}
|
||||
|
||||
private LogicalPlan resolveRerank(Rerank rerank, List<Attribute> childrenOutput) {
|
||||
List<Alias> newFields = new ArrayList<>();
|
||||
boolean changed = false;
|
||||
|
||||
// First resolving fields used in expression
|
||||
for (Alias field : rerank.rerankFields()) {
|
||||
Alias result = (Alias) field.transformUp(UnresolvedAttribute.class, ua -> resolveAttribute(ua, childrenOutput));
|
||||
newFields.add(result);
|
||||
changed |= result != field;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
rerank = rerank.withRerankFields(newFields);
|
||||
}
|
||||
|
||||
// Ensure the score attribute is present in the output.
|
||||
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
|
||||
Attribute resolved = resolveAttribute(ua, childrenOutput);
|
||||
if (resolved.resolved() == false || resolved.dataType() != DOUBLE) {
|
||||
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
|
||||
}
|
||||
rerank = rerank.withScoreAttribute(resolved);
|
||||
}
|
||||
|
||||
return rerank;
|
||||
}
|
||||
|
||||
private List<Attribute> resolveUsingColumns(List<Attribute> cols, List<Attribute> output, String side) {
|
||||
List<Attribute> resolved = new ArrayList<>(cols.size());
|
||||
for (Attribute col : cols) {
|
||||
|
@ -987,7 +1058,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
var u = resolved;
|
||||
var previousAliasName = reverseAliasing.get(resolved.name());
|
||||
if (previousAliasName != null) {
|
||||
String message = format(
|
||||
String message = LoggerMessageFormat.format(
|
||||
null,
|
||||
"Column [{}] renamed to [{}] and is no longer available [{}]",
|
||||
resolved.name(),
|
||||
|
@ -1380,7 +1451,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
}
|
||||
|
||||
private static UnresolvedAttribute unresolvedAttribute(Expression value, String type, Exception e) {
|
||||
String message = format(
|
||||
String message = LoggerMessageFormat.format(
|
||||
"Cannot convert string [{}] to [{}], error [{}]",
|
||||
value.fold(FoldContext.small() /* TODO remove me */),
|
||||
type,
|
||||
|
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.esql.analysis;
|
|||
|
||||
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
|
||||
import org.elasticsearch.xpack.esql.session.Configuration;
|
||||
|
||||
import java.util.Map;
|
||||
|
@ -18,7 +19,8 @@ public record AnalyzerContext(
|
|||
EsqlFunctionRegistry functionRegistry,
|
||||
IndexResolution indexResolution,
|
||||
Map<String, IndexResolution> lookupResolution,
|
||||
EnrichResolution enrichResolution
|
||||
EnrichResolution enrichResolution,
|
||||
InferenceResolution inferenceResolution
|
||||
) {
|
||||
// Currently for tests only, since most do not test lookups
|
||||
// TODO: make this even simpler, remove the enrichResolution for tests that do not require it (most tests)
|
||||
|
@ -26,8 +28,9 @@ public record AnalyzerContext(
|
|||
Configuration configuration,
|
||||
EsqlFunctionRegistry functionRegistry,
|
||||
IndexResolution indexResolution,
|
||||
EnrichResolution enrichResolution
|
||||
EnrichResolution enrichResolution,
|
||||
InferenceResolution inferenceResolution
|
||||
) {
|
||||
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution);
|
||||
this(configuration, functionRegistry, indexResolution, Map.of(), enrichResolution, inferenceResolution);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.plan.IndexPattern;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
|
@ -26,15 +27,22 @@ import static java.util.Collections.emptyList;
|
|||
public class PreAnalyzer {
|
||||
|
||||
public static class PreAnalysis {
|
||||
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList());
|
||||
public static final PreAnalysis EMPTY = new PreAnalysis(emptyList(), emptyList(), emptyList(), emptyList());
|
||||
|
||||
public final List<IndexPattern> indices;
|
||||
public final List<Enrich> enriches;
|
||||
public final List<InferencePlan> inferencePlans;
|
||||
public final List<IndexPattern> lookupIndices;
|
||||
|
||||
public PreAnalysis(List<IndexPattern> indices, List<Enrich> enriches, List<IndexPattern> lookupIndices) {
|
||||
public PreAnalysis(
|
||||
List<IndexPattern> indices,
|
||||
List<Enrich> enriches,
|
||||
List<InferencePlan> inferencePlans,
|
||||
List<IndexPattern> lookupIndices
|
||||
) {
|
||||
this.indices = indices;
|
||||
this.enriches = enriches;
|
||||
this.inferencePlans = inferencePlans;
|
||||
this.lookupIndices = lookupIndices;
|
||||
}
|
||||
}
|
||||
|
@ -52,13 +60,15 @@ public class PreAnalyzer {
|
|||
|
||||
List<Enrich> unresolvedEnriches = new ArrayList<>();
|
||||
List<IndexPattern> lookupIndices = new ArrayList<>();
|
||||
List<InferencePlan> unresolvedInferencePlans = new ArrayList<>();
|
||||
|
||||
plan.forEachUp(UnresolvedRelation.class, p -> (p.indexMode() == IndexMode.LOOKUP ? lookupIndices : indices).add(p.indexPattern()));
|
||||
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
|
||||
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
|
||||
|
||||
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
|
||||
plan.forEachUp(LogicalPlan::setPreAnalyzed);
|
||||
|
||||
return new PreAnalysis(indices.stream().toList(), unresolvedEnriches, lookupIndices);
|
||||
return new PreAnalysis(indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
public class InferenceResolution {
|
||||
|
||||
public static final InferenceResolution EMPTY = new InferenceResolution.Builder().build();
|
||||
|
||||
public static InferenceResolution.Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
private final Map<String, ResolvedInference> resolvedInferences;
|
||||
private final Map<String, String> errors;
|
||||
|
||||
private InferenceResolution(Map<String, ResolvedInference> resolvedInferences, Map<String, String> errors) {
|
||||
this.resolvedInferences = Collections.unmodifiableMap(resolvedInferences);
|
||||
this.errors = Collections.unmodifiableMap(errors);
|
||||
}
|
||||
|
||||
public ResolvedInference getResolvedInference(String inferenceId) {
|
||||
return resolvedInferences.get(inferenceId);
|
||||
}
|
||||
|
||||
public Collection<ResolvedInference> resolvedInferences() {
|
||||
return resolvedInferences.values();
|
||||
}
|
||||
|
||||
public boolean hasError() {
|
||||
return errors.isEmpty() == false;
|
||||
}
|
||||
|
||||
public String getError(String inferenceId) {
|
||||
final String error = errors.get(inferenceId);
|
||||
if (error != null) {
|
||||
return error;
|
||||
} else {
|
||||
return "unresolved inference [" + inferenceId + "]";
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final Map<String, ResolvedInference> resolvedInferences;
|
||||
private final Map<String, String> errors;
|
||||
|
||||
private Builder() {
|
||||
this.resolvedInferences = ConcurrentCollections.newConcurrentMap();
|
||||
this.errors = ConcurrentCollections.newConcurrentMap();
|
||||
}
|
||||
|
||||
public Builder withResolvedInference(ResolvedInference resolvedInference) {
|
||||
resolvedInferences.putIfAbsent(resolvedInference.inferenceId(), resolvedInference);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withError(String inferenceId, String reason) {
|
||||
errors.putIfAbsent(inferenceId, reason);
|
||||
return this;
|
||||
}
|
||||
|
||||
public InferenceResolution build() {
|
||||
return new InferenceResolution(resolvedInferences, errors);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.CountDownActionListener;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class InferenceRunner {
|
||||
|
||||
private final Client client;
|
||||
|
||||
public InferenceRunner(Client client) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
public ThreadContext getThreadContext() {
|
||||
return client.threadPool().getThreadContext();
|
||||
}
|
||||
|
||||
public void resolveInferenceIds(List<InferencePlan> plans, ActionListener<InferenceResolution> listener) {
|
||||
resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
|
||||
|
||||
}
|
||||
|
||||
private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
|
||||
|
||||
if (inferenceIds.isEmpty()) {
|
||||
listener.onResponse(InferenceResolution.EMPTY);
|
||||
return;
|
||||
}
|
||||
|
||||
final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
|
||||
|
||||
final CountDownActionListener countdownListener = new CountDownActionListener(
|
||||
inferenceIds.size(),
|
||||
ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
|
||||
);
|
||||
|
||||
for (var inferenceId : inferenceIds) {
|
||||
client.execute(
|
||||
GetInferenceModelAction.INSTANCE,
|
||||
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
|
||||
ActionListener.wrap(r -> {
|
||||
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
|
||||
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
|
||||
countdownListener.onResponse(null);
|
||||
}, e -> {
|
||||
inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
|
||||
countdownListener.onResponse(null);
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static String planInferenceId(InferencePlan plan) {
|
||||
return plan.inferenceId().fold(FoldContext.small()).toString();
|
||||
}
|
||||
|
||||
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
|
||||
client.execute(InferenceAction.INSTANCE, request, listener);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.lucene.BytesRefs;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.AsyncOperator;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class RerankOperator extends AsyncOperator<Page> {
|
||||
|
||||
// Move to a setting.
|
||||
private static final int MAX_INFERENCE_WORKER = 10;
|
||||
|
||||
public record Factory(
|
||||
InferenceRunner inferenceRunner,
|
||||
String inferenceId,
|
||||
String queryText,
|
||||
ExpressionEvaluator.Factory rowEncoderFactory,
|
||||
int scoreChannel
|
||||
) implements OperatorFactory {
|
||||
|
||||
@Override
|
||||
public String describe() {
|
||||
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Operator get(DriverContext driverContext) {
|
||||
return new RerankOperator(
|
||||
driverContext,
|
||||
inferenceRunner,
|
||||
inferenceId,
|
||||
queryText,
|
||||
rowEncoderFactory().get(driverContext),
|
||||
scoreChannel
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private final InferenceRunner inferenceRunner;
|
||||
private final BlockFactory blockFactory;
|
||||
private final String inferenceId;
|
||||
private final String queryText;
|
||||
private final ExpressionEvaluator rowEncoder;
|
||||
private final int scoreChannel;
|
||||
|
||||
public RerankOperator(
|
||||
DriverContext driverContext,
|
||||
InferenceRunner inferenceRunner,
|
||||
String inferenceId,
|
||||
String queryText,
|
||||
ExpressionEvaluator rowEncoder,
|
||||
int scoreChannel
|
||||
) {
|
||||
super(driverContext, inferenceRunner.getThreadContext(), MAX_INFERENCE_WORKER);
|
||||
|
||||
assert inferenceRunner.getThreadContext() != null;
|
||||
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
this.inferenceRunner = inferenceRunner;
|
||||
this.inferenceId = inferenceId;
|
||||
this.queryText = queryText;
|
||||
this.rowEncoder = rowEncoder;
|
||||
this.scoreChannel = scoreChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void performAsync(Page inputPage, ActionListener<Page> listener) {
|
||||
// Ensure input page blocks are released when the listener is called.
|
||||
final ActionListener<Page> outputListener = ActionListener.runAfter(listener, () -> { releasePageOnAnyThread(inputPage); });
|
||||
|
||||
try {
|
||||
inferenceRunner.doInference(
|
||||
buildInferenceRequest(inputPage),
|
||||
ActionListener.wrap(
|
||||
inferenceResponse -> outputListener.onResponse(buildOutput(inputPage, inferenceResponse)),
|
||||
outputListener::onFailure
|
||||
)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
outputListener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doClose() {
|
||||
Releasables.closeExpectNoException(rowEncoder);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void releaseFetchedOnAnyThread(Page page) {
|
||||
releasePageOnAnyThread(page);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page getOutput() {
|
||||
return fetchFromBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "RerankOperator[inference_id=[" + inferenceId + "], query=[" + queryText + "], score_channel=[" + scoreChannel + "]]";
|
||||
}
|
||||
|
||||
private Page buildOutput(Page inputPage, InferenceAction.Response inferenceResponse) {
|
||||
if (inferenceResponse.getResults() instanceof RankedDocsResults rankedDocsResults) {
|
||||
return buildOutput(inputPage, rankedDocsResults);
|
||||
|
||||
}
|
||||
|
||||
throw new IllegalStateException(
|
||||
"Inference result has wrong type. Got ["
|
||||
+ inferenceResponse.getResults().getClass()
|
||||
+ "] while expecting ["
|
||||
+ RankedDocsResults.class
|
||||
+ "]"
|
||||
);
|
||||
}
|
||||
|
||||
private Page buildOutput(Page inputPage, RankedDocsResults rankedDocsResults) {
|
||||
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
|
||||
Block[] blocks = new Block[blockCount];
|
||||
|
||||
try {
|
||||
for (int b = 0; b < blockCount; b++) {
|
||||
if (b == scoreChannel) {
|
||||
blocks[b] = buildScoreBlock(inputPage, rankedDocsResults);
|
||||
} else {
|
||||
blocks[b] = inputPage.getBlock(b);
|
||||
blocks[b].incRef();
|
||||
}
|
||||
}
|
||||
return new Page(blocks);
|
||||
} catch (Exception e) {
|
||||
Releasables.closeExpectNoException(blocks);
|
||||
throw (e);
|
||||
}
|
||||
}
|
||||
|
||||
private Block buildScoreBlock(Page inputPage, RankedDocsResults rankedDocsResults) {
|
||||
Double[] sortedRankedDocsScores = new Double[inputPage.getPositionCount()];
|
||||
|
||||
try (DoubleBlock.Builder scoreBlockFactory = blockFactory.newDoubleBlockBuilder(inputPage.getPositionCount())) {
|
||||
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) {
|
||||
sortedRankedDocsScores[rankedDoc.index()] = (double) rankedDoc.relevanceScore();
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
|
||||
if (sortedRankedDocsScores[pos] != null) {
|
||||
scoreBlockFactory.appendDouble(sortedRankedDocsScores[pos]);
|
||||
} else {
|
||||
scoreBlockFactory.appendNull();
|
||||
}
|
||||
}
|
||||
|
||||
return scoreBlockFactory.build();
|
||||
}
|
||||
}
|
||||
|
||||
private InferenceAction.Request buildInferenceRequest(Page inputPage) {
|
||||
try (BytesRefBlock encodedRowsBlock = (BytesRefBlock) rowEncoder.eval(inputPage)) {
|
||||
assert (encodedRowsBlock.getPositionCount() == inputPage.getPositionCount());
|
||||
String[] inputs = new String[inputPage.getPositionCount()];
|
||||
BytesRef buffer = new BytesRef();
|
||||
|
||||
for (int pos = 0; pos < inputPage.getPositionCount(); pos++) {
|
||||
if (encodedRowsBlock.isNull(pos)) {
|
||||
inputs[pos] = "";
|
||||
} else {
|
||||
buffer = encodedRowsBlock.getBytesRef(encodedRowsBlock.getFirstValueIndex(pos), buffer);
|
||||
inputs[pos] = BytesRefs.toString(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(List.of(inputs)).setQuery(queryText).build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable {
|
||||
|
||||
public ResolvedInference(StreamInput in) throws IOException {
|
||||
this(in.readString(), TaskType.valueOf(in.readString()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(inferenceId);
|
||||
out.writeString(taskType.name());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.common.io.stream.BytesRefStreamOutput;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
|
||||
import org.elasticsearch.xpack.esql.action.PositionToXContent;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Encodes rows into an XContent format (JSON,YAML,...) for further processing.
|
||||
* Extracted columns can be specified using {@link ExpressionEvaluator}
|
||||
*/
|
||||
public class XContentRowEncoder implements ExpressionEvaluator {
|
||||
private final XContentType xContentType;
|
||||
private final BlockFactory blockFactory;
|
||||
private final ColumnInfoImpl[] columnsInfo;
|
||||
private final ExpressionEvaluator[] fieldsValueEvaluators;
|
||||
|
||||
/**
|
||||
* Creates a factory for YAML XContent row encoding.
|
||||
*
|
||||
* @param fieldsEvaluatorFactories A map of column information to expression evaluators.
|
||||
* @return A Factory instance for creating YAML row encoder for the specified column.
|
||||
*/
|
||||
public static Factory yamlRowEncoderFactory(Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
|
||||
return new Factory(XContentType.YAML, fieldsEvaluatorFactories);
|
||||
}
|
||||
|
||||
private XContentRowEncoder(
|
||||
XContentType xContentType,
|
||||
BlockFactory blockFactory,
|
||||
ColumnInfoImpl[] columnsInfo,
|
||||
ExpressionEvaluator[] fieldsValueEvaluators
|
||||
) {
|
||||
assert columnsInfo.length == fieldsValueEvaluators.length;
|
||||
this.xContentType = xContentType;
|
||||
this.blockFactory = blockFactory;
|
||||
this.columnsInfo = columnsInfo;
|
||||
this.fieldsValueEvaluators = fieldsValueEvaluators;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(fieldsValueEvaluators);
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the provided Page and encode its rows into a BytesRefBlock containing XContent-formatted rows.
|
||||
*
|
||||
* @param page The input Page containing row data.
|
||||
* @return A BytesRefBlock containing the encoded rows.
|
||||
*/
|
||||
@Override
|
||||
public BytesRefBlock eval(Page page) {
|
||||
Block[] fieldValueBlocks = new Block[fieldsValueEvaluators.length];
|
||||
try (
|
||||
BytesRefStreamOutput outputStream = new BytesRefStreamOutput();
|
||||
XContentBuilder xContentBuilder = XContentFactory.contentBuilder(xContentType, outputStream);
|
||||
BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(page.getPositionCount());
|
||||
) {
|
||||
|
||||
PositionToXContent[] toXContents = new PositionToXContent[fieldsValueEvaluators.length];
|
||||
for (int b = 0; b < fieldValueBlocks.length; b++) {
|
||||
fieldValueBlocks[b] = fieldsValueEvaluators[b].eval(page);
|
||||
toXContents[b] = PositionToXContent.positionToXContent(columnsInfo[b], fieldValueBlocks[b], new BytesRef());
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < page.getPositionCount(); pos++) {
|
||||
xContentBuilder.startObject();
|
||||
for (int i = 0; i < fieldValueBlocks.length; i++) {
|
||||
String fieldName = columnsInfo[i].name();
|
||||
Block currentBlock = fieldValueBlocks[i];
|
||||
if (currentBlock.isNull(pos) || currentBlock.getValueCount(pos) < 1) {
|
||||
continue;
|
||||
}
|
||||
toXContents[i].positionToXContent(xContentBuilder.field(fieldName), ToXContent.EMPTY_PARAMS, pos);
|
||||
}
|
||||
xContentBuilder.endObject().flush();
|
||||
outputBlockBuilder.appendBytesRef(outputStream.get());
|
||||
outputStream.reset();
|
||||
}
|
||||
|
||||
return outputBlockBuilder.build();
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
} finally {
|
||||
Releasables.closeExpectNoException(fieldValueBlocks);
|
||||
}
|
||||
}
|
||||
|
||||
public List<String> fieldNames() {
|
||||
return Arrays.stream(columnsInfo).map(ColumnInfoImpl::name).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "XContentRowEncoder[content_type=[" + xContentType.toString() + "], field_names=" + fieldNames() + "]";
|
||||
}
|
||||
|
||||
public static class Factory implements ExpressionEvaluator.Factory {
|
||||
private final XContentType xContentType;
|
||||
private final Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories;
|
||||
|
||||
private Factory(XContentType xContentType, Map<ColumnInfoImpl, ExpressionEvaluator.Factory> fieldsEvaluatorFactories) {
|
||||
this.xContentType = xContentType;
|
||||
this.fieldsEvaluatorFactories = fieldsEvaluatorFactories;
|
||||
}
|
||||
|
||||
public XContentRowEncoder get(DriverContext context) {
|
||||
return new XContentRowEncoder(xContentType, context.blockFactory(), columnsInfo(), fieldsValueEvaluators(context));
|
||||
}
|
||||
|
||||
private ColumnInfoImpl[] columnsInfo() {
|
||||
return fieldsEvaluatorFactories.keySet().toArray(ColumnInfoImpl[]::new);
|
||||
}
|
||||
|
||||
private ExpressionEvaluator[] fieldsValueEvaluators(DriverContext context) {
|
||||
return fieldsEvaluatorFactories.values().stream().map(factory -> factory.get(context)).toArray(ExpressionEvaluator[]::new);
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -740,6 +740,18 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
|
|||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
|
|
|
@ -440,6 +440,13 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
|
|||
* {@link #visitChildren} on {@code ctx}.</p>
|
||||
*/
|
||||
@Override public T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { return visitChildren(ctx); }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation returns the result of calling
|
||||
* {@link #visitChildren} on {@code ctx}.</p>
|
||||
*/
|
||||
@Override public T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) { return visitChildren(ctx); }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
|
|
|
@ -635,6 +635,16 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
|
|||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code matchExpression}
|
||||
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
|
||||
|
|
|
@ -388,6 +388,12 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
|
|||
* @return the visitor result
|
||||
*/
|
||||
T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
|
||||
/**
|
||||
* Visit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
|
||||
* @param ctx the parse tree
|
||||
* @return the visitor result
|
||||
*/
|
||||
T visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx);
|
||||
/**
|
||||
* Visit a parse tree produced by the {@code matchExpression}
|
||||
* labeled alternative in {@link EsqlBaseParser#booleanExpression}.
|
||||
|
|
|
@ -68,6 +68,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Row;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
|
||||
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
|
||||
|
@ -707,4 +708,56 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
|
|||
return new OrderBy(source, dedup, order);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
|
||||
var source = source(ctx);
|
||||
|
||||
if (false == EsqlCapabilities.Cap.RERANK.isEnabled()) {
|
||||
throw new ParsingException(source, "RERANK is in preview and only available in SNAPSHOT build");
|
||||
}
|
||||
|
||||
Expression queryText = expression(ctx.queryText);
|
||||
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
|
||||
if (queryTextLiteral.value() == null) {
|
||||
throw new ParsingException(
|
||||
source(ctx.queryText),
|
||||
"Query text cannot be null or undefined in RERANK",
|
||||
ctx.queryText.getText()
|
||||
);
|
||||
}
|
||||
} else {
|
||||
throw new ParsingException(
|
||||
source(ctx.queryText),
|
||||
"RERANK only support string as query text but [{}] cannot be used as string",
|
||||
ctx.queryText.getText()
|
||||
);
|
||||
}
|
||||
|
||||
return p -> new Rerank(source, p, inferenceId(ctx.inferenceId), queryText, visitFields(ctx.fields()));
|
||||
}
|
||||
|
||||
public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
|
||||
if (ctx.identifier() != null) {
|
||||
return new Literal(source(ctx), visitIdentifier(ctx.identifier()), KEYWORD);
|
||||
}
|
||||
|
||||
if (expression(ctx.parameter()) instanceof Literal literalParam) {
|
||||
if (literalParam.value() != null) {
|
||||
return literalParam;
|
||||
}
|
||||
|
||||
throw new ParsingException(
|
||||
source(ctx.parameter()),
|
||||
"Query parameter [{}] is null or undefined and cannot be used as inference id",
|
||||
ctx.parameter().getText()
|
||||
);
|
||||
}
|
||||
|
||||
throw new ParsingException(
|
||||
source(ctx.parameter()),
|
||||
"Query parameter [{}] is not a string and cannot be used as inference id",
|
||||
ctx.parameter().getText()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.Project;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TopN;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
|
||||
|
@ -49,6 +50,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
|
|||
import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -81,6 +83,7 @@ public class PlanWritables {
|
|||
MvExpand.ENTRY,
|
||||
OrderBy.ENTRY,
|
||||
Project.ENTRY,
|
||||
Rerank.ENTRY,
|
||||
TimeSeriesAggregate.ENTRY,
|
||||
TopN.ENTRY
|
||||
);
|
||||
|
@ -106,6 +109,7 @@ public class PlanWritables {
|
|||
LocalSourceExec.ENTRY,
|
||||
MvExpandExec.ENTRY,
|
||||
ProjectExec.ENTRY,
|
||||
RerankExec.ENTRY,
|
||||
ShowExec.ENTRY,
|
||||
SubqueryExec.ENTRY,
|
||||
TimeSeriesAggregateExec.ENTRY,
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.logical.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class InferencePlan extends UnaryPlan {
|
||||
|
||||
private final Expression inferenceId;
|
||||
|
||||
protected InferencePlan(Source source, LogicalPlan child, Expression inferenceId) {
|
||||
super(source, child);
|
||||
this.inferenceId = inferenceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
Source.EMPTY.writeTo(out);
|
||||
out.writeNamedWriteable(child());
|
||||
out.writeNamedWriteable(inferenceId());
|
||||
}
|
||||
|
||||
public Expression inferenceId() {
|
||||
return inferenceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean expressionsResolved() {
|
||||
return inferenceId.resolved();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
if (super.equals(o) == false) return false;
|
||||
InferencePlan other = (InferencePlan) o;
|
||||
return Objects.equals(inferenceId(), other.inferenceId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), inferenceId());
|
||||
}
|
||||
|
||||
public abstract TaskType taskType();
|
||||
|
||||
public abstract InferencePlan withInferenceId(Expression newInferenceId);
|
||||
|
||||
public InferencePlan withInferenceResolutionError(String inferenceId, String error) {
|
||||
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.logical.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Alias;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expressions;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.expression.Order;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.esql.plan.QueryPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.SurrogateLogicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
|
||||
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
|
||||
|
||||
public class Rerank extends InferencePlan implements SortAgnostic, SurrogateLogicalPlan {
|
||||
|
||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
|
||||
private final Attribute scoreAttribute;
|
||||
private final Expression queryText;
|
||||
private final List<Alias> rerankFields;
|
||||
private List<Attribute> lazyOutput;
|
||||
|
||||
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
|
||||
super(source, child, inferenceId);
|
||||
this.queryText = queryText;
|
||||
this.rerankFields = rerankFields;
|
||||
this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
|
||||
}
|
||||
|
||||
public Rerank(
|
||||
Source source,
|
||||
LogicalPlan child,
|
||||
Expression inferenceId,
|
||||
Expression queryText,
|
||||
List<Alias> rerankFields,
|
||||
Attribute scoreAttribute
|
||||
) {
|
||||
super(source, child, inferenceId);
|
||||
this.queryText = queryText;
|
||||
this.rerankFields = rerankFields;
|
||||
this.scoreAttribute = scoreAttribute;
|
||||
}
|
||||
|
||||
public Rerank(StreamInput in) throws IOException {
|
||||
this(
|
||||
Source.readFrom((PlanStreamInput) in),
|
||||
in.readNamedWriteable(LogicalPlan.class),
|
||||
in.readNamedWriteable(Expression.class),
|
||||
in.readNamedWriteable(Expression.class),
|
||||
in.readCollectionAsList(Alias::new),
|
||||
in.readNamedWriteable(Attribute.class)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeNamedWriteable(queryText);
|
||||
out.writeCollection(rerankFields());
|
||||
out.writeNamedWriteable(scoreAttribute);
|
||||
}
|
||||
|
||||
public Expression queryText() {
|
||||
return queryText;
|
||||
}
|
||||
|
||||
public List<Alias> rerankFields() {
|
||||
return rerankFields;
|
||||
}
|
||||
|
||||
public Attribute scoreAttribute() {
|
||||
return scoreAttribute;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType taskType() {
|
||||
return TaskType.RERANK;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rerank withInferenceId(Expression newInferenceId) {
|
||||
return new Rerank(source(), child(), newInferenceId, queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
public Rerank withRerankFields(List<Alias> newRerankFields) {
|
||||
return new Rerank(source(), child(), inferenceId(), queryText, newRerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
public Rerank withScoreAttribute(Attribute newScoreAttribute) {
|
||||
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, newScoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return ENTRY.name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UnaryPlan replaceChild(LogicalPlan newChild) {
|
||||
return new Rerank(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AttributeSet computeReferences() {
|
||||
AttributeSet.Builder refs = computeReferences(rerankFields).asBuilder();
|
||||
|
||||
if (planHasAttribute(child(), scoreAttribute)) {
|
||||
refs.add(scoreAttribute);
|
||||
}
|
||||
|
||||
return refs.build();
|
||||
}
|
||||
|
||||
public static AttributeSet computeReferences(List<Alias> fields) {
|
||||
AttributeSet rerankFields = AttributeSet.of(asAttributes(fields));
|
||||
return Expressions.references(fields).subtract(rerankFields);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean expressionsResolved() {
|
||||
return super.expressionsResolved() && queryText.resolved() && Resolvables.resolved(rerankFields) && scoreAttribute.resolved();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NodeInfo<? extends LogicalPlan> info() {
|
||||
return NodeInfo.create(this, Rerank::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
if (super.equals(o) == false) return false;
|
||||
Rerank rerank = (Rerank) o;
|
||||
return Objects.equals(queryText, rerank.queryText)
|
||||
&& Objects.equals(rerankFields, rerank.rerankFields)
|
||||
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalPlan surrogate() {
|
||||
Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
|
||||
return new OrderBy(source(), this, List.of(sortOrder));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Attribute> output() {
|
||||
if (lazyOutput == null) {
|
||||
lazyOutput = planHasAttribute(child(), scoreAttribute)
|
||||
? child().output()
|
||||
: mergeOutputAttributes(List.of(scoreAttribute), child().output());
|
||||
}
|
||||
|
||||
return lazyOutput;
|
||||
}
|
||||
|
||||
public static boolean planHasAttribute(QueryPlan<?> plan, Attribute attribute) {
|
||||
return plan.outputSet().stream().anyMatch(attr -> attr.equals(attribute));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.physical.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class InferenceExec extends UnaryExec {
|
||||
private final Expression inferenceId;
|
||||
|
||||
protected InferenceExec(Source source, PhysicalPlan child, Expression inferenceId) {
|
||||
super(source, child);
|
||||
this.inferenceId = inferenceId;
|
||||
}
|
||||
|
||||
public Expression inferenceId() {
|
||||
return inferenceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
Source.EMPTY.writeTo(out);
|
||||
out.writeNamedWriteable(child());
|
||||
out.writeNamedWriteable(inferenceId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
if (super.equals(o) == false) return false;
|
||||
InferenceExec that = (InferenceExec) o;
|
||||
return inferenceId.equals(that.inferenceId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), inferenceId());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.physical.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Alias;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
|
||||
import static org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.planHasAttribute;
|
||||
|
||||
public class RerankExec extends InferenceExec {
|
||||
|
||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
|
||||
PhysicalPlan.class,
|
||||
"RerankExec",
|
||||
RerankExec::new
|
||||
);
|
||||
|
||||
private final Expression queryText;
|
||||
private final List<Alias> rerankFields;
|
||||
private final Attribute scoreAttribute;
|
||||
|
||||
public RerankExec(
|
||||
Source source,
|
||||
PhysicalPlan child,
|
||||
Expression inferenceId,
|
||||
Expression queryText,
|
||||
List<Alias> rerankFields,
|
||||
Attribute scoreAttribute
|
||||
) {
|
||||
super(source, child, inferenceId);
|
||||
this.queryText = queryText;
|
||||
this.rerankFields = rerankFields;
|
||||
this.scoreAttribute = scoreAttribute;
|
||||
}
|
||||
|
||||
public RerankExec(StreamInput in) throws IOException {
|
||||
this(
|
||||
Source.readFrom((PlanStreamInput) in),
|
||||
in.readNamedWriteable(PhysicalPlan.class),
|
||||
in.readNamedWriteable(Expression.class),
|
||||
in.readNamedWriteable(Expression.class),
|
||||
in.readCollectionAsList(Alias::new),
|
||||
in.readNamedWriteable(Attribute.class)
|
||||
);
|
||||
}
|
||||
|
||||
public Expression queryText() {
|
||||
return queryText;
|
||||
}
|
||||
|
||||
public List<Alias> rerankFields() {
|
||||
return rerankFields;
|
||||
}
|
||||
|
||||
public Attribute scoreAttribute() {
|
||||
return scoreAttribute;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return ENTRY.name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeNamedWriteable(queryText());
|
||||
out.writeCollection(rerankFields());
|
||||
out.writeNamedWriteable(scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NodeInfo<? extends PhysicalPlan> info() {
|
||||
return NodeInfo.create(this, RerankExec::new, child(), inferenceId(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UnaryExec replaceChild(PhysicalPlan newChild) {
|
||||
return new RerankExec(source(), newChild, inferenceId(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Attribute> output() {
|
||||
if (planHasAttribute(child(), scoreAttribute)) {
|
||||
return child().output();
|
||||
}
|
||||
|
||||
return mergeOutputAttributes(List.of(scoreAttribute), child().output());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AttributeSet computeReferences() {
|
||||
AttributeSet.Builder refs = Rerank.computeReferences(rerankFields).asBuilder();
|
||||
|
||||
if (planHasAttribute(child(), scoreAttribute)) {
|
||||
refs.add(scoreAttribute);
|
||||
}
|
||||
|
||||
return refs.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
if (super.equals(o) == false) return false;
|
||||
RerankExec rerank = (RerankExec) o;
|
||||
return Objects.equals(queryText, rerank.queryText)
|
||||
&& Objects.equals(rerankFields, rerank.rerankFields)
|
||||
&& Objects.equals(scoreAttribute, rerank.scoreAttribute);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.esql.planner;
|
||||
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.common.lucene.BytesRefs;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.BigArrays;
|
||||
import org.elasticsearch.compute.Describable;
|
||||
|
@ -23,6 +24,7 @@ import org.elasticsearch.compute.operator.ColumnExtractOperator;
|
|||
import org.elasticsearch.compute.operator.ColumnLoadOperator;
|
||||
import org.elasticsearch.compute.operator.Driver;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.EvalOperator;
|
||||
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
|
||||
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
|
||||
import org.elasticsearch.compute.operator.LimitOperator;
|
||||
|
@ -57,6 +59,7 @@ import org.elasticsearch.logging.Logger;
|
|||
import org.elasticsearch.node.Node;
|
||||
import org.elasticsearch.tasks.CancellableTask;
|
||||
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
|
||||
import org.elasticsearch.xpack.esql.action.ColumnInfoImpl;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Alias;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
|
@ -78,6 +81,9 @@ import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
|
|||
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
|
||||
import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter;
|
||||
import org.elasticsearch.xpack.esql.expression.Order;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.inference.RerankOperator;
|
||||
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Fork;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ChangePointExec;
|
||||
|
@ -104,6 +110,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
|
|||
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
|
||||
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
|
||||
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
|
||||
import org.elasticsearch.xpack.esql.score.ScoreMapper;
|
||||
|
@ -112,6 +119,7 @@ import org.elasticsearch.xpack.esql.session.Configuration;
|
|||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -144,6 +152,7 @@ public class LocalExecutionPlanner {
|
|||
private final Supplier<ExchangeSink> exchangeSinkSupplier;
|
||||
private final EnrichLookupService enrichLookupService;
|
||||
private final LookupFromIndexService lookupFromIndexService;
|
||||
private final InferenceRunner inferenceRunner;
|
||||
private final PhysicalOperationProviders physicalOperationProviders;
|
||||
private final List<ShardContext> shardContexts;
|
||||
|
||||
|
@ -159,6 +168,7 @@ public class LocalExecutionPlanner {
|
|||
Supplier<ExchangeSink> exchangeSinkSupplier,
|
||||
EnrichLookupService enrichLookupService,
|
||||
LookupFromIndexService lookupFromIndexService,
|
||||
InferenceRunner inferenceRunner,
|
||||
PhysicalOperationProviders physicalOperationProviders,
|
||||
List<ShardContext> shardContexts
|
||||
) {
|
||||
|
@ -174,6 +184,7 @@ public class LocalExecutionPlanner {
|
|||
this.exchangeSinkSupplier = exchangeSinkSupplier;
|
||||
this.enrichLookupService = enrichLookupService;
|
||||
this.lookupFromIndexService = lookupFromIndexService;
|
||||
this.inferenceRunner = inferenceRunner;
|
||||
this.physicalOperationProviders = physicalOperationProviders;
|
||||
this.shardContexts = shardContexts;
|
||||
}
|
||||
|
@ -242,6 +253,8 @@ public class LocalExecutionPlanner {
|
|||
return planLimit(limit, context);
|
||||
} else if (node instanceof MvExpandExec mvExpand) {
|
||||
return planMvExpand(mvExpand, context);
|
||||
} else if (node instanceof RerankExec rerank) {
|
||||
return planRerank(rerank, context);
|
||||
} else if (node instanceof ChangePointExec changePoint) {
|
||||
return planChangePoint(changePoint, context);
|
||||
}
|
||||
|
@ -543,6 +556,36 @@ public class LocalExecutionPlanner {
|
|||
);
|
||||
}
|
||||
|
||||
private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerContext context) {
|
||||
PhysicalOperation source = plan(rerank.child(), context);
|
||||
|
||||
Map<ColumnInfoImpl, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers = new LinkedHashMap<>();
|
||||
|
||||
for (var rerankField : rerank.rerankFields()) {
|
||||
rerankFieldsEvaluatorSuppliers.put(
|
||||
new ColumnInfoImpl(rerankField.name(), rerankField.dataType(), null),
|
||||
EvalMapper.toEvaluator(context.foldCtx(), rerankField.child(), source.layout)
|
||||
);
|
||||
}
|
||||
|
||||
XContentRowEncoder.Factory rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
|
||||
|
||||
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
|
||||
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));
|
||||
|
||||
Layout outputLayout = source.layout;
|
||||
if (source.layout.get(rerank.scoreAttribute().id()) == null) {
|
||||
outputLayout = source.layout.builder().append(rerank.scoreAttribute()).build();
|
||||
}
|
||||
|
||||
int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel();
|
||||
|
||||
return source.with(
|
||||
new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel),
|
||||
outputLayout
|
||||
);
|
||||
}
|
||||
|
||||
private PhysicalOperation planHashJoin(HashJoinExec join, LocalExecutionPlannerContext context) {
|
||||
PhysicalOperation source = plan(join.left(), context);
|
||||
int positionsChannel = source.layout.numberOfChannels();
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TopN;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
|
||||
|
@ -38,6 +39,7 @@ import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
|
|||
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -173,6 +175,18 @@ public class Mapper {
|
|||
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
|
||||
}
|
||||
|
||||
if (unary instanceof Rerank rerank) {
|
||||
mappedChild = addExchangeForFragment(rerank, mappedChild);
|
||||
return new RerankExec(
|
||||
rerank.source(),
|
||||
mappedChild,
|
||||
rerank.inferenceId(),
|
||||
rerank.queryText(),
|
||||
rerank.rerankFields(),
|
||||
rerank.scoreAttribute()
|
||||
);
|
||||
}
|
||||
|
||||
//
|
||||
// Pipeline operators
|
||||
//
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
|
||||
|
@ -42,6 +43,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
|
|||
import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ShowExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
|
||||
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -86,6 +88,17 @@ class MapperUtils {
|
|||
return new GrokExec(grok.source(), child, grok.input(), grok.parser(), grok.extractedFields());
|
||||
}
|
||||
|
||||
if (p instanceof Rerank rerank) {
|
||||
return new RerankExec(
|
||||
rerank.source(),
|
||||
child,
|
||||
rerank.inferenceId(),
|
||||
rerank.queryText(),
|
||||
rerank.rerankFields(),
|
||||
rerank.scoreAttribute()
|
||||
);
|
||||
}
|
||||
|
||||
if (p instanceof Enrich enrich) {
|
||||
return new EnrichExec(
|
||||
enrich.source(),
|
||||
|
|
|
@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
|||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
|
||||
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
|
||||
|
@ -125,6 +126,7 @@ public class ComputeService {
|
|||
private final DriverTaskRunner driverRunner;
|
||||
private final EnrichLookupService enrichLookupService;
|
||||
private final LookupFromIndexService lookupFromIndexService;
|
||||
private final InferenceRunner inferenceRunner;
|
||||
private final ClusterService clusterService;
|
||||
private final AtomicLong childSessionIdGenerator = new AtomicLong();
|
||||
private final DataNodeComputeHandler dataNodeComputeHandler;
|
||||
|
@ -133,25 +135,24 @@ public class ComputeService {
|
|||
|
||||
@SuppressWarnings("this-escape")
|
||||
public ComputeService(
|
||||
SearchService searchService,
|
||||
TransportService transportService,
|
||||
ExchangeService exchangeService,
|
||||
TransportActionServices transportActionServices,
|
||||
EnrichLookupService enrichLookupService,
|
||||
LookupFromIndexService lookupFromIndexService,
|
||||
ClusterService clusterService,
|
||||
ThreadPool threadPool,
|
||||
BigArrays bigArrays,
|
||||
BlockFactory blockFactory
|
||||
) {
|
||||
this.searchService = searchService;
|
||||
this.transportService = transportService;
|
||||
this.searchService = transportActionServices.searchService();
|
||||
this.transportService = transportActionServices.transportService();
|
||||
this.exchangeService = transportActionServices.exchangeService();
|
||||
this.bigArrays = bigArrays.withCircuitBreaking();
|
||||
this.blockFactory = blockFactory;
|
||||
var esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH);
|
||||
this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor);
|
||||
this.enrichLookupService = enrichLookupService;
|
||||
this.lookupFromIndexService = lookupFromIndexService;
|
||||
this.clusterService = clusterService;
|
||||
this.inferenceRunner = transportActionServices.inferenceRunner();
|
||||
this.clusterService = transportActionServices.clusterService();
|
||||
this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor);
|
||||
this.clusterComputeHandler = new ClusterComputeHandler(
|
||||
this,
|
||||
|
@ -160,7 +161,6 @@ public class ComputeService {
|
|||
esqlExecutor,
|
||||
dataNodeComputeHandler
|
||||
);
|
||||
this.exchangeService = exchangeService;
|
||||
}
|
||||
|
||||
public void execute(
|
||||
|
@ -428,6 +428,7 @@ public class ComputeService {
|
|||
context.exchangeSinkSupplier(),
|
||||
enrichLookupService,
|
||||
lookupFromIndexService,
|
||||
inferenceRunner,
|
||||
new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()),
|
||||
contexts
|
||||
);
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.compute.operator.exchange.ExchangeService;
|
|||
import org.elasticsearch.search.SearchService;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.usage.UsageService;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
|
||||
public record TransportActionServices(
|
||||
TransportService transportService,
|
||||
|
@ -20,5 +21,6 @@ public record TransportActionServices(
|
|||
ExchangeService exchangeService,
|
||||
ClusterService clusterService,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver,
|
||||
UsageService usageService
|
||||
UsageService usageService,
|
||||
InferenceRunner inferenceRunner
|
||||
) {}
|
||||
|
|
|
@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
|
|||
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
|
||||
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
|
||||
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.session.Configuration;
|
||||
import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner;
|
||||
import org.elasticsearch.xpack.esql.session.Result;
|
||||
|
@ -126,17 +127,7 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
|
|||
bigArrays,
|
||||
blockFactoryProvider.blockFactory()
|
||||
);
|
||||
this.computeService = new ComputeService(
|
||||
searchService,
|
||||
transportService,
|
||||
exchangeService,
|
||||
enrichLookupService,
|
||||
lookupFromIndexService,
|
||||
clusterService,
|
||||
threadPool,
|
||||
bigArrays,
|
||||
blockFactoryProvider.blockFactory()
|
||||
);
|
||||
|
||||
this.asyncTaskManagementService = new AsyncTaskManagementService<>(
|
||||
XPackPlugin.ASYNC_RESULTS_INDEX,
|
||||
client,
|
||||
|
@ -159,8 +150,19 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
|
|||
exchangeService,
|
||||
clusterService,
|
||||
indexNameExpressionResolver,
|
||||
usageService
|
||||
usageService,
|
||||
new InferenceRunner(client)
|
||||
);
|
||||
|
||||
this.computeService = new ComputeService(
|
||||
services,
|
||||
enrichLookupService,
|
||||
lookupFromIndexService,
|
||||
threadPool,
|
||||
bigArrays,
|
||||
blockFactoryProvider.blockFactory()
|
||||
);
|
||||
|
||||
defaultAllowPartialResults = EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS.get(clusterService.getSettings());
|
||||
clusterService.getClusterSettings()
|
||||
.addSettingsUpdateConsumer(EsqlPlugin.QUERY_ALLOW_PARTIAL_RESULTS, v -> defaultAllowPartialResults = v);
|
||||
|
|
|
@ -50,6 +50,8 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
|
|||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.index.MappingException;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
|
||||
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
|
||||
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
|
||||
|
@ -63,6 +65,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.Project;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
|
||||
|
@ -119,6 +122,7 @@ public class EsqlSession {
|
|||
private final PlanTelemetry planTelemetry;
|
||||
private final IndicesExpressionGrouper indicesExpressionGrouper;
|
||||
private Set<String> configuredClusters;
|
||||
private final InferenceRunner inferenceRunner;
|
||||
|
||||
public EsqlSession(
|
||||
String sessionId,
|
||||
|
@ -146,6 +150,7 @@ public class EsqlSession {
|
|||
this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
|
||||
this.planTelemetry = planTelemetry;
|
||||
this.indicesExpressionGrouper = indicesExpressionGrouper;
|
||||
this.inferenceRunner = services.inferenceRunner();
|
||||
this.preMapper = new PreMapper(services);
|
||||
}
|
||||
|
||||
|
@ -335,7 +340,7 @@ public class EsqlSession {
|
|||
|
||||
Function<PreAnalysisResult, LogicalPlan> analyzeAction = (l) -> {
|
||||
Analyzer analyzer = new Analyzer(
|
||||
new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution),
|
||||
new AnalyzerContext(configuration, functionRegistry, l.indices, l.lookupIndices, l.enrichResolution, l.inferenceResolution),
|
||||
verifier
|
||||
);
|
||||
LogicalPlan plan = analyzer.analyze(parsed);
|
||||
|
@ -367,7 +372,9 @@ public class EsqlSession {
|
|||
|
||||
var listener = SubscribableListener.<EnrichResolution>newForked(
|
||||
l -> enrichPolicyResolver.resolvePolicies(targetClusters, unresolvedPolicies, l)
|
||||
).<PreAnalysisResult>andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l));
|
||||
)
|
||||
.<PreAnalysisResult>andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
|
||||
.<PreAnalysisResult>andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
|
||||
// first resolve the lookup indices, then the main indices
|
||||
for (var index : preAnalysis.lookupIndices) {
|
||||
listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); });
|
||||
|
@ -580,6 +587,14 @@ public class EsqlSession {
|
|||
}
|
||||
}
|
||||
|
||||
private void resolveInferences(
|
||||
List<InferencePlan> inferencePlans,
|
||||
PreAnalysisResult preAnalysisResult,
|
||||
ActionListener<PreAnalysisResult> l
|
||||
) {
|
||||
inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
|
||||
}
|
||||
|
||||
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set<String> enrichPolicyMatchFields, PreAnalysisResult result) {
|
||||
if (false == parsed.anyMatch(plan -> plan instanceof Aggregate || plan instanceof Project)) {
|
||||
// no explicit columns selection, for example "from employees"
|
||||
|
@ -746,18 +761,44 @@ public class EsqlSession {
|
|||
Map<String, IndexResolution> lookupIndices,
|
||||
EnrichResolution enrichResolution,
|
||||
Set<String> fieldNames,
|
||||
Set<String> wildcardJoinIndices
|
||||
Set<String> wildcardJoinIndices,
|
||||
InferenceResolution inferenceResolution
|
||||
) {
|
||||
PreAnalysisResult(EnrichResolution newEnrichResolution) {
|
||||
this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of());
|
||||
this(null, new HashMap<>(), newEnrichResolution, Set.of(), Set.of(), InferenceResolution.EMPTY);
|
||||
}
|
||||
|
||||
PreAnalysisResult withEnrichResolution(EnrichResolution newEnrichResolution) {
|
||||
return new PreAnalysisResult(indices(), lookupIndices(), newEnrichResolution, fieldNames(), wildcardJoinIndices());
|
||||
return new PreAnalysisResult(
|
||||
indices(),
|
||||
lookupIndices(),
|
||||
newEnrichResolution,
|
||||
fieldNames(),
|
||||
wildcardJoinIndices(),
|
||||
inferenceResolution()
|
||||
);
|
||||
}
|
||||
|
||||
PreAnalysisResult withInferenceResolution(InferenceResolution newInferenceResolution) {
|
||||
return new PreAnalysisResult(
|
||||
indices(),
|
||||
lookupIndices(),
|
||||
enrichResolution(),
|
||||
fieldNames(),
|
||||
wildcardJoinIndices(),
|
||||
newInferenceResolution
|
||||
);
|
||||
}
|
||||
|
||||
PreAnalysisResult withIndexResolution(IndexResolution newIndexResolution) {
|
||||
return new PreAnalysisResult(newIndexResolution, lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices());
|
||||
return new PreAnalysisResult(
|
||||
newIndexResolution,
|
||||
lookupIndices(),
|
||||
enrichResolution(),
|
||||
fieldNames(),
|
||||
wildcardJoinIndices(),
|
||||
inferenceResolution()
|
||||
);
|
||||
}
|
||||
|
||||
PreAnalysisResult addLookupIndexResolution(String index, IndexResolution newIndexResolution) {
|
||||
|
@ -766,11 +807,25 @@ public class EsqlSession {
|
|||
}
|
||||
|
||||
PreAnalysisResult withFieldNames(Set<String> newFields) {
|
||||
return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), newFields, wildcardJoinIndices());
|
||||
return new PreAnalysisResult(
|
||||
indices(),
|
||||
lookupIndices(),
|
||||
enrichResolution(),
|
||||
newFields,
|
||||
wildcardJoinIndices(),
|
||||
inferenceResolution()
|
||||
);
|
||||
}
|
||||
|
||||
public PreAnalysisResult withWildcardJoinIndices(Set<String> wildcardJoinIndices) {
|
||||
return new PreAnalysisResult(indices(), lookupIndices(), enrichResolution(), fieldNames(), wildcardJoinIndices);
|
||||
return new PreAnalysisResult(
|
||||
indices(),
|
||||
lookupIndices(),
|
||||
enrichResolution(),
|
||||
fieldNames(),
|
||||
wildcardJoinIndices,
|
||||
inferenceResolution()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
|
|||
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer;
|
||||
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
|
||||
|
@ -119,6 +120,7 @@ import static org.elasticsearch.xpack.esql.CsvTestUtils.loadPageFromCsv;
|
|||
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.everyItem;
|
||||
|
@ -261,6 +263,10 @@ public class CsvTests extends ESTestCase {
|
|||
"enrich can't load fields in csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.ENRICH_LOAD.capabilityName())
|
||||
);
|
||||
assumeFalse(
|
||||
"can't use rereank in csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.RERANK.capabilityName())
|
||||
);
|
||||
assumeFalse(
|
||||
"can't use match in csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.MATCH_OPERATOR_COLON.capabilityName())
|
||||
|
@ -478,7 +484,10 @@ public class CsvTests extends ESTestCase {
|
|||
private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.MultiIndexTestDataset datasets) {
|
||||
var indexResolution = loadIndexResolution(datasets);
|
||||
var enrichPolicies = loadEnrichPolicies();
|
||||
var analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies), TEST_VERIFIER);
|
||||
var analyzer = new Analyzer(
|
||||
new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies, emptyInferenceResolution()),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
LogicalPlan plan = analyzer.analyze(parsed);
|
||||
plan.setAnalyzed();
|
||||
LOGGER.debug("Analyzed plan:\n{}", plan);
|
||||
|
@ -666,6 +675,7 @@ public class CsvTests extends ESTestCase {
|
|||
() -> exchangeSink.createExchangeSink(() -> {}),
|
||||
Mockito.mock(EnrichLookupService.class),
|
||||
Mockito.mock(LookupFromIndexService.class),
|
||||
Mockito.mock(InferenceRunner.class),
|
||||
physicalOperationProviders,
|
||||
List.of()
|
||||
);
|
||||
|
|
|
@ -8,12 +8,15 @@
|
|||
package org.elasticsearch.xpack.esql.analysis;
|
||||
|
||||
import org.elasticsearch.index.IndexMode;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
|
||||
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
||||
import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy;
|
||||
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
|
||||
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
|
||||
import org.elasticsearch.xpack.esql.parser.EsqlParser;
|
||||
import org.elasticsearch.xpack.esql.parser.QueryParams;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
|
||||
|
@ -29,6 +32,7 @@ import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.MATCH_TYPE;
|
|||
import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.RANGE_TYPE;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
|
||||
public final class AnalyzerTestUtils {
|
||||
|
||||
|
@ -57,7 +61,8 @@ public final class AnalyzerTestUtils {
|
|||
new EsqlFunctionRegistry(),
|
||||
indexResolution,
|
||||
defaultLookupResolution(),
|
||||
defaultEnrichResolution()
|
||||
defaultEnrichResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
verifier
|
||||
);
|
||||
|
@ -70,7 +75,8 @@ public final class AnalyzerTestUtils {
|
|||
new EsqlFunctionRegistry(),
|
||||
indexResolution,
|
||||
lookupResolution,
|
||||
defaultEnrichResolution()
|
||||
defaultEnrichResolution(),
|
||||
defaultInferenceResolution()
|
||||
),
|
||||
verifier
|
||||
);
|
||||
|
@ -78,7 +84,14 @@ public final class AnalyzerTestUtils {
|
|||
|
||||
public static Analyzer analyzer(IndexResolution indexResolution, Verifier verifier, Configuration config) {
|
||||
return new Analyzer(
|
||||
new AnalyzerContext(config, new EsqlFunctionRegistry(), indexResolution, defaultLookupResolution(), defaultEnrichResolution()),
|
||||
new AnalyzerContext(
|
||||
config,
|
||||
new EsqlFunctionRegistry(),
|
||||
indexResolution,
|
||||
defaultLookupResolution(),
|
||||
defaultEnrichResolution(),
|
||||
defaultInferenceResolution()
|
||||
),
|
||||
verifier
|
||||
);
|
||||
}
|
||||
|
@ -90,7 +103,8 @@ public final class AnalyzerTestUtils {
|
|||
new EsqlFunctionRegistry(),
|
||||
analyzerDefaultMapping(),
|
||||
defaultLookupResolution(),
|
||||
defaultEnrichResolution()
|
||||
defaultEnrichResolution(),
|
||||
defaultInferenceResolution()
|
||||
),
|
||||
verifier
|
||||
);
|
||||
|
@ -162,6 +176,14 @@ public final class AnalyzerTestUtils {
|
|||
return enrichResolution;
|
||||
}
|
||||
|
||||
public static InferenceResolution defaultInferenceResolution() {
|
||||
return InferenceResolution.builder()
|
||||
.withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
|
||||
.withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
|
||||
.withError("error-inference-id", "error with inference resolution")
|
||||
.build();
|
||||
}
|
||||
|
||||
public static void loadEnrichPolicyResolution(
|
||||
EnrichResolution enrich,
|
||||
String policyType,
|
||||
|
|
|
@ -45,6 +45,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
|
|||
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
|
@ -67,6 +68,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.Row;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
|
||||
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
|
||||
import org.elasticsearch.xpack.esql.session.IndexResolver;
|
||||
|
@ -86,6 +88,7 @@ import static org.elasticsearch.test.MapMatcher.assertMap;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsIdentifier;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsPattern;
|
||||
|
@ -108,6 +111,7 @@ import static org.hamcrest.Matchers.instanceOf;
|
|||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.matchesRegex;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
//@TestLogging(value = "org.elasticsearch.xpack.esql.analysis:TRACE", reason = "debug")
|
||||
|
@ -1626,7 +1630,13 @@ public class AnalyzerTests extends ESTestCase {
|
|||
enrichResolution.addError("languages", Enrich.Mode.ANY, "error-2");
|
||||
enrichResolution.addError("foo", Enrich.Mode.ANY, "foo-error-101");
|
||||
|
||||
AnalyzerContext context = new AnalyzerContext(configuration("from test"), new EsqlFunctionRegistry(), testIndex, enrichResolution);
|
||||
AnalyzerContext context = new AnalyzerContext(
|
||||
configuration("from test"),
|
||||
new EsqlFunctionRegistry(),
|
||||
testIndex,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
);
|
||||
Analyzer analyzer = new Analyzer(context, TEST_VERIFIER);
|
||||
{
|
||||
LogicalPlan plan = analyze("from test | EVAL x = to_string(languages) | ENRICH _coordinator:languages ON x", analyzer);
|
||||
|
@ -1776,7 +1786,13 @@ public class AnalyzerTests extends ESTestCase {
|
|||
languageIndex.get().mapping()
|
||||
)
|
||||
);
|
||||
AnalyzerContext context = new AnalyzerContext(configuration(query), new EsqlFunctionRegistry(), testIndex, enrichResolution);
|
||||
AnalyzerContext context = new AnalyzerContext(
|
||||
configuration(query),
|
||||
new EsqlFunctionRegistry(),
|
||||
testIndex,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
);
|
||||
Analyzer analyzer = new Analyzer(context, TEST_VERIFIER);
|
||||
LogicalPlan plan = analyze(query, analyzer);
|
||||
var limit = as(plan, Limit.class);
|
||||
|
@ -2172,7 +2188,8 @@ public class AnalyzerTests extends ESTestCase {
|
|||
new EsqlFunctionRegistry(),
|
||||
analyzerDefaultMapping(),
|
||||
Map.of("foobar", missingLookupIndex),
|
||||
defaultEnrichResolution()
|
||||
defaultEnrichResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
@ -3267,6 +3284,193 @@ public class AnalyzerTests extends ESTestCase {
|
|||
assertThat(esRelation.output(), equalTo(NO_FIELDS));
|
||||
}
|
||||
|
||||
public void testResolveRerankInferenceId() {
|
||||
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
{
|
||||
LogicalPlan plan = analyze(
|
||||
" FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
|
||||
"mapping-books.json"
|
||||
);
|
||||
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
|
||||
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
|
||||
}
|
||||
|
||||
{
|
||||
VerificationException ve = expectThrows(
|
||||
VerificationException.class,
|
||||
() -> analyze(
|
||||
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `completion-inference-id`",
|
||||
"mapping-books.json"
|
||||
)
|
||||
|
||||
);
|
||||
assertThat(
|
||||
ve.getMessage(),
|
||||
containsString(
|
||||
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a Rerank command. "
|
||||
+ "Only inference endpoints with the task type [rerank] are supported"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
VerificationException ve = expectThrows(
|
||||
VerificationException.class,
|
||||
() -> analyze(
|
||||
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `error-inference-id`",
|
||||
"mapping-books.json"
|
||||
)
|
||||
|
||||
);
|
||||
assertThat(ve.getMessage(), containsString("error with inference resolution"));
|
||||
}
|
||||
|
||||
{
|
||||
VerificationException ve = expectThrows(
|
||||
VerificationException.class,
|
||||
() -> analyze(
|
||||
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `unknown-inference-id`",
|
||||
"mapping-books.json"
|
||||
)
|
||||
|
||||
);
|
||||
assertThat(ve.getMessage(), containsString("unresolved inference [unknown-inference-id]"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testResolveRerankFields() {
|
||||
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
{
|
||||
// Single field.
|
||||
LogicalPlan plan = analyze("""
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"italian food recipe" OR description:"italian food recipe"
|
||||
| KEEP description, title, year, _score
|
||||
| DROP description
|
||||
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
|
||||
""", "mapping-books.json");
|
||||
|
||||
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
|
||||
Rerank rerank = as(limit.child(), Rerank.class);
|
||||
EsqlProject keep = as(rerank.child(), EsqlProject.class);
|
||||
EsqlProject drop = as(keep.child(), EsqlProject.class);
|
||||
Filter filter = as(drop.child(), Filter.class);
|
||||
EsRelation relation = as(filter.child(), EsRelation.class);
|
||||
|
||||
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
|
||||
assertThat(titleAttribute, notNullValue());
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
|
||||
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
|
||||
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
|
||||
assertThat(
|
||||
rerank.scoreAttribute(),
|
||||
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Multiple fields.
|
||||
LogicalPlan plan = analyze("""
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"food"
|
||||
| RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year WITH `reranking-inference-id`
|
||||
""", "mapping-books.json");
|
||||
|
||||
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
|
||||
Rerank rerank = as(limit.child(), Rerank.class);
|
||||
Filter filter = as(rerank.child(), Filter.class);
|
||||
EsRelation relation = as(filter.child(), EsRelation.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(string("food")));
|
||||
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
|
||||
|
||||
assertThat(rerank.rerankFields(), hasSize(3));
|
||||
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
|
||||
assertThat(titleAttribute, notNullValue());
|
||||
assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute)));
|
||||
|
||||
Attribute descriptionAttribute = relation.output()
|
||||
.stream()
|
||||
.filter(attribute -> attribute.name().equals("description"))
|
||||
.findFirst()
|
||||
.get();
|
||||
assertThat(descriptionAttribute, notNullValue());
|
||||
Alias descriptionAlias = rerank.rerankFields().get(1);
|
||||
assertThat(descriptionAlias.name(), equalTo("description"));
|
||||
assertThat(
|
||||
as(descriptionAlias.child(), Substring.class).children(),
|
||||
equalTo(List.of(descriptionAttribute, literal(0), literal(100)))
|
||||
);
|
||||
|
||||
Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
|
||||
assertThat(yearAttribute, notNullValue());
|
||||
assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
|
||||
assertThat(
|
||||
rerank.scoreAttribute(),
|
||||
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
VerificationException ve = expectThrows(
|
||||
VerificationException.class,
|
||||
() -> analyze(
|
||||
"FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH `reranking-inference-id`",
|
||||
"mapping-books.json"
|
||||
)
|
||||
|
||||
);
|
||||
assertThat(ve.getMessage(), containsString("Unknown column [missingField]"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testResolveRerankScoreField() {
|
||||
assumeTrue("Requires RERANK command", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
{
|
||||
// When the metadata field is required in FROM, it is reused.
|
||||
LogicalPlan plan = analyze("""
|
||||
FROM books METADATA _score
|
||||
| WHERE title:"italian food recipe" OR description:"italian food recipe"
|
||||
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
|
||||
""", "mapping-books.json");
|
||||
|
||||
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
|
||||
Rerank rerank = as(limit.child(), Rerank.class);
|
||||
Filter filter = as(rerank.child(), Filter.class);
|
||||
EsRelation relation = as(filter.child(), EsRelation.class);
|
||||
|
||||
Attribute metadataScoreAttribute = relation.output()
|
||||
.stream()
|
||||
.filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
|
||||
.findFirst()
|
||||
.get();
|
||||
assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
|
||||
assertThat(rerank.output(), hasItem(metadataScoreAttribute));
|
||||
}
|
||||
|
||||
{
|
||||
// When the metadata field is not required in FROM, it is added to the output of RERANK
|
||||
LogicalPlan plan = analyze("""
|
||||
FROM books
|
||||
| WHERE title:"italian food recipe" OR description:"italian food recipe"
|
||||
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
|
||||
""", "mapping-books.json");
|
||||
|
||||
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
|
||||
Rerank rerank = as(limit.child(), Rerank.class);
|
||||
Filter filter = as(rerank.child(), Filter.class);
|
||||
EsRelation relation = as(filter.child(), EsRelation.class);
|
||||
|
||||
assertThat(relation.output().stream().noneMatch(attr -> attr.name().equals(MetadataAttribute.SCORE)), is(true));
|
||||
assertThat(rerank.scoreAttribute(), equalTo(MetadataAttribute.create(EMPTY, MetadataAttribute.SCORE)));
|
||||
assertThat(rerank.output(), hasItem(rerank.scoreAttribute()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IndexAnalyzers createDefaultIndexAnalyzers() {
|
||||
return super.createDefaultIndexAnalyzers();
|
||||
|
|
|
@ -35,6 +35,7 @@ import java.util.List;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
@ -45,7 +46,7 @@ public class ParsingTests extends ESTestCase {
|
|||
|
||||
private final IndexResolution defaultIndex = loadIndexResolution("mapping-basic.json");
|
||||
private final Analyzer defaultAnalyzer = new Analyzer(
|
||||
new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution()),
|
||||
new AnalyzerContext(TEST_CFG, new EsqlFunctionRegistry(), defaultIndex, emptyPolicyResolution(), emptyInferenceResolution()),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.telemetry.Metrics;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
|
||||
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
@ -90,7 +91,13 @@ public class CheckLicenseTests extends ESTestCase {
|
|||
|
||||
private static Analyzer analyzer(EsqlFunctionRegistry registry, License.OperationMode operationMode) {
|
||||
return new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, registry, analyzerDefaultMapping(), defaultEnrichResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
registry,
|
||||
analyzerDefaultMapping(),
|
||||
defaultEnrichResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
new Verifier(new Metrics(new EsqlFunctionRegistry()), getLicenseState(operationMode))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class InferenceRunnerTests extends ESTestCase {
|
||||
public void testResolveInferenceIds() throws Exception {
|
||||
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
|
||||
List<InferencePlan> inferencePlans = List.of(mockInferencePlan("rerank-plan"));
|
||||
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
|
||||
|
||||
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
|
||||
throw new RuntimeException(e);
|
||||
}));
|
||||
|
||||
assertBusy(() -> {
|
||||
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
|
||||
assertNotNull(inferenceResolution);
|
||||
assertThat(inferenceResolution.resolvedInferences(), contains(new ResolvedInference("rerank-plan", TaskType.RERANK)));
|
||||
assertThat(inferenceResolution.hasError(), equalTo(false));
|
||||
});
|
||||
}
|
||||
|
||||
public void testResolveMultipleInferenceIds() throws Exception {
|
||||
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
|
||||
List<InferencePlan> inferencePlans = List.of(
|
||||
mockInferencePlan("rerank-plan"),
|
||||
mockInferencePlan("rerank-plan"),
|
||||
mockInferencePlan("completion-plan")
|
||||
);
|
||||
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
|
||||
|
||||
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
|
||||
throw new RuntimeException(e);
|
||||
}));
|
||||
|
||||
assertBusy(() -> {
|
||||
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
|
||||
assertNotNull(inferenceResolution);
|
||||
|
||||
assertThat(
|
||||
inferenceResolution.resolvedInferences(),
|
||||
contains(
|
||||
new ResolvedInference("rerank-plan", TaskType.RERANK),
|
||||
new ResolvedInference("completion-plan", TaskType.COMPLETION)
|
||||
)
|
||||
);
|
||||
assertThat(inferenceResolution.hasError(), equalTo(false));
|
||||
});
|
||||
}
|
||||
|
||||
public void testResolveMissingInferenceIds() throws Exception {
|
||||
InferenceRunner inferenceRunner = new InferenceRunner(mockClient());
|
||||
List<InferencePlan> inferencePlans = List.of(mockInferencePlan("missing-plan"));
|
||||
|
||||
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
|
||||
|
||||
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
|
||||
throw new RuntimeException(e);
|
||||
}));
|
||||
|
||||
assertBusy(() -> {
|
||||
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
|
||||
assertNotNull(inferenceResolution);
|
||||
|
||||
assertThat(inferenceResolution.resolvedInferences(), empty());
|
||||
assertThat(inferenceResolution.hasError(), equalTo(true));
|
||||
assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found"));
|
||||
});
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked", "raw-types" })
|
||||
private static Client mockClient() {
|
||||
Client client = mock(Client.class);
|
||||
doAnswer(i -> {
|
||||
GetInferenceModelAction.Request request = i.getArgument(1, GetInferenceModelAction.Request.class);
|
||||
ActionListener<ActionResponse> listener = (ActionListener<ActionResponse>) i.getArgument(2, ActionListener.class);
|
||||
ActionResponse response = getInferenceModelResponse(request);
|
||||
|
||||
if (response == null) {
|
||||
listener.onFailure(new ResourceNotFoundException("inference endpoint not found"));
|
||||
} else {
|
||||
listener.onResponse(response);
|
||||
}
|
||||
|
||||
return null;
|
||||
}).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any());
|
||||
return client;
|
||||
}
|
||||
|
||||
private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) {
|
||||
GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class);
|
||||
|
||||
if (request.getInferenceEntityId().equals("rerank-plan")) {
|
||||
when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("rerank-plan", TaskType.RERANK)));
|
||||
return response;
|
||||
}
|
||||
|
||||
if (request.getInferenceEntityId().equals("completion-plan")) {
|
||||
when(response.getEndpoints()).thenReturn(List.of(mockModelConfig("completion-plan", TaskType.COMPLETION)));
|
||||
return response;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {
|
||||
return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class));
|
||||
}
|
||||
|
||||
private static InferencePlan mockInferencePlan(String inferenceId) {
|
||||
InferencePlan plan = mock(InferencePlan.class);
|
||||
when(plan.inferenceId()).thenReturn(new Literal(Source.EMPTY, inferenceId, DataType.KEYWORD));
|
||||
return plan;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.logging.LoggerMessageFormat;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||
import org.elasticsearch.compute.data.Block;
|
||||
import org.elasticsearch.compute.data.BlockFactory;
|
||||
import org.elasticsearch.compute.data.BooleanBlock;
|
||||
import org.elasticsearch.compute.data.BytesRefBlock;
|
||||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.ElementType;
|
||||
import org.elasticsearch.compute.data.FloatBlock;
|
||||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.data.Page;
|
||||
import org.elasticsearch.compute.operator.AsyncOperator;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.compute.operator.Operator;
|
||||
import org.elasticsearch.compute.operator.SourceOperator;
|
||||
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
|
||||
import org.elasticsearch.compute.test.OperatorTestCase;
|
||||
import org.elasticsearch.compute.test.RandomBlock;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
import org.elasticsearch.threadpool.FixedExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class RerankOperatorTests extends OperatorTestCase {
|
||||
private static final String ESQL_TEST_EXECUTOR = "esql_test_executor";
|
||||
private static final String SIMPLE_INFERENCE_ID = "test_reranker";
|
||||
private static final String SIMPLE_QUERY = "query text";
|
||||
private ThreadPool threadPool;
|
||||
private List<ElementType> inputChannelElementTypes;
|
||||
private XContentRowEncoder.Factory rowEncoderFactory;
|
||||
private int scoreChannel;
|
||||
|
||||
@Before
|
||||
private void initChannels() {
|
||||
int channelCount = randomIntBetween(2, 10);
|
||||
scoreChannel = randomIntBetween(0, channelCount - 1);
|
||||
inputChannelElementTypes = IntStream.range(0, channelCount).sorted().mapToObj(this::randomElementType).collect(Collectors.toList());
|
||||
rowEncoderFactory = mockRowEncoderFactory();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setThreadPool() {
|
||||
int numThreads = randomBoolean() ? 1 : between(2, 16);
|
||||
threadPool = new TestThreadPool(
|
||||
"test",
|
||||
new FixedExecutorBuilder(Settings.EMPTY, ESQL_TEST_EXECUTOR, numThreads, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT)
|
||||
);
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdownThreadPool() {
|
||||
terminate(threadPool);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Operator.OperatorFactory simple() {
|
||||
InferenceRunner inferenceRunner = mockedSimpleInferenceRunner();
|
||||
return new RerankOperator.Factory(inferenceRunner, SIMPLE_INFERENCE_ID, SIMPLE_QUERY, rowEncoderFactory, scoreChannel);
|
||||
}
|
||||
|
||||
private InferenceRunner mockedSimpleInferenceRunner() {
|
||||
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
|
||||
when(inferenceRunner.getThreadContext()).thenReturn(threadPool.getThreadContext());
|
||||
doAnswer(invocation -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<InferenceAction.Response> listener = (ActionListener<InferenceAction.Response>) invocation.getArgument(
|
||||
1,
|
||||
ActionListener.class
|
||||
);
|
||||
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
|
||||
when(inferenceResponse.getResults()).thenReturn(
|
||||
mockedRankedDocResults(invocation.getArgument(0, InferenceAction.Request.class))
|
||||
);
|
||||
listener.onResponse(inferenceResponse);
|
||||
return null;
|
||||
}).when(inferenceRunner).doInference(any(), any());
|
||||
|
||||
return inferenceRunner;
|
||||
}
|
||||
|
||||
private RankedDocsResults mockedRankedDocResults(InferenceAction.Request request) {
|
||||
List<RankedDocsResults.RankedDoc> rankedDocs = new ArrayList<>();
|
||||
for (int rank = 0; rank < request.getInput().size(); rank++) {
|
||||
if (rank % 10 != 0) {
|
||||
rankedDocs.add(new RankedDocsResults.RankedDoc(rank, 1f / rank, request.getInput().get(rank)));
|
||||
}
|
||||
}
|
||||
return new RankedDocsResults(rankedDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Matcher<String> expectedDescriptionOfSimple() {
|
||||
return expectedToStringOfSimple();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Matcher<String> expectedToStringOfSimple() {
|
||||
return equalTo(
|
||||
"RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]"
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
|
||||
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
|
||||
@Override
|
||||
protected int remaining() {
|
||||
return size - currentPosition;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Page createPage(int positionOffset, int length) {
|
||||
Block[] blocks = new Block[inputChannelElementTypes.size()];
|
||||
try {
|
||||
currentPosition += length;
|
||||
for (int b = 0; b < inputChannelElementTypes.size(); b++) {
|
||||
blocks[b] = RandomBlock.randomBlock(
|
||||
blockFactory,
|
||||
inputChannelElementTypes.get(b),
|
||||
length,
|
||||
randomBoolean(),
|
||||
0,
|
||||
10,
|
||||
0,
|
||||
10
|
||||
).block();
|
||||
}
|
||||
return new Page(blocks);
|
||||
} catch (Exception e) {
|
||||
Releasables.closeExpectNoException(blocks);
|
||||
throw (e);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the Operator.Status of this operator has the standard fields.
|
||||
*/
|
||||
public void testOperatorStatus() throws IOException {
|
||||
DriverContext driverContext = driverContext();
|
||||
try (var operator = simple().get(driverContext)) {
|
||||
AsyncOperator.Status status = asInstanceOf(AsyncOperator.Status.class, operator.status());
|
||||
|
||||
assertThat(status, notNullValue());
|
||||
assertThat(status.receivedPages(), equalTo(0L));
|
||||
assertThat(status.completedPages(), equalTo(0L));
|
||||
assertThat(status.procesNanos(), greaterThanOrEqualTo(0L));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void assertSimpleOutput(List<Page> inputPages, List<Page> resultPages) {
|
||||
assertThat(inputPages, hasSize(resultPages.size()));
|
||||
|
||||
for (int pageId = 0; pageId < inputPages.size(); pageId++) {
|
||||
Page inputPage = inputPages.get(pageId);
|
||||
Page resultPage = resultPages.get(pageId);
|
||||
|
||||
// Check all rows are present and the output shape is unchanged.
|
||||
assertThat(inputPage.getPositionCount(), equalTo(resultPage.getPositionCount()));
|
||||
assertThat(inputPage.getBlockCount(), equalTo(resultPage.getBlockCount()));
|
||||
|
||||
BytesRef readBuffer = new BytesRef();
|
||||
|
||||
for (int channel = 0; channel < inputPage.getBlockCount(); channel++) {
|
||||
Block inputBlock = inputPage.getBlock(channel);
|
||||
Block resultBlock = resultPage.getBlock(channel);
|
||||
|
||||
assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount()));
|
||||
assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType()));
|
||||
|
||||
if (channel == scoreChannel) {
|
||||
assertExpectedScore(asInstanceOf(DoubleBlock.class, resultBlock));
|
||||
} else {
|
||||
switch (inputBlock.elementType()) {
|
||||
case BOOLEAN -> assertBlockContentEquals(inputBlock, resultBlock, BooleanBlock::getBoolean, BooleanBlock.class);
|
||||
case INT -> assertBlockContentEquals(inputBlock, resultBlock, IntBlock::getInt, IntBlock.class);
|
||||
case LONG -> assertBlockContentEquals(inputBlock, resultBlock, LongBlock::getLong, LongBlock.class);
|
||||
case FLOAT -> assertBlockContentEquals(inputBlock, resultBlock, FloatBlock::getFloat, FloatBlock.class);
|
||||
case DOUBLE -> assertBlockContentEquals(inputBlock, resultBlock, DoubleBlock::getDouble, DoubleBlock.class);
|
||||
case BYTES_REF -> assertByteRefsBlockContentEquals(inputBlock, resultBlock, readBuffer);
|
||||
default -> throw new AssertionError(
|
||||
LoggerMessageFormat.format("Unexpected block type {}", inputBlock.elementType())
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private int inputChannelCount() {
|
||||
return inputChannelElementTypes.size();
|
||||
}
|
||||
|
||||
private ElementType randomElementType(int channel) {
|
||||
return channel == scoreChannel ? ElementType.DOUBLE : randomFrom(ElementType.FLOAT, ElementType.DOUBLE, ElementType.LONG);
|
||||
}
|
||||
|
||||
private XContentRowEncoder.Factory mockRowEncoderFactory() {
|
||||
XContentRowEncoder.Factory factory = mock(XContentRowEncoder.Factory.class);
|
||||
doAnswer(factoryInvocation -> {
|
||||
DriverContext driverContext = factoryInvocation.getArgument(0, DriverContext.class);
|
||||
XContentRowEncoder rowEncoder = mock(XContentRowEncoder.class);
|
||||
doAnswer(encoderInvocation -> {
|
||||
Page inputPage = encoderInvocation.getArgument(0, Page.class);
|
||||
return driverContext.blockFactory()
|
||||
.newConstantBytesRefBlockWith(new BytesRef(randomRealisticUnicodeOfCodepointLength(4)), inputPage.getPositionCount());
|
||||
}).when(rowEncoder).eval(any(Page.class));
|
||||
|
||||
return rowEncoder;
|
||||
}).when(factory).get(any(DriverContext.class));
|
||||
|
||||
return factory;
|
||||
}
|
||||
|
||||
private void assertExpectedScore(DoubleBlock scoreBlockResult) {
|
||||
assertAllPositions(scoreBlockResult, (pos) -> {
|
||||
if (pos % 10 == 0) {
|
||||
assertThat(scoreBlockResult.isNull(pos), equalTo(true));
|
||||
} else {
|
||||
assertThat(scoreBlockResult.getValueCount(pos), equalTo(1));
|
||||
assertThat(scoreBlockResult.getDouble(scoreBlockResult.getFirstValueIndex(pos)), equalTo((double) (1f / pos)));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
<V extends Block, U> void assertBlockContentEquals(
|
||||
Block input,
|
||||
Block result,
|
||||
BiFunction<V, Integer, U> valueReader,
|
||||
Class<V> blockClass
|
||||
) {
|
||||
V inputBlock = asInstanceOf(blockClass, input);
|
||||
V resultBlock = asInstanceOf(blockClass, result);
|
||||
|
||||
assertAllPositions(inputBlock, (pos) -> {
|
||||
if (inputBlock.isNull(pos)) {
|
||||
assertThat(resultBlock.isNull(pos), equalTo(inputBlock.isNull(pos)));
|
||||
} else {
|
||||
assertThat(resultBlock.getValueCount(pos), equalTo(inputBlock.getValueCount(pos)));
|
||||
assertThat(resultBlock.getFirstValueIndex(pos), equalTo(inputBlock.getFirstValueIndex(pos)));
|
||||
for (int i = 0; i < inputBlock.getValueCount(pos); i++) {
|
||||
assertThat(
|
||||
valueReader.apply(resultBlock, resultBlock.getFirstValueIndex(pos) + i),
|
||||
equalTo(valueReader.apply(inputBlock, inputBlock.getFirstValueIndex(pos) + i))
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void assertAllPositions(Block block, Consumer<Integer> consumer) {
|
||||
for (int pos = 0; pos < block.getPositionCount(); pos++) {
|
||||
consumer.accept(pos);
|
||||
}
|
||||
}
|
||||
|
||||
private <V extends Block, U> void assertByteRefsBlockContentEquals(Block input, Block result, BytesRef readBuffer) {
|
||||
assertBlockContentEquals(input, result, (BytesRefBlock b, Integer pos) -> b.getBytesRef(pos, readBuffer), BytesRefBlock.class);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.inference;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.AbstractWireTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ResolvedInferenceTests extends AbstractWireTestCase<ResolvedInference> {
|
||||
|
||||
@Override
|
||||
protected ResolvedInference createTestInstance() {
|
||||
return new ResolvedInference(randomIdentifier(), randomTaskType());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ResolvedInference mutateInstance(ResolvedInference instance) throws IOException {
|
||||
if (randomBoolean()) {
|
||||
return new ResolvedInference(randomValueOtherThan(instance.inferenceId(), ESTestCase::randomIdentifier), instance.taskType());
|
||||
}
|
||||
|
||||
return new ResolvedInference(instance.inferenceId(), randomValueOtherThan(instance.taskType(), this::randomTaskType));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ResolvedInference copyInstance(ResolvedInference instance, TransportVersion version) throws IOException {
|
||||
return copyInstance(instance, getNamedWriteableRegistry(), (out, v) -> v.writeTo(out), in -> new ResolvedInference(in), version);
|
||||
}
|
||||
|
||||
private TaskType randomTaskType() {
|
||||
return randomFrom(TaskType.values());
|
||||
}
|
||||
}
|
|
@ -66,6 +66,8 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
|
@ -98,7 +100,13 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
|
|||
logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
|
||||
|
||||
analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
}
|
||||
|
@ -449,7 +457,13 @@ public class LocalLogicalPlanOptimizerTests extends ESTestCase {
|
|||
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
|
||||
|
||||
var analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
|
|
@ -101,6 +101,7 @@ import static org.elasticsearch.index.query.QueryBuilders.termQuery;
|
|||
import static org.elasticsearch.index.query.QueryBuilders.termsQuery;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
|
||||
|
@ -184,7 +185,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
|
|||
IndexResolution getIndexResult = IndexResolution.valid(test);
|
||||
|
||||
return new Analyzer(
|
||||
new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution),
|
||||
new AnalyzerContext(config, new EsqlFunctionRegistry(), getIndexResult, enrichResolution, emptyInferenceResolution()),
|
||||
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -155,6 +155,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.asLimit;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptySource;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
|
||||
|
@ -248,7 +249,8 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
defaultLookupResolution(),
|
||||
enrichResolution
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
@ -258,7 +260,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
EsIndex airports = new EsIndex("airports", mappingAirports, Map.of("airports", IndexMode.STANDARD));
|
||||
IndexResolution getIndexResultAirports = IndexResolution.valid(airports);
|
||||
analyzerAirports = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResultAirports,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
@ -267,7 +275,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
EsIndex types = new EsIndex("types", mappingTypes, Map.of("types", IndexMode.STANDARD));
|
||||
IndexResolution getIndexResultTypes = IndexResolution.valid(types);
|
||||
analyzerTypes = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResultTypes,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
@ -276,14 +290,26 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
EsIndex extra = new EsIndex("extra", mappingExtra, Map.of("extra", IndexMode.STANDARD));
|
||||
IndexResolution getIndexResultExtra = IndexResolution.valid(extra);
|
||||
analyzerExtra = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultExtra, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResultExtra,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
metricMapping = loadMapping("k8s-mappings.json");
|
||||
var metricsIndex = IndexResolution.valid(new EsIndex("k8s", metricMapping, Map.of("k8s", IndexMode.TIME_SERIES)));
|
||||
metricsAnalyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), metricsIndex, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
metricsIndex,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
@ -298,7 +324,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
)
|
||||
);
|
||||
multiIndexAnalyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), multiIndex, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
multiIndex,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
}
|
||||
|
@ -5268,7 +5300,13 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
|
|||
EsIndex empty = new EsIndex("empty_test", emptyMap(), Map.of());
|
||||
IndexResolution getIndexResultAirports = IndexResolution.valid(empty);
|
||||
var analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResultAirports,
|
||||
enrichResolution,
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
||||
|
|
|
@ -164,6 +164,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_SEARCH_STATS;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
|
||||
|
@ -356,7 +357,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
|
|||
EsIndex index = new EsIndex(indexName, mapping, Map.of("test", IndexMode.STANDARD));
|
||||
IndexResolution getIndexResult = IndexResolution.valid(index);
|
||||
Analyzer analyzer = new Analyzer(
|
||||
new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution),
|
||||
new AnalyzerContext(config, functionRegistry, getIndexResult, lookupResolution, enrichResolution, emptyInferenceResolution()),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
return new TestDataSource(mapping, index, analyzer, stats);
|
||||
|
@ -7673,6 +7674,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
|
|||
() -> exchangeSinkHandler.createExchangeSink(() -> {}),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null),
|
||||
List.of()
|
||||
);
|
||||
|
|
|
@ -40,6 +40,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
|||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
|
||||
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
@ -63,7 +64,8 @@ public class PropagateInlineEvalsTests extends ESTestCase {
|
|||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
defaultLookupResolution(),
|
||||
new EnrichResolution()
|
||||
new EnrichResolution(),
|
||||
defaultInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
|
|
|
@ -30,6 +30,10 @@ public class GrammarInDevelopmentParsingTests extends ESTestCase {
|
|||
parse("row a = 1 | match foo", "match");
|
||||
}
|
||||
|
||||
public void testDevelopmentRerank() {
|
||||
parse("row a = 1 | rerank \"foo\" ON title WITH reranker", "rerank");
|
||||
}
|
||||
|
||||
void parse(String query, String errorMessage) {
|
||||
ParsingException pe = expectThrows(ParsingException.class, () -> parser().createStatement(query));
|
||||
assertThat(pe.getMessage(), containsString("mismatched input '" + errorMessage + "'"));
|
||||
|
|
|
@ -64,6 +64,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Row;
|
|||
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
|
||||
|
||||
|
@ -3332,6 +3333,87 @@ public class StatementParserTests extends AbstractStatementParserTests {
|
|||
expectError("explain [row x = 1", "line 1:19: missing ']' at '<EOF>'");
|
||||
}
|
||||
|
||||
public void testRerankSingleField() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceID");
|
||||
var rerank = as(plan, Rerank.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(literalString("query text")));
|
||||
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
|
||||
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
|
||||
}
|
||||
|
||||
public void testRerankMultipleFields() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH inferenceID");
|
||||
var rerank = as(plan, Rerank.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(literalString("query text")));
|
||||
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
|
||||
assertThat(
|
||||
rerank.rerankFields(),
|
||||
equalTo(
|
||||
List.of(
|
||||
alias("title", attribute("title")),
|
||||
alias("description", attribute("description")),
|
||||
alias("authors_renamed", attribute("authors"))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testRerankComputedFields() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
var plan = processingCommand("RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH inferenceID");
|
||||
var rerank = as(plan, Rerank.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(literalString("query text")));
|
||||
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
|
||||
assertThat(
|
||||
rerank.rerankFields(),
|
||||
equalTo(
|
||||
List.of(
|
||||
alias("title", attribute("title")),
|
||||
alias("short_description", function("SUBSTRING", List.of(attribute("description"), integer(0), integer(100))))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testRerankWithPositionalParameters() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker")));
|
||||
var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(literalString("query text")));
|
||||
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
|
||||
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
|
||||
}
|
||||
|
||||
public void testRerankWithNamedParameters() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
|
||||
var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker")));
|
||||
var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class);
|
||||
|
||||
assertThat(rerank.queryText(), equalTo(literalString("query text")));
|
||||
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
|
||||
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
|
||||
}
|
||||
|
||||
public void testInvalidRerank() {
|
||||
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
|
||||
expectError("FROM foo* | RERANK ON title WITH inferenceId", "line 1:20: mismatched input 'ON' expecting {QUOTED_STRING");
|
||||
|
||||
expectError("FROM foo* | RERANK \"query text\" WITH inferenceId", "line 1:33: mismatched input 'WITH' expecting 'on'");
|
||||
|
||||
expectError("FROM foo* | RERANK \"query text\" ON title", "line 1:41: mismatched input '<EOF>' expecting {'and',");
|
||||
}
|
||||
|
||||
static Alias alias(String name, Expression value) {
|
||||
return new Alias(EMPTY, name, value);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.logical.inference;
|
||||
|
||||
import org.elasticsearch.xpack.esql.core.expression.Alias;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.expression.AliasTests;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests;
|
||||
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
|
||||
|
||||
public class RerankSerializationTests extends AbstractLogicalPlanSerializationTests<Rerank> {
|
||||
@Override
|
||||
protected Rerank createTestInstance() {
|
||||
Source source = randomSource();
|
||||
LogicalPlan child = randomChild(0);
|
||||
return new Rerank(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Rerank mutateInstance(Rerank instance) throws IOException {
|
||||
LogicalPlan child = instance.child();
|
||||
Expression inferenceId = instance.inferenceId();
|
||||
Expression queryText = instance.queryText();
|
||||
List<Alias> fields = instance.rerankFields();
|
||||
|
||||
switch (between(0, 3)) {
|
||||
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
|
||||
case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankSerializationTests.randomIdentifier()));
|
||||
case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankSerializationTests.randomIdentifier()));
|
||||
case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
|
||||
}
|
||||
return new Rerank(instance.source(), child, inferenceId, queryText, fields, instance.scoreAttribute());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean alwaysEmptySource() {
|
||||
return true;
|
||||
}
|
||||
|
||||
private List<Alias> randomFields() {
|
||||
return randomList(0, 10, AliasTests::randomAlias);
|
||||
}
|
||||
|
||||
private Literal string(String value) {
|
||||
return new Literal(EMPTY, value, DataType.KEYWORD);
|
||||
}
|
||||
|
||||
private Attribute scoreAttribute() {
|
||||
return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plan.physical.inference;
|
||||
|
||||
import org.elasticsearch.xpack.esql.core.expression.Alias;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Attribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.expression.AliasTests;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests;
|
||||
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
|
||||
|
||||
public class RerankExecSerializationTests extends AbstractPhysicalPlanSerializationTests<RerankExec> {
|
||||
@Override
|
||||
protected RerankExec createTestInstance() {
|
||||
Source source = randomSource();
|
||||
PhysicalPlan child = randomChild(0);
|
||||
return new RerankExec(source, child, string(randomIdentifier()), string(randomIdentifier()), randomFields(), scoreAttribute());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RerankExec mutateInstance(RerankExec instance) throws IOException {
|
||||
PhysicalPlan child = instance.child();
|
||||
Expression inferenceId = instance.inferenceId();
|
||||
Expression queryText = instance.queryText();
|
||||
List<Alias> fields = instance.rerankFields();
|
||||
|
||||
switch (between(0, 3)) {
|
||||
case 0 -> child = randomValueOtherThan(child, () -> randomChild(0));
|
||||
case 1 -> inferenceId = randomValueOtherThan(inferenceId, () -> string(RerankExecSerializationTests.randomIdentifier()));
|
||||
case 2 -> queryText = randomValueOtherThan(queryText, () -> string(RerankExecSerializationTests.randomIdentifier()));
|
||||
case 3 -> fields = randomValueOtherThan(fields, this::randomFields);
|
||||
}
|
||||
return new RerankExec(instance.source(), child, inferenceId, queryText, fields, scoreAttribute());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean alwaysEmptySource() {
|
||||
return true;
|
||||
}
|
||||
|
||||
private List<Alias> randomFields() {
|
||||
return randomList(0, 10, AliasTests::randomAlias);
|
||||
}
|
||||
|
||||
static Literal string(String value) {
|
||||
return new Literal(EMPTY, value, DataType.KEYWORD);
|
||||
}
|
||||
|
||||
private Attribute scoreAttribute() {
|
||||
return new MetadataAttribute(EMPTY, MetadataAttribute.SCORE, DataType.DOUBLE, randomBoolean());
|
||||
}
|
||||
}
|
|
@ -50,6 +50,7 @@ import static java.util.Arrays.asList;
|
|||
import static org.elasticsearch.index.query.QueryBuilders.rangeQuery;
|
||||
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
|
||||
|
@ -84,7 +85,13 @@ public class FilterTests extends ESTestCase {
|
|||
mapper = new Mapper();
|
||||
|
||||
analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
EsqlTestUtils.emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
}
|
||||
|
|
|
@ -232,6 +232,7 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
esPhysicalOperationProviders(shardContexts),
|
||||
shardContexts
|
||||
);
|
||||
|
|
|
@ -13,7 +13,6 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
||||
import org.elasticsearch.xpack.esql.analysis.Analyzer;
|
||||
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
|
||||
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
|
||||
import org.elasticsearch.xpack.esql.analysis.Verifier;
|
||||
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
|
@ -28,6 +27,8 @@ import org.junit.BeforeClass;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
@ -46,7 +47,13 @@ public class QueryTranslatorTests extends ESTestCase {
|
|||
IndexResolution getIndexResult = IndexResolution.valid(test);
|
||||
|
||||
return new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, new EnrichResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import java.util.Map;
|
|||
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
|
||||
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
|
||||
|
@ -191,7 +192,13 @@ public class ClusterRequestTests extends AbstractWireSerializingTestCase<Cluster
|
|||
IndexResolution getIndexResult = IndexResolution.valid(test);
|
||||
var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
|
||||
var analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));
|
||||
|
|
|
@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfigur
|
|||
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
|
||||
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
|
||||
|
@ -292,7 +293,13 @@ public class DataNodeRequestSerializationTests extends AbstractWireSerializingTe
|
|||
IndexResolution getIndexResult = IndexResolution.valid(test);
|
||||
var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG, FoldContext.small()));
|
||||
var analyzer = new Analyzer(
|
||||
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
|
||||
new AnalyzerContext(
|
||||
EsqlTestUtils.TEST_CFG,
|
||||
new EsqlFunctionRegistry(),
|
||||
getIndexResult,
|
||||
emptyPolicyResolution(),
|
||||
emptyInferenceResolution()
|
||||
),
|
||||
TEST_VERIFIER
|
||||
);
|
||||
return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query)));
|
||||
|
|
|
@ -86,7 +86,7 @@ public abstract class AbstractTestInferenceService implements InferenceService {
|
|||
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
|
||||
|
||||
var taskSettingsMap = getTaskSettingsMap(config);
|
||||
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
||||
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
|
||||
|
||||
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
|
||||
}
|
||||
|
@ -99,11 +99,15 @@ public abstract class AbstractTestInferenceService implements InferenceService {
|
|||
var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
|
||||
|
||||
var taskSettingsMap = getTaskSettingsMap(config);
|
||||
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
||||
var taskSettings = getTasksSettingsFromMap(taskSettingsMap);
|
||||
|
||||
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
|
||||
}
|
||||
|
||||
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
|
||||
return TestTaskSettings.fromMap(taskSettingsMap);
|
||||
}
|
||||
|
||||
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);
|
||||
|
||||
@Override
|
||||
|
@ -149,15 +153,15 @@ public abstract class AbstractTestInferenceService implements InferenceService {
|
|||
TaskType taskType,
|
||||
String service,
|
||||
ServiceSettings serviceSettings,
|
||||
TestTaskSettings taskSettings,
|
||||
TaskSettings taskSettings,
|
||||
TestSecretSettings secretSettings
|
||||
) {
|
||||
super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestTaskSettings getTaskSettings() {
|
||||
return (TestTaskSettings) super.getTaskSettings();
|
||||
public TaskSettings getTaskSettings() {
|
||||
return super.getTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -45,6 +45,11 @@ public class TestInferenceServicePlugin extends Plugin {
|
|||
TestRerankingServiceExtension.TestServiceSettings.NAME,
|
||||
TestRerankingServiceExtension.TestServiceSettings::new
|
||||
),
|
||||
new NamedWriteableRegistry.Entry(
|
||||
TaskSettings.class,
|
||||
TestRerankingServiceExtension.TestTaskSettings.NAME,
|
||||
TestRerankingServiceExtension.TestTaskSettings::new
|
||||
),
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
TestStreamingCompletionServiceExtension.TestServiceSettings.NAME,
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.inference.ModelConfigurations;
|
|||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
|
@ -43,6 +44,8 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.mock.AbstractTestInferenceService.random;
|
||||
|
||||
public class TestRerankingServiceExtension implements InferenceServiceExtension {
|
||||
|
||||
@Override
|
||||
|
@ -84,11 +87,15 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
|
||||
|
||||
var taskSettingsMap = getTaskSettingsMap(config);
|
||||
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
||||
var taskSettings = TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
|
||||
|
||||
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
|
||||
}
|
||||
|
||||
protected TaskSettings getTasksSettingsFromMap(Map<String, Object> taskSettingsMap) {
|
||||
return TestRerankingServiceExtension.TestTaskSettings.fromMap(taskSettingsMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceServiceConfiguration getConfiguration() {
|
||||
return Configuration.get();
|
||||
|
@ -107,13 +114,15 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
@Nullable Integer topN,
|
||||
List<String> input,
|
||||
boolean stream,
|
||||
Map<String, Object> taskSettings,
|
||||
Map<String, Object> taskSettingsMap,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
TaskSettings taskSettings = model.getTaskSettings().updatedTaskSettings(taskSettingsMap);
|
||||
|
||||
switch (model.getConfigurations().getTaskType()) {
|
||||
case ANY, RERANK -> listener.onResponse(makeResults(input));
|
||||
case ANY, RERANK -> listener.onResponse(makeResults(input, (TestRerankingServiceExtension.TestTaskSettings) taskSettings));
|
||||
default -> listener.onFailure(
|
||||
new ElasticsearchStatusException(
|
||||
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
|
||||
|
@ -151,7 +160,7 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
);
|
||||
}
|
||||
|
||||
private RankedDocsResults makeResults(List<String> input) {
|
||||
private RankedDocsResults makeResults(List<String> input, TestRerankingServiceExtension.TestTaskSettings taskSettings) {
|
||||
int totalResults = input.size();
|
||||
try {
|
||||
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
|
||||
|
@ -161,17 +170,19 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList());
|
||||
} catch (NumberFormatException ex) {
|
||||
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
|
||||
float minScore = random.nextFloat(-1f, 1f);
|
||||
float resultDiff = 0.2f;
|
||||
|
||||
float minScore = taskSettings.minScore();
|
||||
float resultDiff = taskSettings.resultDiff();
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
results.add(
|
||||
new RankedDocsResults.RankedDoc(
|
||||
totalResults - 1 - i,
|
||||
minScore + resultDiff * (totalResults - i),
|
||||
input.get(totalResults - 1 - i)
|
||||
)
|
||||
);
|
||||
float relevanceScore = minScore + resultDiff * (totalResults - i);
|
||||
String inputText = input.get(totalResults - 1 - i);
|
||||
if (taskSettings.useTextLength()) {
|
||||
relevanceScore = 1f / inputText.length();
|
||||
}
|
||||
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, relevanceScore, inputText));
|
||||
}
|
||||
// Ensure result are sorted by descending score
|
||||
results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore()));
|
||||
return new RankedDocsResults(results);
|
||||
}
|
||||
}
|
||||
|
@ -208,6 +219,77 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
|||
}
|
||||
}
|
||||
|
||||
public record TestTaskSettings(boolean useTextLength, float minScore, float resultDiff) implements TaskSettings {
|
||||
|
||||
static final String NAME = "test_reranking_task_settings";
|
||||
|
||||
public static TestTaskSettings fromMap(Map<String, Object> map) {
|
||||
boolean useTextLength = false;
|
||||
float minScore = random.nextFloat(-1f, 1f);
|
||||
float resultDiff = 0.2f;
|
||||
|
||||
if (map.containsKey("use_text_length")) {
|
||||
useTextLength = Boolean.parseBoolean(map.remove("use_text_length").toString());
|
||||
}
|
||||
|
||||
if (map.containsKey("min_score")) {
|
||||
minScore = Float.parseFloat(map.remove("min_score").toString());
|
||||
}
|
||||
|
||||
if (map.containsKey("result_diff")) {
|
||||
resultDiff = Float.parseFloat(map.remove("result_diff").toString());
|
||||
}
|
||||
|
||||
return new TestTaskSettings(useTextLength, minScore, resultDiff);
|
||||
}
|
||||
|
||||
public TestTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readBoolean(), in.readFloat(), in.readFloat());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(useTextLength);
|
||||
out.writeFloat(minScore);
|
||||
out.writeFloat(resultDiff);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field("use_text_length", useTextLength);
|
||||
builder.field("min_score", minScore);
|
||||
builder.field("result_diff", resultDiff);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSettings updatedTaskSettings(Map<String, Object> newSettingsMap) {
|
||||
TestTaskSettings newSettingsObject = fromMap(Map.copyOf(newSettingsMap));
|
||||
return new TestTaskSettings(
|
||||
newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength,
|
||||
newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore,
|
||||
newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public record TestServiceSettings(String modelId) implements ServiceSettings {
|
||||
|
||||
static final String NAME = "test_reranking_service_settings";
|
||||
|
|
Loading…
Reference in New Issue