Adding endpoint creation validation to ElasticInferenceService (#117642)
* Adding endpoint creation validation to ElasticInferenceService * Fix unit tests * Update docs/changelog/117642.yaml --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
3b1825571d
commit
bea8df3c8e
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 117642
|
||||||
|
summary: Adding endpoint creation validation to `ElasticInferenceService`
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -54,6 +54,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -557,11 +558,8 @@ public class ElasticInferenceService extends SenderService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkModelConfig(Model model, ActionListener<Model> listener) {
|
public void checkModelConfig(Model model, ActionListener<Model> listener) {
|
||||||
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) {
|
// TODO: Remove this function once all services have been updated to use the new model validators
|
||||||
listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel));
|
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
|
||||||
} else {
|
|
||||||
listener.onResponse(model);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
|
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
|
||||||
|
@ -576,18 +574,6 @@ public class ElasticInferenceService extends SenderService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails(
|
|
||||||
ElasticInferenceServiceSparseEmbeddingsModel model
|
|
||||||
) {
|
|
||||||
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(
|
|
||||||
model.getServiceSettings().modelId(),
|
|
||||||
model.getServiceSettings().maxInputTokens(),
|
|
||||||
model.getServiceSettings().rateLimitSettings()
|
|
||||||
);
|
|
||||||
|
|
||||||
return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
|
|
||||||
}
|
|
||||||
|
|
||||||
private TraceContext getCurrentTraceInfo() {
|
private TraceContext getCurrentTraceInfo() {
|
||||||
var threadPool = getServiceComponents().threadPool();
|
var threadPool = getServiceComponents().threadPool();
|
||||||
|
|
||||||
|
|
|
@ -317,7 +317,21 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
|
public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var service = createService(senderFactory, getUrl(webServer))) {
|
try (var service = createService(senderFactory, getUrl(webServer))) {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"hello": 2.1259406,
|
||||||
|
"greet": 1.7073475
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
|
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
|
||||||
PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
||||||
service.checkModelConfig(model, listener);
|
service.checkModelConfig(model, listener);
|
||||||
|
|
Loading…
Reference in New Issue