Adding endpoint creation validation to ElasticInferenceService (#117642)

* Adding endpoint creation validation to ElasticInferenceService

* Fix unit tests

* Update docs/changelog/117642.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Dan Rubinstein 2025-02-19 12:24:21 -05:00 committed by GitHub
parent 3b1825571d
commit bea8df3c8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 17 deletions

View File

@ -0,0 +1,5 @@
pr: 117642
summary: Adding endpoint creation validation to `ElasticInferenceService`
area: Machine Learning
type: enhancement
issues: []

View File

@ -54,6 +54,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import java.util.ArrayList;
@ -557,11 +558,8 @@ public class ElasticInferenceService extends SenderService {
@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) {
listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel));
} else {
listener.onResponse(model);
}
// TODO: Remove this function once all services have been updated to use the new model validators
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
}
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
@ -576,18 +574,6 @@ public class ElasticInferenceService extends SenderService {
}
}
private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails(
ElasticInferenceServiceSparseEmbeddingsModel model
) {
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(
model.getServiceSettings().modelId(),
model.getServiceSettings().maxInputTokens(),
model.getServiceSettings().rateLimitSettings()
);
return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
}
private TraceContext getCurrentTraceInfo() {
var threadPool = getServiceComponents().threadPool();

View File

@ -317,7 +317,21 @@ public class ElasticInferenceServiceTests extends ESTestCase {
public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, getUrl(webServer))) {
String responseJson = """
{
"data": [
{
"hello": 2.1259406,
"greet": 1.7073475
}
]
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);