diff --git a/docs/changelog/117642.yaml b/docs/changelog/117642.yaml new file mode 100644 index 000000000000..dbddbbf5e64e --- /dev/null +++ b/docs/changelog/117642.yaml @@ -0,0 +1,5 @@ +pr: 117642 +summary: Adding endpoint creation validation to `ElasticInferenceService` +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index fee66a9f84ac..737c549255a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -54,6 +54,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.ArrayList; @@ -557,11 +558,8 @@ public class ElasticInferenceService extends SenderService { @Override public void checkModelConfig(Model model, ActionListener listener) { - if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) { - listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel)); - } else { - listener.onResponse(model); - } + // TODO: Remove this function once all services have been updated to use the new model validators + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); } private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { @@ -576,18 +574,6 @@ public class ElasticInferenceService extends SenderService { } } - private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails( - ElasticInferenceServiceSparseEmbeddingsModel model - ) { - ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings( - model.getServiceSettings().modelId(), - model.getServiceSettings().maxInputTokens(), - model.getServiceSettings().rateLimitSettings() - ); - - return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings); - } - private TraceContext getCurrentTraceInfo() { var threadPool = getServiceComponents().threadPool(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 414c2a3f943d..5d98a90ec2bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -317,7 +317,21 @@ public class ElasticInferenceServiceTests extends ESTestCase { public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = createService(senderFactory, getUrl(webServer))) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener);