Add ModelRegistryMetadata to Cluster State (#121106)

This commit integrates `MinimalServiceSettings` (introduced in #120560) into the cluster state for all registered models in the `ModelRegistry`.
These settings allow consumers to access configuration details without requiring asynchronous calls to retrieve full model configurations.

To ensure consistency, the cluster state metadata must remain synchronized with the models in the inference index.
If a mismatch is detected during startup, the master node performs an upgrade to load all model settings from the index.
This commit is contained in:
Jim Ferenczi 2025-03-18 10:12:51 +00:00 committed by GitHub
parent d20528b27c
commit 270ec538c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1473 additions and 614 deletions

View File

@ -0,0 +1,5 @@
pr: 121106
summary: Add `ModelRegistryMetadata` to Cluster State
area: Machine Learning
type: enhancement
issues: []

View File

@ -187,6 +187,7 @@ public class TransportVersions {
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
public static final TransportVersion ESQL_FAILURE_FROM_REMOTE = def(9_030_00_0);
public static final TransportVersion INDEX_RESHARDING_METADATA = def(9_031_0_00);
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA = def(9_032_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -9,6 +9,12 @@
package org.elasticsearch.inference;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
@ -46,12 +52,16 @@ import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
* @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
*/
public record MinimalServiceSettings(
@Nullable String service,
TaskType taskType,
@Nullable Integer dimensions,
@Nullable SimilarityMeasure similarity,
@Nullable ElementType elementType
) implements ToXContentObject {
) implements ServiceSettings, SimpleDiffable<MinimalServiceSettings> {
public static final String NAME = "minimal_service_settings";
public static final String SERVICE_FIELD = "service";
public static final String TASK_TYPE_FIELD = "task_type";
static final String DIMENSIONS_FIELD = "dimensions";
static final String SIMILARITY_FIELD = "similarity";
@ -61,17 +71,20 @@ public record MinimalServiceSettings(
"model_settings",
true,
args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
Integer dimensions = (Integer) args[1];
SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
DenseVectorFieldMapper.ElementType elementType = args[3] == null
String service = (String) args[0];
TaskType taskType = TaskType.fromString((String) args[1]);
Integer dimensions = (Integer) args[2];
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
DenseVectorFieldMapper.ElementType elementType = args[4] == null
? null
: DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
: DenseVectorFieldMapper.ElementType.fromString((String) args[4]);
return new MinimalServiceSettings(service, taskType, dimensions, similarity, elementType);
}
);
private static final String UNKNOWN_SERVICE = "_unknown_";
static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SERVICE_FIELD));
PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
@ -82,28 +95,39 @@ public record MinimalServiceSettings(
return PARSER.parse(parser, null);
}
public static MinimalServiceSettings textEmbedding(int dimensions, SimilarityMeasure similarity, ElementType elementType) {
return new MinimalServiceSettings(TEXT_EMBEDDING, dimensions, similarity, elementType);
public static MinimalServiceSettings textEmbedding(
String serviceName,
int dimensions,
SimilarityMeasure similarity,
ElementType elementType
) {
return new MinimalServiceSettings(serviceName, TEXT_EMBEDDING, dimensions, similarity, elementType);
}
public static MinimalServiceSettings sparseEmbedding() {
return new MinimalServiceSettings(SPARSE_EMBEDDING, null, null, null);
public static MinimalServiceSettings sparseEmbedding(String serviceName) {
return new MinimalServiceSettings(serviceName, SPARSE_EMBEDDING, null, null, null);
}
public static MinimalServiceSettings rerank() {
return new MinimalServiceSettings(RERANK, null, null, null);
public static MinimalServiceSettings rerank(String serviceName) {
return new MinimalServiceSettings(serviceName, RERANK, null, null, null);
}
public static MinimalServiceSettings completion() {
return new MinimalServiceSettings(COMPLETION, null, null, null);
public static MinimalServiceSettings completion(String serviceName) {
return new MinimalServiceSettings(serviceName, COMPLETION, null, null, null);
}
public static MinimalServiceSettings chatCompletion() {
return new MinimalServiceSettings(CHAT_COMPLETION, null, null, null);
public static MinimalServiceSettings chatCompletion(String serviceName) {
return new MinimalServiceSettings(serviceName, CHAT_COMPLETION, null, null, null);
}
public MinimalServiceSettings {
Objects.requireNonNull(taskType, "task type must not be null");
validate(taskType, dimensions, similarity, elementType);
}
public MinimalServiceSettings(Model model) {
this(
model.getConfigurations().getService(),
model.getTaskType(),
model.getServiceSettings().dimensions(),
model.getServiceSettings().similarity(),
@ -111,22 +135,55 @@ public record MinimalServiceSettings(
);
}
public MinimalServiceSettings(
TaskType taskType,
@Nullable Integer dimensions,
@Nullable SimilarityMeasure similarity,
@Nullable ElementType elementType
) {
this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
this.dimensions = dimensions;
this.similarity = similarity;
this.elementType = elementType;
validate();
public MinimalServiceSettings(StreamInput in) throws IOException {
this(
in.readOptionalString(),
TaskType.fromStream(in),
in.readOptionalInt(),
in.readOptionalEnum(SimilarityMeasure.class),
in.readOptionalEnum(ElementType.class)
);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(service);
taskType.writeTo(out);
out.writeOptionalInt(dimensions);
out.writeOptionalEnum(similarity);
out.writeOptionalEnum(elementType);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA;
}
@Override
public ToXContentObject getFilteredXContentObject() {
return this::toXContent;
}
@Override
public String modelId() {
return null;
}
public static Diff<MinimalServiceSettings> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(MinimalServiceSettings::new, in);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (service != null) {
builder.field(SERVICE_FIELD, service);
}
builder.field(TASK_TYPE_FIELD, taskType.toString());
if (dimensions != null) {
builder.field(DIMENSIONS_FIELD, dimensions);
@ -143,7 +200,8 @@ public record MinimalServiceSettings(
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("task_type=").append(taskType);
sb.append("service=").append(service);
sb.append(", task_type=").append(taskType);
if (dimensions != null) {
sb.append(", dimensions=").append(dimensions);
}
@ -156,31 +214,46 @@ public record MinimalServiceSettings(
return sb.toString();
}
private void validate() {
private static void validate(TaskType taskType, Integer dimensions, SimilarityMeasure similarity, ElementType elementType) {
switch (taskType) {
case TEXT_EMBEDDING:
validateFieldPresent(DIMENSIONS_FIELD, dimensions);
validateFieldPresent(SIMILARITY_FIELD, similarity);
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
validateFieldPresent(DIMENSIONS_FIELD, dimensions, taskType);
validateFieldPresent(SIMILARITY_FIELD, similarity, taskType);
validateFieldPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
break;
default:
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
validateFieldNotPresent(SIMILARITY_FIELD, similarity);
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
validateFieldNotPresent(DIMENSIONS_FIELD, dimensions, taskType);
validateFieldNotPresent(SIMILARITY_FIELD, similarity, taskType);
validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType, taskType);
break;
}
}
private void validateFieldPresent(String field, Object fieldValue) {
private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) {
if (fieldValue == null) {
throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
}
}
private void validateFieldNotPresent(String field, Object fieldValue) {
private static void validateFieldNotPresent(String field, Object fieldValue, TaskType taskType) {
if (fieldValue != null) {
throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
}
}
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
}
/**
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
*/
public boolean canMergeWith(MinimalServiceSettings other) {
return taskType == other.taskType
&& Objects.equals(dimensions, other.dimensions)
&& similarity == other.similarity
&& elementType == other.elementType
&& (service == null || service.equals(other.service));
}
}

View File

@ -16,8 +16,7 @@ import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
public class MinimalServiceSettingsTests extends AbstractXContentTestCase<MinimalServiceSettings> {
@Override
protected MinimalServiceSettings createTestInstance() {
public static MinimalServiceSettings randomInstance() {
TaskType taskType = randomFrom(TaskType.values());
Integer dimensions = null;
SimilarityMeasure similarity = null;
@ -28,7 +27,12 @@ public class MinimalServiceSettingsTests extends AbstractXContentTestCase<Minima
similarity = randomFrom(SimilarityMeasure.values());
elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
}
return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
return new MinimalServiceSettings(randomBoolean() ? null : randomAlphaOfLength(10), taskType, dimensions, similarity, elementType);
}
@Override
protected MinimalServiceSettings createTestInstance() {
return randomInstance();
}
@Override

View File

@ -44,18 +44,23 @@ public class GetInferenceModelAction extends ActionType<GetInferenceModelAction.
// no effect when getting a single model
private final boolean persistDefaultConfig;
// For testing only, retrieves the minimal config from the cluster state.
private final boolean returnMinimalConfig;
public Request(String inferenceEntityId, TaskType taskType) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
this(inferenceEntityId, taskType, PERSIST_DEFAULT_CONFIGS);
}
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig) {
this(inferenceEntityId, taskType, persistDefaultConfig, false);
}
public Request(String inferenceEntityId, TaskType taskType, boolean persistDefaultConfig, boolean returnMinimalConfig) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
this.taskType = Objects.requireNonNull(taskType);
this.persistDefaultConfig = persistDefaultConfig;
this.returnMinimalConfig = returnMinimalConfig;
}
public Request(StreamInput in) throws IOException {
@ -68,6 +73,12 @@ public class GetInferenceModelAction extends ActionType<GetInferenceModelAction.
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
this.returnMinimalConfig = in.readBoolean();
} else {
this.returnMinimalConfig = false;
}
}
public String getInferenceEntityId() {
@ -82,6 +93,10 @@ public class GetInferenceModelAction extends ActionType<GetInferenceModelAction.
return persistDefaultConfig;
}
public boolean isReturnMinimalConfig() {
return returnMinimalConfig;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
@ -90,6 +105,10 @@ public class GetInferenceModelAction extends ActionType<GetInferenceModelAction.
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeBoolean(this.persistDefaultConfig);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA)) {
out.writeBoolean(returnMinimalConfig);
}
}
@Override
@ -99,12 +118,13 @@ public class GetInferenceModelAction extends ActionType<GetInferenceModelAction.
Request request = (Request) o;
return Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& taskType == request.taskType
&& persistDefaultConfig == request.persistDefaultConfig;
&& persistDefaultConfig == request.persistDefaultConfig
&& returnMinimalConfig == request.returnMinimalConfig;
}
@Override
public int hashCode() {
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig);
return Objects.hash(inferenceEntityId, taskType, persistDefaultConfig, returnMinimalConfig);
}
}

View File

@ -17,7 +17,9 @@ import org.junit.ClassRule;
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class SemanticMatchIT extends SemanticMatchTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));
public static ElasticsearchCluster cluster = Clusters.testCluster(
spec -> spec.module("x-pack-inference").plugin("inference-service-test")
);
@Override
protected String getTestRestCluster() {

View File

@ -18,6 +18,7 @@ import org.junit.Before;
import java.io.IOException;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.StringContains.containsString;
public abstract class SemanticMatchTestCase extends ESRestTestCase {
@ -88,16 +89,22 @@ public abstract class SemanticMatchTestCase extends ESRestTestCase {
Request request = new Request("PUT", "_inference/text_embedding/test_dense_inference");
request.setJsonEntity("""
{
"service": "test_service",
"service": "text_embedding_test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
"api_key": "abc64",
"dimensions": 128
},
"task_settings": {
}
}
""");
adminClient().performRequest(request);
try {
adminClient().performRequest(request);
} catch (ResponseException exc) {
// in case the removal failed
assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(400));
}
}
@After

View File

@ -51,7 +51,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
}
for (int i = 0; i < 4; i++) {
putModel("te_model_" + i, mockSparseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}
var getAllModels = getAllModels();
@ -147,7 +147,9 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
{
"service": "openai",
"service_settings": {
"api_key": "XXXX"
"api_key": "XXXX",
"dimensions": 128,
"similarity": "cosine"
},
"task_settings": {
"model": "text-embedding-ada-002"

View File

@ -158,7 +158,7 @@ public class HuggingFaceServiceUpgradeIT extends InferenceUpgradeTestCase {
assertThat(inferenceMap.entrySet(), not(empty()));
}
private String embeddingConfig(String url) {
static String embeddingConfig(String url) {
return Strings.format("""
{
"service": "hugging_face",
@ -181,7 +181,7 @@ public class HuggingFaceServiceUpgradeIT extends InferenceUpgradeTestCase {
""";
}
private String elserConfig(String url) {
static String elserConfig(String url) {
return Strings.format("""
{
"service": "hugging_face",
@ -193,7 +193,7 @@ public class HuggingFaceServiceUpgradeIT extends InferenceUpgradeTestCase {
""", url);
}
private String elserResponse() {
static String elserResponse() {
return """
[
{

View File

@ -11,6 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
@ -24,6 +25,7 @@ import java.util.List;
import java.util.Map;
import static org.elasticsearch.core.Strings.format;
import static org.hamcrest.Matchers.containsString;
public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCase {
@ -86,6 +88,15 @@ public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCas
return entityAsMap(response);
}
@SuppressWarnings("unchecked")
protected Map<String, Map<String, Object>> getMinimalConfig() throws IOException {
var endpoint = "_cluster/state?filter_path=metadata.model_registry";
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertOK(response);
return (Map<String, Map<String, Object>>) XContentMapValues.extractValue("metadata.model_registry.models", entityAsMap(response));
}
protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("POST", endpoint);
@ -125,6 +136,18 @@ public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCas
assertOKAndConsume(response);
}
@SuppressWarnings("unchecked")
protected void deleteAll() throws IOException {
var endpoints = (List<Map<String, Object>>) get(TaskType.ANY, "*").get("endpoints");
for (var endpoint : endpoints) {
try {
delete((String) endpoint.get("inference_id"));
} catch (Exception exc) {
assertThat(exc.getMessage(), containsString("reserved inference endpoint"));
}
}
}
@SuppressWarnings("unchecked")
// in version 8.15, there was a breaking change where "models" was renamed to "endpoints"
LinkedList<Map<String, Object>> getConfigsWithBreakingChangeHandling(TaskType testTaskType, String oldClusterId) throws IOException {

View File

@ -0,0 +1,147 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.application;
import com.carrotsearch.randomizedtesting.annotations.Name;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.application.HuggingFaceServiceUpgradeIT.elserConfig;
import static org.elasticsearch.xpack.application.HuggingFaceServiceUpgradeIT.elserResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class ModelRegistryUpgradeIT extends InferenceUpgradeTestCase {
private static MockWebServer embeddingsServer;
private static MockWebServer elserServer;
@BeforeClass
public static void startWebServer() throws IOException {
embeddingsServer = new MockWebServer();
embeddingsServer.start();
elserServer = new MockWebServer();
elserServer.start();
}
@AfterClass
public static void shutdown() {
embeddingsServer.close();
elserServer.close();
}
public ModelRegistryUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}
public void testUpgradeModels() throws Exception {
if (isOldCluster()) {
int numModels = randomIntBetween(5, 10);
for (int i = 0; i < numModels; i++) {
var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
if (taskType == TaskType.TEXT_EMBEDDING) {
int numDimensions = randomIntBetween(2, 50);
try {
embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse(numDimensions)));
put("test-inference-" + i, embeddingConfig(getUrl(embeddingsServer)), taskType);
} finally {
embeddingsServer.clearRequests();
}
} else {
try {
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
put("test-inference-" + i, elserConfig(getUrl(elserServer)), taskType);
} finally {
elserServer.clearRequests();
}
}
}
} else if (isUpgradedCluster()) {
// check upgraded model in the cluster state
assertBusy(() -> assertMinimalModelsAreUpgraded());
deleteAll();
}
}
@SuppressWarnings("unchecked")
private void assertMinimalModelsAreUpgraded() throws IOException {
var fullModels = (List<Map<String, Object>>) get(TaskType.ANY, "*").get("endpoints");
var minimalModels = getMinimalConfig();
assertMinimalModelsAreUpgraded(
fullModels.stream().collect(Collectors.toMap(a -> (String) a.get("inference_id"), a -> a)),
minimalModels
);
}
@SuppressWarnings("unchecked")
private void assertMinimalModelsAreUpgraded(
Map<String, Map<String, Object>> fullModels,
Map<String, Map<String, Object>> minimalModels
) {
assertThat(fullModels.size(), greaterThan(0));
assertThat(fullModels.size(), equalTo(minimalModels.size()));
for (var entry : fullModels.entrySet()) {
var fullModel = entry.getValue();
var fullModelSettings = (Map<String, Object>) fullModel.get("service_settings");
var minimalModelSettings = minimalModels.get(entry.getKey());
assertNotNull(minimalModelSettings);
assertThat(minimalModelSettings.get("service"), equalTo(fullModel.get("service")));
assertThat(minimalModelSettings.get("task_type"), equalTo(fullModel.get("task_type")));
var taskType = TaskType.fromString((String) minimalModelSettings.get("task_type"));
if (taskType == TaskType.TEXT_EMBEDDING) {
assertNotNull(minimalModelSettings.get("dimensions"));
assertNotNull(minimalModelSettings.get("similarity"));
// For default models, dimensions and similarity are not exposed since they are predefined.
if (fullModelSettings.containsKey("dimensions")) {
assertThat(minimalModelSettings.get("dimensions"), equalTo(fullModelSettings.get("dimensions")));
}
if (fullModelSettings.containsKey("similarity")) {
assertThat(minimalModelSettings.get("similarity"), equalTo(fullModelSettings.get("similarity")));
}
}
}
}
private String embeddingResponse(int numDimensions) {
StringBuilder result = new StringBuilder();
result.append("[[");
for (int i = 0; i < numDimensions; i++) {
if (i > 0) {
result.append(", ");
}
result.append(randomFloat());
}
result.append("]]");
return result.toString();
}
static String embeddingConfig(String url) {
return Strings.format("""
{
"service": "hugging_face",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", url, randomFrom(DenseVectorFieldMapper.ElementType.values()), randomFrom(SimilarityMeasure.values()));
}
}

View File

@ -334,7 +334,7 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
public TestServiceSettings(StreamInput in) throws IOException {
this(
in.readString(),
in.readOptionalInt(),
in.readInt(),
in.readOptionalEnum(SimilarityMeasure.class),
in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class)
);

View File

@ -28,6 +28,7 @@ import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;
import java.util.Arrays;
@ -57,9 +58,10 @@ public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCas
@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
Utils.storeSparseModel(modelRegistry);
Utils.storeDenseModel(
client(),
modelRegistry,
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),

View File

@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;
import java.util.Collection;
@ -73,6 +74,7 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
@Before
public void setup() throws Exception {
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
// dot product means that we need normalized vectors; it's not worth doing that in this test
SimilarityMeasure similarity = randomValueOtherThan(
@ -80,9 +82,8 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
);
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
Utils.storeSparseModel(client());
Utils.storeDenseModel(client(), dimensions, similarity, elementType);
Utils.storeSparseModel(modelRegistry);
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
}
@Override

View File

@ -22,6 +22,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
@ -57,7 +58,7 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
threadPool = createThreadPool(inferenceUtilityPool());
webServer.start();
gatewayUrl = getUrl(webServer);
modelRegistry = new ModelRegistry(client());
modelRegistry = node().injector().getInstance(ModelRegistry.class);
}
@After
@ -73,7 +74,7 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ReindexPlugin.class);
return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class);
}
public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
@ -97,7 +98,11 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
)
)
);
@ -134,7 +139,7 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
)
@ -207,10 +212,14 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(),
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
)
@ -254,7 +263,11 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service)
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
)
)
)
);

View File

@ -21,6 +21,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.MinimalServiceSettingsTests;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
@ -38,6 +39,7 @@ import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.registry.ModelRegistryTests;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
@ -66,7 +68,6 @@ import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
@ -80,7 +81,8 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
@Before
public void createComponents() {
modelRegistry = new ModelRegistry(client());
modelRegistry = node().injector().getInstance(ModelRegistry.class);
modelRegistry.clearDefaultIds();
}
@Override
@ -91,44 +93,30 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
public void testStoreModel() throws Exception {
String inferenceEntityId = "test-store-model";
Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);
assertThat(storeModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
ModelRegistryTests.assertStoreModel(modelRegistry, model);
}
public void testStoreModelWithUnknownFields() throws Exception {
String inferenceEntityId = "test-store-model-unknown-field";
Model model = buildModelWithUnknownField(inferenceEntityId);
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);
assertNull(storeModelHolder.get());
assertNotNull(exceptionHolder.get());
assertThat(exceptionHolder.get(), instanceOf(ElasticsearchStatusException.class));
ElasticsearchStatusException statusException = (ElasticsearchStatusException) exceptionHolder.get();
ElasticsearchStatusException statusException = expectThrows(
ElasticsearchStatusException.class,
() -> ModelRegistryTests.assertStoreModel(modelRegistry, model)
);
assertThat(
statusException.getRootCause().getMessage(),
containsString("mapping set to strict, dynamic introduction of [unknown_field] within [_doc] is not allowed")
);
assertThat(exceptionHolder.get().getMessage(), containsString("Failed to store inference endpoint [" + inferenceEntityId + "]"));
assertThat(statusException.getMessage(), containsString("Failed to store inference endpoint [" + inferenceEntityId + "]"));
}
public void testGetModel() throws Exception {
String inferenceEntityId = "test-get-model";
Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
ModelRegistryTests.assertStoreModel(modelRegistry, model);
// now get the model
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
@ -156,32 +144,18 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
public void testStoreModelFailsWhenModelExists() throws Exception {
String inferenceEntityId = "test-put-trained-model-config-exists";
Model model = buildElserModelConfig(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
ModelRegistryTests.assertStoreModel(modelRegistry, model);
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
putModelHolder.set(false);
// an model with the same id exists
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(false));
assertThat(exceptionHolder.get(), not(nullValue()));
assertThat(
exceptionHolder.get().getMessage(),
containsString("Inference endpoint [test-put-trained-model-config-exists] already exists")
);
// a model with the same id exists
var exc = expectThrows(Exception.class, () -> ModelRegistryTests.assertStoreModel(modelRegistry, model));
assertThat(exc.getMessage(), containsString("Inference endpoint [test-put-trained-model-config-exists] already exists"));
}
public void testDeleteModel() throws Exception {
// put models
for (var id : new String[] { "model1", "model2", "model3" }) {
Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
ModelRegistryTests.assertStoreModel(modelRegistry, model);
}
AtomicReference<Boolean> deleteResponseHolder = new AtomicReference<>();
@ -220,7 +194,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var defaultConfigs = new ArrayList<Model>();
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (var id : new String[] { "model1", "model2", "model3" }) {
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
var modelSettings = MinimalServiceSettingsTests.randomInstance();
defaultConfigs.add(createModel(id, modelSettings.taskType(), "name"));
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
}
@ -260,11 +234,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.TEXT_EMBEDDING, service));
for (var model : sparseAndTextEmbeddingModels) {
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
ModelRegistryTests.assertStoreModel(modelRegistry, model);
}
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -303,10 +273,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
for (int i = 0; i < modelCount; i++) {
var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service);
createdModels.add(model);
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
ModelRegistryTests.assertStoreModel(modelRegistry, model);
}
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
@ -331,14 +298,10 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var inferenceEntityId = "model-with-secrets";
var secret = "abc";
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret);
blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
ModelRegistryTests.assertStoreModel(modelRegistry, modelWithSecrets);
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get().secrets().keySet(), hasSize(1));
@ -364,7 +327,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (int i = 0; i < defaultModelCount; i++) {
var id = "default-" + i;
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
var modelSettings = MinimalServiceSettingsTests.randomInstance();
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
}
@ -385,9 +348,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var id = randomAlphaOfLength(5) + i;
var model = createModel(id, randomFrom(TaskType.values()), serviceName);
createdModels.put(id, model);
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
ModelRegistryTests.assertStoreModel(modelRegistry, model);
}
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
@ -429,7 +390,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (int i = 0; i < defaultModelCount; i++) {
var id = "default-" + i;
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
var modelSettings = MinimalServiceSettingsTests.randomInstance();
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
}
@ -439,7 +400,6 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
defaultIds.forEach(modelRegistry::addDefaultIds);
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -471,7 +431,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
for (int i = 0; i < defaultModelCount; i++) {
var id = "default-" + i;
var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
var modelSettings = MinimalServiceSettingsTests.randomInstance();
defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
}
@ -511,11 +471,13 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
defaultIds.add(
new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(serviceName), service)
);
defaultIds.add(
new InferenceService.DefaultConfigId(
"default-text",
MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
MinimalServiceSettings.textEmbedding(serviceName, 384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
service
)
);
@ -527,17 +489,12 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
}).when(service).defaultConfigs(any());
defaultIds.forEach(modelRegistry::addDefaultIds);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
ModelRegistryTests.assertStoreModel(modelRegistry, configured1);
ModelRegistryTests.assertStoreModel(modelRegistry, configured2);
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
assertNull(exceptionHolder.get());
@ -563,15 +520,17 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
var service = mock(InferenceService.class);
var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
defaultIds.add(
new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(serviceName), service)
);
defaultIds.add(
new InferenceService.DefaultConfigId(
"default-text",
MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
MinimalServiceSettings.textEmbedding(serviceName, 384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
service
)
);
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(), service));
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(serviceName), service));
doAnswer(invocation -> {
ActionListener<List<Model>> listener = invocation.getArgument(0);
@ -580,20 +539,14 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
}).when(service).defaultConfigs(any());
defaultIds.forEach(modelRegistry::addDefaultIds);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName);
var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName);
blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
blockingCall(listener -> modelRegistry.storeModel(configuredRerank, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
ModelRegistryTests.assertStoreModel(modelRegistry, configuredSparse);
ModelRegistryTests.assertStoreModel(modelRegistry, configuredText);
ModelRegistryTests.assertStoreModel(modelRegistry, configuredRerank);
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
if (exceptionHolder.get() != null) {
@ -693,13 +646,27 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
);
}
private static ServiceSettings createServiceSettings(TaskType taskType) {
return switch (taskType) {
case TEXT_EMBEDDING -> new TestModel.TestServiceSettings(
"model",
randomIntBetween(2, 100),
randomFrom(SimilarityMeasure.values()),
DenseVectorFieldMapper.ElementType.FLOAT
);
default -> new TestModelOfAnyKind.TestModelServiceSettings();
};
}
public static Model createModel(String inferenceEntityId, TaskType taskType, String service) {
return new Model(new ModelConfigurations(inferenceEntityId, taskType, service, new TestModelOfAnyKind.TestModelServiceSettings()));
var serviceSettings = createServiceSettings(taskType);
return new Model(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings));
}
public static Model createModelWithSecrets(String inferenceEntityId, TaskType taskType, String service, String secret) {
var serviceSettings = createServiceSettings(taskType);
return new Model(
new ModelConfigurations(inferenceEntityId, taskType, service, new TestModelOfAnyKind.TestModelServiceSettings()),
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
new ModelSecrets(new TestModelOfAnyKind.TestSecretSettings(secret))
);
}

View File

@ -13,7 +13,9 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.MappedActionFilter;
import org.elasticsearch.cluster.NamedDiff;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
@ -51,6 +53,7 @@ import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin;
@ -101,6 +104,7 @@ import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankB
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata;
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceDiagnosticsAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction;
@ -256,7 +260,8 @@ public class InferencePlugin extends Plugin
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
ModelRegistry modelRegistry = new ModelRegistry(services.client());
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
services.clusterService().addListener(modelRegistry);
if (inferenceServiceExtensions == null) {
inferenceServiceExtensions = new ArrayList<>();
@ -375,9 +380,24 @@ public class InferencePlugin extends Plugin
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new));
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom));
return entries;
}
@Override
public List<NamedXContentRegistry.Entry> getNamedXContent() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.add(
new NamedXContentRegistry.Entry(
Metadata.ProjectCustom.class,
new ParseField(ModelRegistryMetadata.TYPE),
ModelRegistryMetadata::fromXContent
)
);
return namedXContent;
}
@Override
public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings settings) {

View File

@ -202,7 +202,8 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
} else {
delegate.onFailure(e);
}
})
}),
timeout
)
);

View File

@ -983,7 +983,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
}
private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) {
if (Objects.equals(previous, current)) {
if (previous != null && current != null && previous.canMergeWith(current)) {
return true;
}
if (previous == null || current == null) {

View File

@ -26,14 +26,30 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateAckListener;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterServiceTaskQueue;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.inference.InferenceService;
@ -51,6 +67,10 @@ import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.inference.InferenceIndex;
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@ -68,23 +88,36 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.elasticsearch.core.Strings.format;
/**
* Class for persisting and reading inference endpoint configurations.
* Some inference services provide default configurations, the registry is
* made aware of these at start up via {@link #addDefaultIds(InferenceService.DefaultConfigId)}.
* Only the ids and service details are registered at this point
* as the full config definition may not be known at start up.
* The full config is lazily populated on read and persisted to the
* index. This has the effect of creating the backing index on reading
* the configs. {@link #getAllModels(boolean, ActionListener)} has an option
* to not write the default configs to index on read to avoid index creation.
* A class responsible for persisting and reading inference endpoint configurations.
* All endpoint modifications (see {@link PutInferenceModelAction}, {@link UpdateInferenceModelAction} and
* {@link DeleteInferenceEndpointAction}) are executed on the master mode to prevent race conditions when modifying models.
*
* <p><strong>Default endpoints:</strong></p>
* Some inference services provide default configurations, which are registered at startup using
* {@link #addDefaultIds(InferenceService.DefaultConfigId)}. At this point, only the IDs and service details
* are registered, as the full configuration definitions may not yet be available.
* The full configurations are populated lazily upon reading and are then persisted to the index.
* This process also triggers the creation of the backing index when reading the configurations.
* To avoid index creation, {@link #getAllModels(boolean, ActionListener)} includes an option to skip writing
* default configurations to the index during reads.
*
* <p><strong>Minimal Service Settings in Cluster State:</strong></p>
* The cluster state is updated with the {@link MinimalServiceSettings} for all registered models,
* ensuring these settings are readily accessible to consumers without requiring an asynchronous call
* to retrieve the full model configurations.
*
* <p><strong>Metadata Upgrades:</strong></p>
* Since cluster state metadata was introduced later, the master node performs an upgrade at startup,
* if necessary, to load all model settings from the {@link InferenceIndex}.
*/
public class ModelRegistry {
public class ModelRegistry implements ClusterStateListener {
public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}
public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) {
@ -106,14 +139,28 @@ public class ModelRegistry {
private static final String MODEL_ID_FIELD = "model_id";
private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
private final ClusterService clusterService;
private final OriginSettingClient client;
private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
public ModelRegistry(Client client) {
public ModelRegistry(ClusterService clusterService, Client client) {
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
defaultConfigIds = new ConcurrentHashMap<>();
this.defaultConfigIds = new ConcurrentHashMap<>();
this.clusterService = clusterService;
var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
@Override
public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
var projectMetadata = clusterState.metadata().getProject(task.getProjectId());
var updated = task.executeTask(ModelRegistryMetadata.fromState(projectMetadata));
var newProjectMetadata = ProjectMetadata.builder(projectMetadata).putCustom(ModelRegistryMetadata.TYPE, updated);
return new Tuple<>(ClusterState.builder(clusterState).putProjectMetadata(newProjectMetadata).build(), task);
}
};
this.metadataTaskQueue = clusterService.createTaskQueue("model_registry", Priority.NORMAL, executor);
}
/**
@ -155,6 +202,39 @@ public class ModelRegistry {
defaultConfigIds.put(defaultConfigId.inferenceId(), defaultConfigId);
}
/**
* Visible for testing only.
*/
public void clearDefaultIds() {
defaultConfigIds.clear();
}
/**
* Retrieves the {@link MinimalServiceSettings} associated with the specified {@code inferenceEntityId}.
*
* If the {@code inferenceEntityId} is not found, the method behaves as follows:
* <ul>
* <li>Returns {@code null} if the id might exist but its configuration is not available locally.</li>
* <li>Throws a {@link ResourceNotFoundException} if it is certain that the id does not exist in the cluster.</li>
* </ul>
*
* @param inferenceEntityId the unique identifier for the inference entity.
* @return the {@link MinimalServiceSettings} associated with the provided ID, or {@code null} if unavailable locally.
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
*/
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
var config = defaultConfigIds.get(inferenceEntityId);
if (config != null) {
return config.settings();
}
var state = ModelRegistryMetadata.fromState(clusterService.state().projectState().metadata());
var existing = state.getMinimalServiceSettings(inferenceEntityId);
if (state.isUpgraded() && existing == null) {
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
}
return existing;
}
/**
* Get a model with its secret settings
* @param inferenceEntityId Model to get
@ -219,27 +299,6 @@ public class ModelRegistry {
client.search(modelSearch, searchListener);
}
/**
* Retrieves the {@link MinimalServiceSettings} associated with the specified {@code inferenceEntityId}.
*
* If the {@code inferenceEntityId} is not found, the method behaves as follows:
* <ul>
* <li>Returns {@code null} if the id might exist but its configuration is not available locally.</li>
* <li>Throws a {@link ResourceNotFoundException} if it is certain that the id does not exist in the cluster.</li>
* </ul>
*
* @param inferenceEntityId the unique identifier for the inference entity.
* @return the {@link MinimalServiceSettings} associated with the provided ID, or {@code null} if unavailable locally.
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
*/
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
var config = defaultConfigIds.get(inferenceEntityId);
if (config != null) {
return config.settings();
}
return null;
}
private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) {
return new ResourceNotFoundException("Inference endpoint not found [{}]", inferenceEntityId);
}
@ -369,7 +428,7 @@ public class ModelRegistry {
}
});
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);
}
private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
@ -472,7 +531,11 @@ public class ModelRegistry {
// Since none of our updates succeeded at this point, we can simply return.
finalListener.onFailure(
new ElasticsearchStatusException(
format("Failed to update inference endpoint [%s] due to [%s]", inferenceEntityId),
format(
"Failed to update inference endpoint [%s] due to [%s]",
inferenceEntityId,
configResponse.buildFailureMessage()
),
RestStatus.INTERNAL_SERVER_ERROR,
configResponse.buildFailureMessage()
)
@ -556,9 +619,8 @@ public class ModelRegistry {
/**
* Note: storeModel does not overwrite existing models and thus does not need to check the lock
*/
public void storeModel(Model model, ActionListener<Boolean> listener) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreModelListener(model, listener);
public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue timeout) {
ActionListener<BulkResponse> bulkResponseActionListener = getStoreIndexListener(model, listener, timeout);
IndexRequest configRequest = createIndexRequest(
Model.documentId(model.getConfigurations().getInferenceEntityId()),
@ -581,7 +643,7 @@ public class ModelRegistry {
.execute(bulkResponseActionListener);
}
private static ActionListener<BulkResponse> getStoreModelListener(Model model, ActionListener<Boolean> listener) {
private ActionListener<BulkResponse> getStoreIndexListener(Model model, ActionListener<Boolean> listener, TimeValue timeout) {
return ActionListener.wrap(bulkItemResponses -> {
var inferenceEntityId = model.getConfigurations().getInferenceEntityId();
@ -605,7 +667,17 @@ public class ModelRegistry {
BulkItemResponse.Failure failure = getFirstBulkFailure(bulkItemResponses);
if (failure == null) {
listener.onResponse(true);
var storeListener = getStoreMetadataListener(inferenceEntityId, listener);
try {
var projectId = clusterService.state().projectState().projectId();
metadataTaskQueue.submitTask(
"add model [" + inferenceEntityId + "]",
new AddModelMetadataTask(projectId, inferenceEntityId, new MinimalServiceSettings(model), storeListener),
timeout
);
} catch (Exception exc) {
storeListener.onFailure(exc);
}
return;
}
@ -630,6 +702,31 @@ public class ModelRegistry {
});
}
private ActionListener<AcknowledgedResponse> getStoreMetadataListener(String inferenceEntityId, ActionListener<Boolean> listener) {
return new ActionListener<>() {
@Override
public void onResponse(AcknowledgedResponse resp) {
listener.onResponse(true);
}
@Override
public void onFailure(Exception exc) {
deleteModel(inferenceEntityId, ActionListener.running(() -> {
listener.onFailure(
new ElasticsearchStatusException(
format(
"Failed to add the inference endpoint [%s]. The service may be in an "
+ "inconsistent state. Please try deleting and re-adding the endpoint.",
inferenceEntityId
),
RestStatus.INTERNAL_SERVER_ERROR
)
);
}));
}
};
}
private static void logBulkFailures(String inferenceEntityId, BulkResponse bulkResponse) {
for (BulkItemResponse item : bulkResponse.getItems()) {
if (item.isFailed()) {
@ -687,12 +784,62 @@ public class ModelRegistry {
return;
}
var request = createDeleteRequest(inferenceEntityIds);
client.execute(DeleteByQueryAction.INSTANCE, request, getDeleteModelClusterStateListener(inferenceEntityIds, listener));
}
private ActionListener<BulkByScrollResponse> getDeleteModelClusterStateListener(
Set<String> inferenceEntityIds,
ActionListener<Boolean> listener
) {
return new ActionListener<>() {
@Override
public void onResponse(BulkByScrollResponse bulkByScrollResponse) {
var clusterStateListener = new ActionListener<AcknowledgedResponse>() {
@Override
public void onResponse(AcknowledgedResponse acknowledgedResponse) {
listener.onResponse(acknowledgedResponse.isAcknowledged());
}
@Override
public void onFailure(Exception exc) {
listener.onFailure(
new ElasticsearchStatusException(
format(
"Failed to delete the inference endpoint [%s]. The service may be in an "
+ "inconsistent state. Please try deleting the endpoint again.",
inferenceEntityIds
),
RestStatus.INTERNAL_SERVER_ERROR
)
);
}
};
try {
var projectId = clusterService.state().projectState().projectId();
metadataTaskQueue.submitTask(
"delete models [" + inferenceEntityIds + "]",
new DeleteModelMetadataTask(projectId, inferenceEntityIds, clusterStateListener),
null
);
} catch (Exception exc) {
clusterStateListener.onFailure(exc);
}
}
@Override
public void onFailure(Exception exc) {
listener.onFailure(exc);
}
};
}
private static DeleteByQueryRequest createDeleteRequest(Set<String> inferenceEntityIds) {
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN);
request.setQuery(documentIdsQuery(inferenceEntityIds));
request.setRefresh(true);
client.execute(DeleteByQueryAction.INSTANCE, request, listener.delegateFailureAndWrap((l, r) -> l.onResponse(Boolean.TRUE)));
return request;
}
private static IndexRequest createIndexRequest(String docId, String indexName, ToXContentObject body, boolean allowOverwriting) {
@ -723,11 +870,11 @@ public class ModelRegistry {
}
}
private QueryBuilder documentIdQuery(String inferenceEntityId) {
private static QueryBuilder documentIdQuery(String inferenceEntityId) {
return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId)));
}
private QueryBuilder documentIdsQuery(Set<String> inferenceEntityIds) {
private static QueryBuilder documentIdsQuery(Set<String> inferenceEntityIds) {
var documentIdsArray = inferenceEntityIds.stream().map(Model::documentId).toArray(String[]::new);
return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(documentIdsArray));
}
@ -747,4 +894,131 @@ public class ModelRegistry {
.filter(defaultConfigId -> defaultConfigId.settings().taskType().equals(taskType))
.collect(Collectors.toList());
}
@Override
public void clusterChanged(ClusterChangedEvent event) {
if (event.localNodeMaster() == false) {
return;
}
if (event.state().metadata().projects().size() > 1) {
// TODO: Add support to handle multi-projects
return;
}
var state = ModelRegistryMetadata.fromState(event.state().projectState().metadata());
if (state.isUpgraded()) {
return;
}
if (upgradeMetadataInProgress.compareAndSet(false, true) == false) {
return;
}
// GetInferenceModelAction is used because ModelRegistry does not know how to parse the service settings
client.execute(
GetInferenceModelAction.INSTANCE,
new GetInferenceModelAction.Request("*", TaskType.ANY, false),
new ActionListener<>() {
@Override
public void onResponse(GetInferenceModelAction.Response response) {
Map<String, MinimalServiceSettings> map = new HashMap<>();
for (var model : response.getEndpoints()) {
map.put(
model.getInferenceEntityId(),
new MinimalServiceSettings(
model.getService(),
model.getTaskType(),
model.getServiceSettings().dimensions(),
model.getServiceSettings().similarity(),
model.getServiceSettings().elementType()
)
);
}
metadataTaskQueue.submitTask(
"model registry auto upgrade",
new UpgradeModelsMetadataTask(
clusterService.state().metadata().getProject().id(),
map,
ActionListener.running(() -> upgradeMetadataInProgress.set(false))
),
null
);
}
@Override
public void onFailure(Exception e) {
upgradeMetadataInProgress.set(false);
}
}
);
}
private abstract static class MetadataTask extends AckedBatchedClusterStateUpdateTask {
private final ProjectId projectId;
MetadataTask(ProjectId projectId, ActionListener<AcknowledgedResponse> listener) {
super(TimeValue.THIRTY_SECONDS, listener);
this.projectId = projectId;
}
abstract ModelRegistryMetadata executeTask(ModelRegistryMetadata current);
public ProjectId getProjectId() {
return projectId;
}
}
private static class UpgradeModelsMetadataTask extends MetadataTask {
private final Map<String, MinimalServiceSettings> fromIndex;
UpgradeModelsMetadataTask(
ProjectId projectId,
Map<String, MinimalServiceSettings> fromIndex,
ActionListener<AcknowledgedResponse> listener
) {
super(projectId, listener);
this.fromIndex = fromIndex;
}
@Override
ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
return current.withUpgradedModels(fromIndex);
}
}
private static class AddModelMetadataTask extends MetadataTask {
private final String inferenceEntityId;
private final MinimalServiceSettings settings;
AddModelMetadataTask(
ProjectId projectId,
String inferenceEntityId,
MinimalServiceSettings settings,
ActionListener<AcknowledgedResponse> listener
) {
super(projectId, listener);
this.inferenceEntityId = inferenceEntityId;
this.settings = settings;
}
@Override
ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
return current.withAddedModel(inferenceEntityId, settings);
}
}
private static class DeleteModelMetadataTask extends MetadataTask {
private final Set<String> inferenceEntityIds;
DeleteModelMetadataTask(ProjectId projectId, Set<String> inferenceEntityId, ActionListener<AcknowledgedResponse> listener) {
super(projectId, listener);
this.inferenceEntityIds = inferenceEntityId;
}
@Override
ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
return current.withRemovedModel(inferenceEntityIds);
}
}
}

View File

@ -0,0 +1,316 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.registry;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.DiffableUtils;
import org.elasticsearch.cluster.NamedDiff;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.inference.InferenceIndex;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Custom {@link Metadata} implementation for storing the {@link MinimalServiceSettings} of all models in the {@link ModelRegistry}.
* Deleted models are retained as tombstones until the {@link ModelRegistry} upgrades from the existing inference index.
* After the upgrade, all active models are registered.
*/
public class ModelRegistryMetadata implements Metadata.ProjectCustom {
public static final String TYPE = "model_registry";
public static final ModelRegistryMetadata EMPTY = new ModelRegistryMetadata(ImmutableOpenMap.of(), Set.of());
private static final ParseField UPGRADED_FIELD = new ParseField("upgraded");
private static final ParseField MODELS_FIELD = new ParseField("models");
private static final ParseField TOMBSTONES_FIELD = new ParseField("tombstones");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ModelRegistryMetadata, Void> PARSER = new ConstructingObjectParser<>(
TYPE,
false,
args -> {
var isUpgraded = (boolean) args[0];
var settingsMap = (ImmutableOpenMap<String, MinimalServiceSettings>) args[1];
var deletedIDs = (List<String>) args[2];
if (isUpgraded) {
return new ModelRegistryMetadata(settingsMap);
}
return new ModelRegistryMetadata(settingsMap, new HashSet<>(deletedIDs));
}
);
static {
PARSER.declareBoolean(constructorArg(), UPGRADED_FIELD);
PARSER.declareObject(constructorArg(), (p, c) -> {
ImmutableOpenMap.Builder<String, MinimalServiceSettings> modelMap = ImmutableOpenMap.builder();
while (p.nextToken() != XContentParser.Token.END_OBJECT) {
String name = p.currentName();
modelMap.put(name, MinimalServiceSettings.parse(p));
}
return modelMap.build();
}, MODELS_FIELD);
PARSER.declareStringArray(optionalConstructorArg(), TOMBSTONES_FIELD);
}
public static ModelRegistryMetadata fromState(ProjectMetadata projectMetadata) {
ModelRegistryMetadata resp = projectMetadata.custom(TYPE);
return resp != null ? resp : EMPTY;
}
public ModelRegistryMetadata withAddedModel(String inferenceEntityId, MinimalServiceSettings settings) {
final var existing = modelMap.get(inferenceEntityId);
if (existing != null && settings.equals(existing)) {
return this;
}
var settingsBuilder = ImmutableOpenMap.builder(modelMap);
settingsBuilder.fPut(inferenceEntityId, settings);
if (isUpgraded) {
return new ModelRegistryMetadata(settingsBuilder.build());
}
var newTombstone = new HashSet<>(tombstones);
newTombstone.remove(inferenceEntityId);
return new ModelRegistryMetadata(settingsBuilder.build(), newTombstone);
}
public ModelRegistryMetadata withRemovedModel(Set<String> inferenceEntityIds) {
var mapBuilder = ImmutableOpenMap.builder(modelMap);
for (var toDelete : inferenceEntityIds) {
mapBuilder.remove(toDelete);
}
if (isUpgraded) {
return new ModelRegistryMetadata(mapBuilder.build());
}
var newTombstone = new HashSet<>(tombstones);
newTombstone.addAll(inferenceEntityIds);
return new ModelRegistryMetadata(mapBuilder.build(), newTombstone);
}
public ModelRegistryMetadata withUpgradedModels(Map<String, MinimalServiceSettings> indexModels) {
if (isUpgraded) {
throw new IllegalArgumentException("Already upgraded");
}
ImmutableOpenMap.Builder<String, MinimalServiceSettings> builder = ImmutableOpenMap.builder(modelMap);
for (var entry : indexModels.entrySet()) {
if (builder.containsKey(entry.getKey()) == false && tombstones.contains(entry.getKey()) == false) {
builder.fPut(entry.getKey(), entry.getValue());
}
}
return new ModelRegistryMetadata(builder.build());
}
private final boolean isUpgraded;
private final ImmutableOpenMap<String, MinimalServiceSettings> modelMap;
private final Set<String> tombstones;
public ModelRegistryMetadata(ImmutableOpenMap<String, MinimalServiceSettings> modelMap) {
this.isUpgraded = true;
this.modelMap = modelMap;
this.tombstones = null;
}
public ModelRegistryMetadata(ImmutableOpenMap<String, MinimalServiceSettings> modelMap, Set<String> tombstone) {
this.isUpgraded = false;
this.modelMap = modelMap;
this.tombstones = Collections.unmodifiableSet(tombstone);
}
public ModelRegistryMetadata(StreamInput in) throws IOException {
this.isUpgraded = in.readBoolean();
this.modelMap = in.readImmutableOpenMap(StreamInput::readString, MinimalServiceSettings::new);
this.tombstones = isUpgraded ? null : in.readCollectionAsSet(StreamInput::readString);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(isUpgraded);
out.writeMap(modelMap, StreamOutput::writeWriteable);
if (isUpgraded == false) {
out.writeStringCollection(tombstones);
}
}
public static ModelRegistryMetadata fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params ignored) {
return Iterators.concat(
Iterators.single((b, p) -> b.field(UPGRADED_FIELD.getPreferredName(), isUpgraded)),
ChunkedToXContentHelper.object(
MODELS_FIELD.getPreferredName(),
modelMap,
e -> (b, p) -> e.getValue().toXContent(b.field(e.getKey()), p)
),
isUpgraded
? Collections.emptyIterator()
: ChunkedToXContentHelper.array(
TOMBSTONES_FIELD.getPreferredName(),
Iterators.map(tombstones.iterator(), e -> (b, p) -> b.value(e))
)
);
}
/**
* Determines whether all models created prior to {@link TransportVersions#INFERENCE_MODEL_REGISTRY_METADATA}
* have been successfully restored from the {@link InferenceIndex}.
*
* @return true if all such models have been restored; false otherwise.
*
* If this method returns false, it indicates that there may still be models in the {@link InferenceIndex}
* that have not yet been referenced in the {@link #getModelMap()}.
*/
public boolean isUpgraded() {
return isUpgraded;
}
/**
* Returns all the registered models.
*/
public ImmutableOpenMap<String, MinimalServiceSettings> getModelMap() {
return modelMap;
}
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) {
return modelMap.get(inferenceEntityId);
}
@Override
public Diff<Metadata.ProjectCustom> diff(Metadata.ProjectCustom before) {
return new ModelRegistryMetadataDiff((ModelRegistryMetadata) before, this);
}
public static NamedDiff<Metadata.ProjectCustom> readDiffFrom(StreamInput in) throws IOException {
return new ModelRegistryMetadataDiff(in);
}
@Override
public EnumSet<Metadata.XContentContext> context() {
return Metadata.ALL_CONTEXTS;
}
@Override
public boolean isRestorable() {
// this metadata is created automatically from the inference index if it doesn't exist.
return false;
}
@Override
public String getWriteableName() {
return TYPE;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA;
}
@Override
public int hashCode() {
return Objects.hash(this.modelMap, this.tombstones, this.isUpgraded);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (obj.getClass() != getClass()) {
return false;
}
ModelRegistryMetadata other = (ModelRegistryMetadata) obj;
return Objects.equals(this.modelMap, other.modelMap) && isUpgraded == other.isUpgraded;
}
@Override
public String toString() {
return Strings.toString(this);
}
public Collection<String> getTombstones() {
return tombstones;
}
static class ModelRegistryMetadataDiff implements NamedDiff<Metadata.ProjectCustom> {
private static final DiffableUtils.DiffableValueReader<String, MinimalServiceSettings> SETTINGS_DIFF_READER =
new DiffableUtils.DiffableValueReader<>(MinimalServiceSettings::new, MinimalServiceSettings::readDiffFrom);
final boolean isUpgraded;
final DiffableUtils.MapDiff<String, MinimalServiceSettings, ImmutableOpenMap<String, MinimalServiceSettings>> settingsDiff;
final Set<String> tombstone;
ModelRegistryMetadataDiff(ModelRegistryMetadata before, ModelRegistryMetadata after) {
this.isUpgraded = after.isUpgraded;
this.settingsDiff = DiffableUtils.diff(before.modelMap, after.modelMap, DiffableUtils.getStringKeySerializer());
this.tombstone = after.isUpgraded ? null : after.tombstones;
}
ModelRegistryMetadataDiff(StreamInput in) throws IOException {
this.isUpgraded = in.readBoolean();
this.settingsDiff = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), SETTINGS_DIFF_READER);
this.tombstone = isUpgraded ? null : in.readCollectionAsSet(StreamInput::readString);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(isUpgraded);
settingsDiff.writeTo(out);
if (isUpgraded == false) {
out.writeStringCollection(tombstone);
}
}
@Override
public String getWriteableName() {
return TYPE;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.INFERENCE_MODEL_REGISTRY_METADATA;
}
@Override
public Metadata.ProjectCustom apply(Metadata.ProjectCustom part) {
var metadata = (ModelRegistryMetadata) part;
if (isUpgraded) {
return new ModelRegistryMetadata(settingsDiff.apply(metadata.modelMap));
} else {
return new ModelRegistryMetadata(settingsDiff.apply(metadata.modelMap), tombstone);
}
}
}
}

View File

@ -33,6 +33,8 @@ public final class Paths {
+ "}/"
+ STREAM_SUFFIX;
public static final String RETURN_MINIMAL_CONFIG = "return_minimal_config";
private Paths() {
}

View File

@ -140,7 +140,7 @@ public class ElasticInferenceService extends SenderService {
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
),
MinimalServiceSettings.chatCompletion()
MinimalServiceSettings.chatCompletion(NAME)
),
DEFAULT_ELSER_MODEL_ID_V2,
new DefaultModelConfig(
@ -153,7 +153,7 @@ public class ElasticInferenceService extends SenderService {
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
),
MinimalServiceSettings.sparseEmbedding()
MinimalServiceSettings.sparseEmbedding(NAME)
)
);
}

View File

@ -88,7 +88,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
public static final String NAME = "elasticsearch";
public static final String OLD_ELSER_SERVICE_NAME = "elser";
static final String MULTILINGUAL_E5_SMALL_MODEL_ID = ".multilingual-e5-small";
public static final String MULTILINGUAL_E5_SMALL_MODEL_ID = ".multilingual-e5-small";
static final String MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 = ".multilingual-e5-small_linux-x86_64";
public static final Set<String> MULTILINGUAL_E5_SMALL_VALID_IDS = Set.of(
MULTILINGUAL_E5_SMALL_MODEL_ID,
@ -858,7 +858,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
return List.of(
new DefaultConfigId(DEFAULT_ELSER_ID, ElserInternalServiceSettings.minimalServiceSettings(), this),
new DefaultConfigId(DEFAULT_E5_ID, MultilingualE5SmallInternalServiceSettings.minimalServiceSettings(), this),
new DefaultConfigId(DEFAULT_RERANK_ID, MinimalServiceSettings.rerank(), this)
new DefaultConfigId(DEFAULT_RERANK_ID, MinimalServiceSettings.rerank(NAME), this)
);
}

View File

@ -23,7 +23,7 @@ public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSe
public static final String NAME = "elser_mlnode_service_settings";
public static MinimalServiceSettings minimalServiceSettings() {
return MinimalServiceSettings.sparseEmbedding();
return MinimalServiceSettings.sparseEmbedding(ElasticsearchInternalService.NAME);
}
public static ElserInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) {

View File

@ -29,7 +29,12 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE;
public static MinimalServiceSettings minimalServiceSettings() {
return MinimalServiceSettings.textEmbedding(DIMENSIONS, SIMILARITY, DenseVectorFieldMapper.ElementType.FLOAT);
return MinimalServiceSettings.textEmbedding(
ElasticsearchInternalService.NAME,
DIMENSIONS,
SIMILARITY,
DenseVectorFieldMapper.ElementType.FLOAT
);
}
public static MultilingualE5SmallInternalServiceSettings defaultEndpointSettings(boolean useLinuxOptimizedModel) {

View File

@ -96,7 +96,6 @@ public class OpenAiEmbeddingsServiceSettings extends FilteredXContentObject impl
ValidationException validationException,
ConfigurationParseContext context
) {
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);

View File

@ -67,5 +67,4 @@ public class LocalStateInferencePlugin extends LocalStateCompositeXPackPlugin {
public Collection<MappedActionFilter> getMappedActionFilters() {
return inferencePlugin.getMappedActionFilters();
}
}

View File

@ -8,7 +8,8 @@
package org.elasticsearch.xpack.inference;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
@ -42,9 +43,6 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -53,7 +51,7 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -100,16 +98,16 @@ public final class Utils {
);
}
public static void storeSparseModel(Client client) throws Exception {
public static void storeSparseModel(ModelRegistry modelRegistry) throws Exception {
Model model = new TestSparseInferenceServiceExtension.TestSparseModel(
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
);
storeModel(client, model);
storeModel(modelRegistry, model);
}
public static void storeDenseModel(
Client client,
ModelRegistry modelRegistry,
int dimensions,
SimilarityMeasure similarityMeasure,
DenseVectorFieldMapper.ElementType elementType
@ -118,38 +116,13 @@ public final class Utils {
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
new TestDenseInferenceServiceExtension.TestServiceSettings("dense_model", dimensions, similarityMeasure, elementType)
);
storeModel(client, model);
storeModel(modelRegistry, model);
}
public static void storeModel(Client client, Model model) throws Exception {
ModelRegistry modelRegistry = new ModelRegistry(client);
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);
assertThat(storeModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
}
private static <T> void blockingCall(
Consumer<ActionListener<T>> function,
AtomicReference<T> response,
AtomicReference<Exception> error
) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<T> listener = ActionListener.wrap(r -> {
response.set(r);
latch.countDown();
}, e -> {
error.set(e);
latch.countDown();
});
function.accept(listener);
latch.await();
public static void storeModel(ModelRegistry modelRegistry, Model model) throws Exception {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
modelRegistry.storeModel(model, listener, AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);
assertTrue(listener.actionGet(TimeValue.THIRTY_SECONDS));
}
public static Model getInvalidModel(String inferenceEntityId, String serviceName, TaskType taskType) {

View File

@ -83,9 +83,8 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
@Override
protected String defaultMapping() {
XContentBuilder builder = null;
try {
builder = JsonXContent.contentBuilder().startObject();
XContentBuilder builder = JsonXContent.contentBuilder().startObject();
if (useIncludesExcludes) {
builder.startObject(SourceFieldMapper.NAME).array("excludes", "field").endObject();
}
@ -104,6 +103,7 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
builder.field("dimensions", model1.getServiceSettings().dimensions());
builder.field("similarity", model1.getServiceSettings().similarity().name());
builder.field("element_type", model1.getServiceSettings().elementType().name());
builder.field("service", model1.getConfigurations().getService());
builder.endObject();
builder.endObject();
@ -112,6 +112,7 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
builder.field("inference_id", model2.getInferenceEntityId());
builder.startObject("model_settings");
builder.field("task_type", model2.getTaskType().name());
builder.field("service", model2.getConfigurations().getService());
builder.endObject();
builder.endObject();

View File

@ -415,7 +415,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
MapperService mapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
@ -426,7 +426,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
fieldName,
inferenceId,
searchInferenceId,
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId);
@ -504,8 +504,8 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
exc.getMessage(),
containsString(
"Cannot update parameter [model_settings] "
+ "from [task_type=sparse_embedding] "
+ "to [task_type=text_embedding, dimensions=10, similarity=cosine, element_type=float]"
+ "from [service=null, task_type=sparse_embedding] "
+ "to [service=null, task_type=text_embedding, dimensions=10, similarity=cosine, element_type=float]"
)
);
}
@ -546,7 +546,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
mapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
@ -765,7 +765,10 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
useLegacyFormat,
b -> b.startObject("field")
.startObject(INFERENCE_FIELD)
.field(MODEL_SETTINGS_FIELD, new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null))
.field(
MODEL_SETTINGS_FIELD,
new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null)
)
.field(CHUNKS_FIELD, useLegacyFormat ? List.of() : Map.of())
.endObject()
.endObject()
@ -827,14 +830,26 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
MapperService floatMapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
new MinimalServiceSettings(
"my-service",
TaskType.TEXT_EMBEDDING,
1024,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
)
);
assertMapperService.accept(floatMapperService, DenseVectorFieldMapper.ElementType.FLOAT);
MapperService byteMapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE)
new MinimalServiceSettings(
"my-service",
TaskType.TEXT_EMBEDDING,
1024,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.BYTE
)
);
assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE);
}
@ -924,7 +939,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
MapperService mapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null)
);
Mapper mapper = mapperService.mappingLookup().getMapper(fieldName);
@ -941,7 +956,13 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
MapperService mapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
new MinimalServiceSettings(
"my-service",
TaskType.TEXT_EMBEDDING,
1024,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
)
);
Mapper mapper = mapperService.mappingLookup().getMapper(fieldName);

View File

@ -140,37 +140,43 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
public void testModelSettingsValidation() {
NullPointerException npe = expectThrows(NullPointerException.class, () -> {
new MinimalServiceSettings(null, 10, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
new MinimalServiceSettings("service", null, 10, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
});
assertThat(npe.getMessage(), equalTo("task type must not be null"));
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, 10, null, null);
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, 10, null, null);
});
assertThat(ex.getMessage(), containsString("[dimensions] is not allowed"));
ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE, null);
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE, null);
});
assertThat(ex.getMessage(), containsString("[similarity] is not allowed"));
ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, DenseVectorFieldMapper.ElementType.FLOAT);
new MinimalServiceSettings("service", TaskType.SPARSE_EMBEDDING, null, null, DenseVectorFieldMapper.ElementType.FLOAT);
});
assertThat(ex.getMessage(), containsString("[element_type] is not allowed"));
ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
new MinimalServiceSettings(
"service",
TaskType.TEXT_EMBEDDING,
null,
SimilarityMeasure.COSINE,
DenseVectorFieldMapper.ElementType.FLOAT
);
});
assertThat(ex.getMessage(), containsString("required [dimensions] field is missing"));
ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 10, null, DenseVectorFieldMapper.ElementType.FLOAT);
new MinimalServiceSettings("service", TaskType.TEXT_EMBEDDING, 10, null, DenseVectorFieldMapper.ElementType.FLOAT);
});
assertThat(ex.getMessage(), containsString("required [similarity] field is missing"));
ex = expectThrows(IllegalArgumentException.class, () -> {
new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 10, SimilarityMeasure.COSINE, null);
new MinimalServiceSettings("service", TaskType.TEXT_EMBEDDING, 10, SimilarityMeasure.COSINE, null);
});
assertThat(ex.getMessage(), containsString("required [element_type] field is missing"));
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;
import java.util.Collection;
@ -24,7 +25,8 @@ public class SemanticTextNonDynamicFieldMapperTests extends NonDynamicFieldMappe
@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
ModelRegistry modelRegistry = node().injector().getInstance(ModelRegistry.class);
Utils.storeSparseModel(modelRegistry);
}
@Override

View File

@ -352,8 +352,9 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
) throws IOException {
var modelSettings = switch (inferenceResultType) {
case NONE -> null;
case SPARSE_EMBEDDING -> new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null);
case SPARSE_EMBEDDING -> new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null);
case TEXT_EMBEDDING -> new MinimalServiceSettings(
"my-service",
TaskType.TEXT_EMBEDDING,
TEXT_EMBEDDING_DIMENSION_COUNT,
// l2_norm similarity is required for bit embeddings

View File

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.registry;
import org.elasticsearch.cluster.ClusterModule;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.MinimalServiceSettingsTests;
import org.elasticsearch.test.SimpleDiffableWireSerializationTestCase;
import java.util.Map;
import java.util.Set;
public class ModelRegistryMetadataDiffTests extends SimpleDiffableWireSerializationTestCase<Metadata.ProjectCustom> {
@Override
protected Metadata.ProjectCustom createTestInstance() {
return ModelRegistryMetadataTests.randomInstance();
}
@Override
protected Writeable.Reader<Metadata.ProjectCustom> instanceReader() {
return ModelRegistryMetadata::new;
}
@Override
protected Metadata.ProjectCustom makeTestChanges(Metadata.ProjectCustom testInstance) {
return mutateInstance((ModelRegistryMetadata) testInstance);
}
@Override
protected Writeable.Reader<Diff<Metadata.ProjectCustom>> diffReader() {
return ModelRegistryMetadata::readDiffFrom;
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(ClusterModule.getNamedWriteables());
}
@Override
protected Metadata.ProjectCustom mutateInstance(Metadata.ProjectCustom instance) {
return mutateInstance((ModelRegistryMetadata) instance);
}
private static ModelRegistryMetadata mutateInstance(ModelRegistryMetadata instance) {
if (instance.isUpgraded() == false && randomBoolean()) {
return instance.withUpgradedModels(Map.of(randomAlphaOfLength(10), MinimalServiceSettingsTests.randomInstance()));
}
if (randomBoolean() || instance.getModelMap().isEmpty()) {
return instance.withAddedModel(randomAlphaOfLength(10), MinimalServiceSettingsTests.randomInstance());
} else {
return instance.withRemovedModel(Set.of(randomFrom(instance.getModelMap().keySet())));
}
}
}

View File

@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.registry;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.MinimalServiceSettingsTests;
import org.elasticsearch.test.AbstractChunkedSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.equalTo;
public class ModelRegistryMetadataTests extends AbstractChunkedSerializingTestCase<ModelRegistryMetadata> {
public static ModelRegistryMetadata randomInstance() {
return randomInstance(randomBoolean());
}
public static ModelRegistryMetadata randomInstance(boolean isUpgraded) {
if (rarely()) {
return ModelRegistryMetadata.EMPTY;
}
int size = randomIntBetween(1, 5);
Map<String, MinimalServiceSettings> models = new HashMap<>();
for (int i = 0; i < size; i++) {
models.put(randomAlphaOfLength(10), MinimalServiceSettingsTests.randomInstance());
}
if (isUpgraded) {
return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build());
}
Set<String> deletedIDs = new HashSet<>();
size = randomIntBetween(0, 3);
for (int i = 0; i < size; i++) {
deletedIDs.add(randomAlphaOfLength(10));
}
return new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build(), deletedIDs);
}
@Override
protected ModelRegistryMetadata createTestInstance() {
return randomInstance();
}
@Override
protected ModelRegistryMetadata mutateInstance(ModelRegistryMetadata instance) {
return randomValueOtherThan(instance, this::createTestInstance);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(
Collections.singletonList(
new NamedWriteableRegistry.Entry(ModelRegistryMetadata.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)
)
);
}
@Override
protected ModelRegistryMetadata doParseInstance(XContentParser parser) throws IOException {
return ModelRegistryMetadata.fromXContent(parser);
}
@Override
protected Writeable.Reader<ModelRegistryMetadata> instanceReader() {
return ModelRegistryMetadata::new;
}
public void testUpgrade() {
var metadata = randomInstance(false);
var metadataWithTombstones = metadata.withRemovedModel(Set.of(randomFrom(metadata.getModelMap().keySet())));
var indexMetadata = metadata.withAddedModel(randomAlphanumericOfLength(10), MinimalServiceSettingsTests.randomInstance());
var upgraded = metadataWithTombstones.withUpgradedModels(indexMetadata.getModelMap());
Map<String, MinimalServiceSettings> expectedModelMap = new HashMap<>(metadataWithTombstones.getModelMap());
expectedModelMap.putAll(indexMetadata.getModelMap());
for (var id : metadataWithTombstones.getTombstones()) {
expectedModelMap.remove(id);
}
assertTrue(upgraded.isUpgraded());
assertThat(upgraded.getModelMap(), equalTo(expectedModelMap));
}
public void testAlreadyUpgraded() {
var metadata = randomInstance(true);
var indexMetadata = randomInstance(true);
var exc = expectThrows(IllegalArgumentException.class, () -> metadata.withUpgradedModels(indexMetadata.getModelMap()));
}
}

View File

@ -7,40 +7,29 @@
package org.elasticsearch.xpack.inference.registry;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.MinimalServiceSettingsTests;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchResponseUtils;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.junit.After;
import org.junit.Before;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@ -49,37 +38,29 @@ import static org.elasticsearch.core.Strings.format;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ModelRegistryTests extends ESTestCase {
public class ModelRegistryTests extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private ThreadPool threadPool;
private ModelRegistry registry;
@Before
public void setUpThreadPool() {
threadPool = new TestThreadPool(getTestName());
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
}
@After
public void tearDownThreadPool() {
terminate(threadPool);
@Before
public void createComponents() {
registry = node().injector().getInstance(ModelRegistry.class);
}
public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() {
var client = mockClient();
mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY));
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModelWithSecrets("1", listener);
@ -87,80 +68,18 @@ public class ModelRegistryTests extends ESTestCase {
assertThat(exception.getMessage(), is("Inference endpoint not found [1]"));
}
public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIndexReceived() {
var client = mockClient();
var unknownIndexHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", "unknown_index"));
mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit }));
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModelWithSecrets("1", listener);
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
is("Invalid result while loading inference endpoint [1] index: [unknown_index]. Try deleting and reinitializing the service")
);
}
public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFindInferenceEntry() {
var client = mockClient();
var inferenceSecretsHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", ".secrets-inference"));
mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit }));
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModelWithSecrets("1", listener);
IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
is("Failed to load inference endpoint [1]. Endpoint is in an invalid state, try deleting and reinitializing the service")
);
}
public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFindInferenceSecretsEntry() {
var client = mockClient();
var inferenceHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", ".inference"));
mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit }));
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModelWithSecrets("1", listener);
IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
is("Failed to load inference endpoint [1]. Endpoint is in an invalid state, try deleting and reinitializing the service")
);
}
public void testGetModelWithSecrets() {
var client = mockClient();
String config = """
{
"model_id": "1",
"task_type": "sparse_embedding",
"service": "foo"
}
""";
String secrets = """
{
"api_key": "secret"
}
""";
var inferenceHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", ".inference"));
inferenceHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(config))));
var inferenceSecretsHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", ".secrets-inference"));
inferenceSecretsHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(secrets))));
mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit }));
var registry = new ModelRegistry(client);
assertStoreModel(
registry,
new TestModel(
"1",
TaskType.SPARSE_EMBEDDING,
"foo",
new TestModel.TestServiceSettings(null, null, null, null),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings("secret")
)
);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModelWithSecrets("1", listener);
@ -169,152 +88,64 @@ public class ModelRegistryTests extends ESTestCase {
assertEquals("1", modelConfig.inferenceEntityId());
assertEquals("foo", modelConfig.service());
assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType());
assertThat(modelConfig.settings().keySet(), empty());
assertNotNull(modelConfig.settings().keySet());
assertThat(modelConfig.secrets().keySet(), hasSize(1));
assertEquals("secret", modelConfig.secrets().get("api_key"));
assertThat(modelConfig.secrets().get("secret_settings"), instanceOf(Map.class));
@SuppressWarnings("unchecked")
var secretSettings = (Map<String, Object>) modelConfig.secrets().get("secret_settings");
assertThat(secretSettings.get("api_key"), equalTo("secret"));
}
public void testGetModelNoSecrets() {
var client = mockClient();
String config = """
{
"model_id": "1",
"task_type": "sparse_embedding",
"service": "foo"
}
""";
assertStoreModel(
registry,
new TestModel(
"1",
TaskType.SPARSE_EMBEDDING,
"foo",
new TestModel.TestServiceSettings(null, null, null, null),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings(randomAlphaOfLength(4))
)
);
var inferenceHit = SearchResponseUtils.searchHitFromMap(Map.of("_index", ".inference"));
inferenceHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(config))));
var getListener = new PlainActionFuture<UnparsedModel>();
registry.getModel("1", getListener);
mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit }));
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<UnparsedModel>();
registry.getModel("1", listener);
var modelConfig = listener.actionGet(TIMEOUT);
var modelConfig = getListener.actionGet(TIMEOUT);
assertEquals("1", modelConfig.inferenceEntityId());
assertEquals("foo", modelConfig.service());
assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType());
assertThat(modelConfig.settings().keySet(), empty());
assertNotNull(modelConfig.settings().keySet());
assertThat(modelConfig.secrets().keySet(), empty());
}
public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() {
var client = mockBulkClient();
var bulkItem = mock(BulkItemResponse.class);
when(bulkItem.isFailed()).thenReturn(false);
var bulkResponse = mock(BulkResponse.class);
when(bulkResponse.getItems()).thenReturn(new BulkItemResponse[] { bulkItem });
mockClientExecuteBulk(client, bulkResponse);
var model = TestModel.createRandomInstance();
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<Boolean>();
registry.storeModel(model, listener);
assertTrue(listener.actionGet(TIMEOUT));
}
public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() {
var client = mockBulkClient();
var bulkResponse = mock(BulkResponse.class);
when(bulkResponse.getItems()).thenReturn(new BulkItemResponse[0]);
mockClientExecuteBulk(client, bulkResponse);
var model = TestModel.createRandomInstance();
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<Boolean>();
registry.storeModel(model, listener);
ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
is(
format(
"Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service",
model.getConfigurations().getInferenceEntityId()
)
)
);
assertStoreModel(registry, model);
}
public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVersionConflict() {
var client = mockBulkClient();
var bulkItem = mock(BulkItemResponse.class);
when(bulkItem.isFailed()).thenReturn(true);
var failure = new BulkItemResponse.Failure("index", "id", mock(VersionConflictEngineException.class));
when(bulkItem.getFailure()).thenReturn(failure);
var bulkResponse = mock(BulkResponse.class);
when(bulkResponse.getItems()).thenReturn(new BulkItemResponse[] { bulkItem });
mockClientExecuteBulk(client, bulkResponse);
var model = TestModel.createRandomInstance();
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<Boolean>();
assertStoreModel(registry, model);
registry.storeModel(model, listener);
ResourceAlreadyExistsException exception = expectThrows(ResourceAlreadyExistsException.class, () -> listener.actionGet(TIMEOUT));
ResourceAlreadyExistsException exception = expectThrows(
ResourceAlreadyExistsException.class,
() -> assertStoreModel(registry, model)
);
assertThat(
exception.getMessage(),
is(format("Inference endpoint [%s] already exists", model.getConfigurations().getInferenceEntityId()))
);
}
public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() {
var client = mockBulkClient();
var bulkItem = mock(BulkItemResponse.class);
when(bulkItem.isFailed()).thenReturn(true);
var failure = new BulkItemResponse.Failure("index", "id", mock(IllegalStateException.class));
when(bulkItem.getFailure()).thenReturn(failure);
var bulkResponse = mock(BulkResponse.class);
when(bulkResponse.getItems()).thenReturn(new BulkItemResponse[] { bulkItem });
mockClientExecuteBulk(client, bulkResponse);
var model = TestModel.createRandomInstance();
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<Boolean>();
registry.storeModel(model, listener);
ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
is(format("Failed to store inference endpoint [%s]", model.getConfigurations().getInferenceEntityId()))
);
}
public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() {
var client = mock(Client.class);
var registry = new ModelRegistry(client);
var listener = new PlainActionFuture<Boolean>();
registry.removeDefaultConfigs(Set.of(), listener);
assertTrue(listener.actionGet(TIMEOUT));
verify(client, times(0)).execute(any(), any(), any());
}
public void testDeleteModels_Returns_ConflictException_WhenModelIsBeingAdded() {
var client = mockClient();
var registry = new ModelRegistry(client);
var model = TestModel.createRandomInstance();
var newModel = TestModel.createRandomInstance();
registry.updateModelTransaction(newModel, model, new PlainActionFuture<>());
@ -333,10 +164,10 @@ public class ModelRegistryTests extends ESTestCase {
public void testIdMatchedDefault() {
var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
defaultConfigIds.add(
new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
defaultConfigIds.add(
new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
var matched = ModelRegistry.idMatchedDefault("bar", defaultConfigIds);
@ -346,14 +177,11 @@ public class ModelRegistryTests extends ESTestCase {
}
public void testContainsDefaultConfigId() {
var client = mockClient();
var registry = new ModelRegistry(client);
registry.addDefaultIds(
new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
registry.addDefaultIds(
new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
assertTrue(registry.containsDefaultConfigId("foo"));
assertFalse(registry.containsDefaultConfigId("baz"));
@ -362,19 +190,21 @@ public class ModelRegistryTests extends ESTestCase {
public void testTaskTypeMatchedDefaults() {
var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
defaultConfigIds.add(
new InferenceService.DefaultConfigId("s1", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("s1", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
defaultConfigIds.add(
new InferenceService.DefaultConfigId("s2", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
new InferenceService.DefaultConfigId("s2", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class))
);
defaultConfigIds.add(
new InferenceService.DefaultConfigId(
"d1",
MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
MinimalServiceSettings.textEmbedding("my_service", 384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
mock(InferenceService.class)
)
);
defaultConfigIds.add(new InferenceService.DefaultConfigId("c1", MinimalServiceSettings.completion(), mock(InferenceService.class)));
defaultConfigIds.add(
new InferenceService.DefaultConfigId("c1", MinimalServiceSettings.completion("my_service"), mock(InferenceService.class))
);
var matched = ModelRegistry.taskTypeMatchedDefaults(TaskType.SPARSE_EMBEDDING, defaultConfigIds);
assertThat(matched, contains(defaultConfigIds.get(0), defaultConfigIds.get(1)));
@ -385,19 +215,18 @@ public class ModelRegistryTests extends ESTestCase {
}
public void testDuplicateDefaultIds() {
var client = mockBulkClient();
var registry = new ModelRegistry(client);
var id = "my-inference";
var mockServiceA = mock(InferenceService.class);
when(mockServiceA.name()).thenReturn("service-a");
var mockServiceB = mock(InferenceService.class);
when(mockServiceB.name()).thenReturn("service-b");
registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomMinimalServiceSettings(), mockServiceA));
registry.addDefaultIds(new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceA));
var ise = expectThrows(
IllegalStateException.class,
() -> registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomMinimalServiceSettings(), mockServiceB))
() -> registry.addDefaultIds(
new InferenceService.DefaultConfigId(id, MinimalServiceSettingsTests.randomInstance(), mockServiceB)
)
);
assertThat(
ise.getMessage(),
@ -408,59 +237,16 @@ public class ModelRegistryTests extends ESTestCase {
);
}
private Client mockBulkClient() {
var client = mockClient();
when(client.prepareBulk()).thenReturn(new BulkRequestBuilder(client));
public static void assertStoreModel(ModelRegistry registry, Model model) {
PlainActionFuture<Boolean> storeListener = new PlainActionFuture<>();
registry.storeModel(model, storeListener, TimeValue.THIRTY_SECONDS);
assertTrue(storeListener.actionGet(TimeValue.THIRTY_SECONDS));
return client;
}
private Client mockClient() {
var client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
return client;
}
private static void mockClientExecuteSearch(Client client, SearchResponse searchResponse) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> actionListener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
ActionListener.respondAndRelease(actionListener, searchResponse);
return Void.TYPE;
}).when(client).execute(any(), any(), any());
}
private static void mockClientExecuteBulk(Client client, BulkResponse bulkResponse) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<BulkResponse> actionListener = (ActionListener<BulkResponse>) invocationOnMock.getArguments()[2];
actionListener.onResponse(bulkResponse);
return Void.TYPE;
}).when(client).execute(any(), any(), any());
}
private static SearchResponse mockSearchResponse(SearchHit[] hits) {
var searchResponse = mock(SearchResponse.class);
SearchHits searchHits = new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 1);
try {
when(searchResponse.getHits()).thenReturn(searchHits.asUnpooled());
} finally {
searchHits.decRef();
}
return searchResponse;
}
public static MinimalServiceSettings randomMinimalServiceSettings() {
TaskType type = randomFrom(TaskType.values());
if (type == TaskType.TEXT_EMBEDDING) {
return MinimalServiceSettings.textEmbedding(
randomIntBetween(2, 384),
randomFrom(SimilarityMeasure.values()),
randomFrom(DenseVectorFieldMapper.ElementType.values())
);
}
return new MinimalServiceSettings(type, null, null, null);
var settings = registry.getMinimalServiceSettings(model.getInferenceEntityId());
assertNotNull(settings);
assertThat(settings.taskType(), equalTo(model.getTaskType()));
assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions()));
assertThat(settings.elementType(), equalTo(model.getServiceSettings().elementType()));
assertThat(settings.dimensions(), equalTo(model.getServiceSettings().dimensions()));
}
}

View File

@ -11,7 +11,6 @@ import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
@ -28,7 +27,8 @@ import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
@ -43,6 +43,7 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTest
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@ -65,6 +66,7 @@ import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
@ -97,17 +99,25 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class ElasticInferenceServiceTests extends ESTestCase {
public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ModelRegistry modelRegistry;
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
}
@Before
public void init() throws Exception {
webServer.start();
modelRegistry = node().injector().getInstance(ModelRegistry.class);
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@ -380,24 +390,6 @@ public class ElasticInferenceServiceTests extends ESTestCase {
verifyNoMoreInteractions(sender);
}
private ModelRegistry mockModelRegistry() {
return mockModelRegistry(threadPool);
}
public static ModelRegistry mockModelRegistry(ThreadPool threadPool) {
var client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<Boolean>) invocationOnMock.getArgument(2);
listener.onResponse(true);
return Void.TYPE;
}).when(client).execute(any(), any(), any());
return new ModelRegistry(client);
}
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
var sender = mock(Sender.class);
@ -1166,7 +1158,11 @@ public class ElasticInferenceServiceTests extends ESTestCase {
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
)
)
);
@ -1205,8 +1201,16 @@ public class ElasticInferenceServiceTests extends ESTestCase {
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
)
)
)
);
@ -1334,7 +1338,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
factory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceSettings(Settings.EMPTY),
mockModelRegistry(),
modelRegistry,
mockAuthHandler
);
}
@ -1363,7 +1367,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
senderFactory,
createWithEmptySettings(threadPool),
ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL),
mockModelRegistry(),
modelRegistry,
mockAuthHandler
);
}
@ -1376,21 +1380,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
senderFactory,
createWithEmptySettings(threadPool),
ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL),
mockModelRegistry(),
new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool)
);
}
public static ElasticInferenceService createServiceWithAuthHandler(
HttpRequestSender.Factory senderFactory,
String elasticInferenceServiceURL,
ThreadPool threadPool
) {
return new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL),
mockModelRegistry(threadPool),
modelRegistry,
new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool)
);
}

View File

@ -17,11 +17,15 @@ import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
@ -31,6 +35,7 @@ import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInfe
import org.junit.Before;
import java.io.IOException;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
@ -41,18 +46,24 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceTests.mockModelRegistry;
import static org.hamcrest.CoreMatchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase {
public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase {
private DeterministicTaskQueue taskQueue;
private ModelRegistry modelRegistry;
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
}
@Before
public void init() throws Exception {
taskQueue = new DeterministicTaskQueue();
modelRegistry = getInstanceFromNode(ModelRegistry.class);
}
public void testSendsAnAuthorizationRequestTwice() throws Exception {
@ -104,7 +115,7 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase
handlerRef.set(
new ElasticInferenceServiceAuthorizationHandler(
createWithEmptySettings(taskQueue.getThreadPool()),
mockModelRegistry(taskQueue.getThreadPool()),
modelRegistry,
requestHandler,
initDefaultEndpoints(),
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
@ -124,7 +135,15 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase
assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
handler.defaultConfigIds(),
is(List.of(new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), null)))
is(
List.of(
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
null
)
)
)
);
assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
@ -166,7 +185,7 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.EMPTY_INSTANCE
),
MinimalServiceSettings.chatCompletion()
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME)
),
"elser-v2",
new DefaultModelConfig(
@ -179,7 +198,7 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESTestCase
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.EMPTY_INSTANCE
),
MinimalServiceSettings.sparseEmbedding()
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME)
)
);
}

View File

@ -19,6 +19,7 @@ dependencies {
javaRestTestImplementation project(path: xpackModule('esql'))
javaRestTestImplementation project(path: xpackModule('snapshot-repo-test-kit'))
javaRestTestImplementation project(path: xpackModule('ent-search'))
javaRestTestImplementation project(path: xpackModule('inference'))
javaRestTestImplementation project(':test:external-modules:test-multi-project')
}

View File

@ -97,6 +97,7 @@ import org.elasticsearch.xpack.core.security.authc.TokenMetadata;
import org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.elasticsearch.xpack.ilm.IndexLifecycle;
import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason;
import org.elasticsearch.xpack.slm.SnapshotLifecycle;
@ -424,6 +425,12 @@ abstract class MlNativeIntegTestCase extends ESIntegTestCase {
entries.add(
new NamedWriteableRegistry.Entry(AutoscalingDeciderResult.Reason.class, MlScalingReason.NAME, MlScalingReason::new)
);
entries.add(
new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)
);
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom));
doEnsureClusterStateConsistency(new NamedWriteableRegistry(entries));
}
}