Add Mistral AI Chat Completion support to Inference Plugin (#128538)

* Add Mistral AI Chat Completion support to Inference Plugin

* Add changelog file

* Fix tests and typos

* Refactor Mistral chat completion integration and add tests

* Refactor Mistral error response handling and extract StreamingErrorResponse entity

* Add Mistral chat completion request and response tests

* Enhance error response documentation and clarify StreamingErrorResponse structure

* Refactor Mistral chat completion request handling and introduce skip stream options parameter

* Refactor MistralChatCompletionServiceSettings to include rateLimitSettings in equality checks

* Enhance MistralErrorResponse documentation with detailed error examples

* Add comment for Mistral-specific 422 validation error in OpenAiResponseHandler

* [CI] Auto commit changes from spotless

* Refactor OpenAiUnifiedChatCompletionRequestEntity to remove unused fields and streamline constructor

* Refactor UnifiedChatCompletionRequestEntity and UnifiedCompletionRequest to rename and update stream options parameter

* Refactor MistralChatCompletionRequestEntityTests to improve JSON assertion and remove unused imports

* Add unit tests for MistralUnifiedChatCompletionResponseHandler to validate error handling

* Add unit tests for MistralService

* Update expected service count in testGetServicesWithCompletionTaskType

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Jan-Kazlouski-elastic 2025-06-04 20:43:33 +03:00 committed by GitHub
parent 3b217b19bd
commit 767d53fefa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 2611 additions and 306 deletions

View File

@ -0,0 +1,5 @@
pr: 128538
summary: "Added Mistral Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []

View File

@ -189,6 +189,7 @@ public class TransportVersions {
public static final TransportVersion DATA_STREAM_OPTIONS_API_REMOVE_INCLUDE_DEFAULTS_8_19 = def(8_841_0_41);
public static final TransportVersion JOIN_ON_ALIASES_8_19 = def(8_841_0_42);
public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_44);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@ -282,6 +283,7 @@ public class TransportVersions {
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES = def(9_087_0_00);
public static final TransportVersion JOIN_ON_ALIASES = def(9_088_0_00);
public static final TransportVersion ILM_ADD_SKIP_SETTING = def(9_089_0_00);
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
*/
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
/**
* Indicates whether to include the `stream_options` field in the JSON output.
* Some providers do not support this field. In such cases, this parameter should be set to "false",
* and the `stream_options` field will be excluded from the output.
* For providers that do support stream options, this parameter is left unset (default behavior),
* which implicitly includes the `stream_options` field in the output.
*/
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";
/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
@ -91,6 +99,23 @@ public record UnifiedCompletionRequest(
);
}
/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
),
params
);
}
/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId

View File

@ -134,7 +134,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(13));
assertThat(services.size(), equalTo(14));
var providers = providers(services);
@ -154,7 +154,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker"
"amazon_sagemaker",
"mistral"
).toArray()
)
);
@ -162,7 +163,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));
var providers = providers(services);
@ -176,7 +177,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
"googlevertexai",
"mistral"
).toArray()
)
);

View File

@ -100,6 +100,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbedd
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@ -266,6 +267,13 @@ public class InferenceNamedWriteablesProvider {
MistralEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MistralChatCompletionServiceSettings.NAME,
MistralChatCompletionServiceSettings::new
)
);
// note - no task settings for Mistral embeddings...
}

View File

@ -21,12 +21,13 @@ import java.util.Objects;
* A pattern is emerging in how external providers provide error responses.
*
* At a minimum, these return:
* <pre><code>
* {
* "error: {
* "message": "(error message)"
* }
* }
*
* </code></pre>
* Others may return additional information such as error codes specific to the service.
*
* This currently covers error handling for Azure AI Studio, however this pattern

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.streaming;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import java.util.Objects;
import java.util.Optional;
/**
* Represents an error response from a streaming inference service.
* This class extends {@link ErrorResponse} and provides additional fields
* specific to streaming errors, such as code, param, and type.
* An example error response for a streaming service might look like:
* <pre><code>
* {
* "error": {
* "message": "Invalid input",
* "code": "400",
* "param": "input",
* "type": "invalid_request_error"
* }
* }
* </code></pre>
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
*/
public class StreamingErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
);
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);
static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));
ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}
/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as an HttpResult
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as a string
*/
public static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
@Nullable
private final String code;
@Nullable
private final String param;
private final String type;
StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}
@Nullable
public String code() {
return code;
}
@Nullable
public String param() {
return param;
}
public String type() {
return type;
}
}

View File

@ -15,6 +15,12 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM;
/**
* Represents a unified chat completion request entity.
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
*/
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
public static final String STREAM_FIELD = "stream";
@ -42,7 +48,8 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
builder.field(STREAM_FIELD, stream);
if (stream) {
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) {
builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();

View File

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
/**
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
* This class is specifically designed to handle Mistral's error response format.
*/
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
/**
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "mistral completions").
* @param parseFunction The function to parse the response.
*/
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
}
}

View File

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.services.mistral;
public class MistralConstants {
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";
// note - there is no bounds information available from Mistral,
// so we'll use a sane default here which is the same as Cohere's
@ -18,4 +19,8 @@ public class MistralConstants {
public static final String MODEL_FIELD = "model";
public static final String INPUT_FIELD = "input";
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
public static final String MAX_TOKENS_FIELD = "max_tokens";
public static final String DETAIL_FIELD = "detail";
public static final String MSG_FIELD = "msg";
public static final String MESSAGE_FIELD = "message";
}

View File

@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;
import java.util.List;

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.net.URI;
import java.net.URISyntaxException;
/**
* Represents a Mistral model that can be used for inference tasks.
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
*/
public abstract class MistralModel extends RateLimitGroupingModel {
protected String model;
protected URI uri;
protected RateLimitSettings rateLimitSettings;
protected MistralModel(ModelConfigurations configurations, ModelSecrets secrets) {
super(configurations, secrets);
}
protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
}
public String model() {
return this.model;
}
public URI uri() {
return this.uri;
}
@Override
public RateLimitSettings rateLimitSettings() {
return this.rateLimitSettings;
}
@Override
public int rateLimitGroupingHash() {
return 0;
}
// Needed for testing only
public void setURI(String newUri) {
try {
this.uri = new URI(newUri);
} catch (URISyntaxException e) {
// swallow any error
}
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
}

View File

@ -30,7 +30,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@ -39,8 +42,11 @@ import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -48,6 +54,7 @@ import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
@ -56,14 +63,26 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFrom
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
/**
* MistralService is an implementation of the SenderService that handles inference tasks
* using Mistral models. It supports text embedding, completion, and chat completion tasks.
* The service uses MistralActionCreator to create actions for executing inference requests.
*/
public class MistralService extends SenderService {
public static final String NAME = "mistral";
private static final String SERVICE_NAME = "Mistral";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
"mistral chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);
public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
@ -79,11 +98,16 @@ public class MistralService extends SenderService {
) {
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
} else {
listener.onFailure(createInvalidModelException(model));
switch (model) {
case MistralEmbeddingsModel mistralEmbeddingsModel:
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
break;
case MistralChatCompletionModel mistralChatCompletionModel:
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
break;
default:
listener.onFailure(createInvalidModelException(model));
break;
}
}
@ -99,7 +123,24 @@ public class MistralService extends SenderService {
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof MistralChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
MistralChatCompletionModel mistralChatCompletionModel = (MistralChatCompletionModel) model;
var overriddenModel = MistralChatCompletionModel.of(mistralChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new MistralChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = MistralActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}
@Override
@ -162,7 +203,7 @@ public class MistralService extends SenderService {
);
}
MistralEmbeddingsModel model = createModel(
MistralModel model = createModel(
modelId,
taskType,
serviceSettingsMap,
@ -184,7 +225,7 @@ public class MistralService extends SenderService {
}
@Override
public Model parsePersistedConfigWithSecrets(
public MistralModel parsePersistedConfigWithSecrets(
String modelId,
TaskType taskType,
Map<String, Object> config,
@ -211,7 +252,7 @@ public class MistralService extends SenderService {
}
@Override
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
@ -236,7 +277,12 @@ public class MistralService extends SenderService {
return TransportVersions.V_8_15_0;
}
private static MistralEmbeddingsModel createModel(
@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}
private static MistralModel createModel(
String modelId,
TaskType taskType,
Map<String, Object> serviceSettings,
@ -246,23 +292,26 @@ public class MistralService extends SenderService {
String failureMessage,
ConfigurationParseContext context
) {
if (taskType == TaskType.TEXT_EMBEDDING) {
return new MistralEmbeddingsModel(
modelId,
taskType,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
switch (taskType) {
case TEXT_EMBEDDING:
return new MistralEmbeddingsModel(
modelId,
taskType,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION:
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
default:
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}
private MistralEmbeddingsModel createModelFromPersistent(
private MistralModel createModelFromPersistent(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
@ -284,7 +333,7 @@ public class MistralService extends SenderService {
}
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof MistralEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
@ -304,6 +353,10 @@ public class MistralService extends SenderService {
}
}
/**
* Configuration class for the Mistral inference service.
* It provides the settings and configurations required for the service.
*/
public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
import java.util.Locale;
/**
* Handles streaming chat completion responses and error parsing for Mistral inference endpoints.
* Adapts the OpenAI handler to support Mistral's error schema.
*/
public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
private static final String MISTRAL_ERROR = "mistral_error";
public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
}
@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof MistralErrorResponse
? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}
}

View File

@ -7,19 +7,41 @@
package org.elasticsearch.xpack.inference.services.mistral.action;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.mistral.MistralCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.MistralEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.core.Strings.format;
/**
* MistralActionCreator is responsible for creating executable actions for Mistral models.
* It implements the MistralActionVisitor interface to provide specific implementations
* for different types of Mistral models.
*/
public class MistralActionCreator implements MistralActionVisitor {
public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
static final String USER_ROLE = "user";
static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
"mistral completions",
OpenAiChatCompletionResponseEntity::fromResponse
);
private final Sender sender;
private final ServiceComponents serviceComponents;
@ -35,7 +57,32 @@ public class MistralActionCreator implements MistralActionVisitor {
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = constructFailedToSendRequestMessage("Mistral embeddings");
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, embeddingsModel.getInferenceEntityId());
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
@Override
public ExecutableAction create(MistralChatCompletionModel chatCompletionModel) {
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
chatCompletionModel,
COMPLETION_HANDLER,
inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel),
ChatCompletionInput.class
);
var errorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId());
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
}
/**
* Builds an error message for Mistral actions.
*
* @param requestType The type of request (e.g., TEXT_EMBEDDING, COMPLETION).
* @param inferenceId The ID of the inference entity.
* @return A formatted error message.
*/
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
return format("Failed to send Mistral %s request from inference entity id [%s]", requestType.toString(), inferenceId);
}
}

View File

@ -8,10 +8,33 @@
package org.elasticsearch.xpack.inference.services.mistral.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import java.util.Map;
/**
* Interface for creating {@link ExecutableAction} instances for Mistral models.
* <p>
* This interface is used to create {@link ExecutableAction} instances for different types of Mistral models, such as
* {@link MistralEmbeddingsModel} and {@link MistralChatCompletionModel}.
*/
public interface MistralActionVisitor {
/**
* Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}.
*
* @param embeddingsModel The model to create the action for.
* @param taskSettings The task settings to use.
* @return An {@link ExecutableAction} for the given model.
*/
ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
/**
* Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}.
*
* @param chatCompletionModel The model to create the action for.
* @return An {@link ExecutableAction} for the given model.
*/
ExecutableAction create(MistralChatCompletionModel chatCompletionModel);
}

View File

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.completion;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralModel;
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH;
/**
* Represents a Mistral chat completion model.
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
*/
public class MistralChatCompletionModel extends MistralModel {
/**
* Constructor for MistralChatCompletionModel.
*
* @param inferenceEntityId The unique identifier for the inference entity.
* @param taskType The type of task this model is designed for.
* @param service The name of the service this model belongs to.
* @param serviceSettings The settings specific to the Mistral chat completion service.
* @param secrets The secrets required for accessing the service.
* @param context The context for parsing configuration settings.
*/
public MistralChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
MistralChatCompletionServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets)
);
}
/**
* Creates a new MistralChatCompletionModel with overridden service settings.
*
* @param model The original MistralChatCompletionModel.
* @param request The UnifiedCompletionRequest containing the model override.
* @return A new MistralChatCompletionModel with the overridden model ID.
*/
public static MistralChatCompletionModel of(MistralChatCompletionModel model, UnifiedCompletionRequest request) {
if (request.model() == null) {
// If no model is specified in the request, return the original model
return model;
}
var originalModelServiceSettings = model.getServiceSettings();
var overriddenServiceSettings = new MistralChatCompletionServiceSettings(
request.model(),
originalModelServiceSettings.rateLimitSettings()
);
return new MistralChatCompletionModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
overriddenServiceSettings,
model.getSecretSettings()
);
}
public MistralChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
MistralChatCompletionServiceSettings serviceSettings,
DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
new ModelSecrets(secrets)
);
setPropertiesFromServiceSettings(serviceSettings);
}
private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) {
this.model = serviceSettings.modelId();
this.rateLimitSettings = serviceSettings.rateLimitSettings();
setEndpointUrl();
}
@Override
public int rateLimitGroupingHash() {
return Objects.hash(model, getSecretSettings().apiKey());
}
private void setEndpointUrl() {
try {
this.uri = new URI(API_COMPLETIONS_PATH);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
@Override
public MistralChatCompletionServiceSettings getServiceSettings() {
return (MistralChatCompletionServiceSettings) super.getServiceSettings();
}
/**
* Accepts a visitor to create an executable action for this model.
*
* @param creator The visitor that creates the executable action.
* @return An ExecutableAction that can be executed.
*/
public ExecutableAction accept(MistralActionVisitor creator) {
return creator.create(this);
}
}

View File

@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
/**
* Represents the settings for the Mistral chat completion service.
* This class encapsulates the model ID and rate limit settings for the Mistral chat completion service.
*/
public class MistralChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings {
public static final String NAME = "mistral_completions_service_settings";
private final String modelId;
private final RateLimitSettings rateLimitSettings;
// default for Mistral is 5 requests / sec
// setting this to 240 (4 requests / sec) is a sane default for us
protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240);
public static MistralChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
MistralService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new MistralChatCompletionServiceSettings(model, rateLimitSettings);
}
public MistralChatCompletionServiceSettings(StreamInput in) throws IOException {
this.modelId = in.readString();
this.rateLimitSettings = new RateLimitSettings(in);
}
public MistralChatCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
this.modelId = modelId;
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED;
}
@Override
public String modelId() {
return this.modelId;
}
public RateLimitSettings rateLimitSettings() {
return this.rateLimitSettings;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
rateLimitSettings.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
this.toXContentFragmentOfExposedFields(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_FIELD, this.modelId);
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MistralChatCompletionServiceSettings that = (MistralChatCompletionServiceSettings) o;
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(modelId, rateLimitSettings);
}
}

View File

@ -10,16 +10,15 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralModel;
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.net.URI;
import java.net.URISyntaxException;
@ -27,10 +26,11 @@ import java.util.Map;
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_EMBEDDINGS_PATH;
public class MistralEmbeddingsModel extends Model {
protected String model;
protected URI uri;
protected RateLimitSettings rateLimitSettings;
/**
* Represents a Mistral embeddings model.
* This class extends MistralModel to handle embeddings-specific settings and actions.
*/
public class MistralEmbeddingsModel extends MistralModel {
public MistralEmbeddingsModel(
String inferenceEntityId,
@ -58,6 +58,20 @@ public class MistralEmbeddingsModel extends Model {
setPropertiesFromServiceSettings(serviceSettings);
}
private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
this.model = serviceSettings.modelId();
this.rateLimitSettings = serviceSettings.rateLimitSettings();
setEndpointUrl();
}
private void setEndpointUrl() {
try {
this.uri = new URI(API_EMBEDDINGS_PATH);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
public MistralEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
@ -74,51 +88,11 @@ public class MistralEmbeddingsModel extends Model {
setPropertiesFromServiceSettings(serviceSettings);
}
private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
this.model = serviceSettings.modelId();
this.rateLimitSettings = serviceSettings.rateLimitSettings();
setEndpointUrl();
}
@Override
public MistralEmbeddingsServiceSettings getServiceSettings() {
return (MistralEmbeddingsServiceSettings) super.getServiceSettings();
}
public String model() {
return this.model;
}
public URI uri() {
return this.uri;
}
public RateLimitSettings rateLimitSettings() {
return this.rateLimitSettings;
}
private void setEndpointUrl() {
try {
this.uri = new URI(API_EMBEDDINGS_PATH);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
// Needed for testing only
public void setURI(String newUri) {
try {
this.uri = new URI(newUri);
} catch (URISyntaxException e) {
// swallow any error
}
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
public ExecutableAction accept(MistralActionVisitor creator, Map<String, Object> taskSettings) {
return creator.create(this, taskSettings);
}

View File

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
/**
* Mistral Unified Chat Completion Request
* This class is responsible for creating a request to the Mistral chat completion model.
* It constructs an HTTP POST request with the necessary headers and body content.
*/
public class MistralChatCompletionRequest implements Request {
private final MistralChatCompletionModel model;
private final UnifiedChatInput chatInput;
public MistralChatCompletionRequest(UnifiedChatInput chatInput, MistralChatCompletionModel model) {
this.chatInput = Objects.requireNonNull(chatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new MistralChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey()));
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public URI getURI() {
return model.uri();
}
@Override
public Request truncate() {
// No truncation for Mistral chat completions
return this;
}
@Override
public boolean[] getTruncationInfo() {
// No truncation for Mistral chat completions
return null;
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public boolean isStreaming() {
return chatInput.stream();
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import java.io.IOException;
import java.util.Objects;
/**
* MistralChatCompletionRequestEntity is responsible for creating the request entity for Mistral chat completion.
* It implements ToXContentObject to allow serialization to XContent format.
*/
public class MistralChatCompletionRequestEntity implements ToXContentObject {
private final MistralChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, MistralChatCompletionModel model) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(
builder,
UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params)
);
builder.endObject();
return builder;
}
}

View File

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request;
package org.elasticsearch.xpack.inference.services.mistral.request.embeddings;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;

View File

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request;
package org.elasticsearch.xpack.inference.services.mistral.request.embeddings;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

View File

@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.response;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import java.nio.charset.StandardCharsets;
/**
* Represents an error response entity for Mistral inference services.
* This class extends ErrorResponse and provides a method to create an instance
* from an HttpResult, attempting to read the body as a UTF-8 string.
* An example error response for Not Found error would look like:
* <pre><code>
* {
* "detail": "Not Found"
* }
* </code></pre>
* An example error response for Bad Request error would look like:
* <pre><code>
* {
* "object": "error",
* "message": "Invalid model: wrong-model-name",
* "type": "invalid_model",
* "param": null,
* "code": "1500"
* }
* </code></pre>
* An example error response for Unauthorized error would look like:
* <pre><code>
* {
* "message": "Unauthorized",
* "request_id": "ad95a2165083f20b490f8f78a14bb104"
* }
* </code></pre>
* An example error response for Unprocessable Entity error would look like:
* <pre><code>
* {
* "object": "error",
* "message": {
* "detail": [
* {
* "type": "greater_than_equal",
* "loc": [
* "body",
* "max_tokens"
* ],
* "msg": "Input should be greater than or equal to 0",
* "input": -10,
* "ctx": {
* "ge": 0
* }
* }
* ]
* },
* "type": "invalid_request_error",
* "param": null,
* "code": null
* }
* </code></pre>
*/
public class MistralErrorResponse extends ErrorResponse {
public MistralErrorResponse(String message) {
super(message);
}
/**
* Creates an ErrorResponse from the given HttpResult.
* Attempts to read the body as a UTF-8 string and constructs a MistralErrorResponseEntity.
* If reading fails, returns a generic UNDEFINED_ERROR.
*
* @param response the HttpResult containing the error response
* @return an ErrorResponse instance
*/
public static ErrorResponse fromResponse(HttpResult response) {
try {
String errorMessage = new String(response.body(), StandardCharsets.UTF_8);
return new MistralErrorResponse(errorMessage);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
}

View File

@ -24,6 +24,7 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentE
import java.util.concurrent.Flow;
import java.util.function.Function;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
public class OpenAiResponseHandler extends BaseResponseHandler {
@ -40,6 +41,7 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens";
static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length.";
static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response";
static final String OPENAI_SERVER_BUSY = "Received a server busy error status code";
@ -86,11 +88,23 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else if (statusCode == 422) {
// OpenAI does not return 422 at the time of writing, but Mistral does and follows most of OpenAI's format.
// TODO: Revisit this in the future to decouple OpenAI and Mistral error handling.
throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, request, result));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(request), request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}
private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}
protected RetryException buildExceptionHandling429(Request request, HttpResult result) {
return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
}

View File

@ -7,15 +7,8 @@
package org.elasticsearch.xpack.inference.services.openai;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
@ -24,18 +17,21 @@ import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Flow;
import java.util.function.Function;
import static org.elasticsearch.core.Strings.format;
/**
* Handles streaming chat completion responses and error parsing for OpenAI inference endpoints.
* This handler is designed to work with the unified OpenAI chat completion API.
*/
public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
super(requestType, parseFunction, StreamingErrorResponse::fromResponse);
}
public OpenAiUnifiedChatCompletionResponseHandler(
@ -62,7 +58,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof OpenAiErrorResponse oer
return errorResponse instanceof StreamingErrorResponse oer
? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param())
: new UnifiedChatCompletionException(
restStatus,
@ -84,8 +80,8 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
}
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
var errorResponse = StreamingErrorResponse.fromString(message);
if (errorResponse instanceof StreamingErrorResponse oer) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
@ -109,85 +105,4 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
);
}
}
private static class OpenAiErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"open_ai_error",
true,
args -> Optional.ofNullable((OpenAiErrorResponse) args[0])
);
private static final ConstructingObjectParser<OpenAiErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"open_ai_error",
true,
args -> new OpenAiErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);
static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));
ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}
private static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
@Nullable
private final String code;
@Nullable
private final String param;
private final String type;
OpenAiErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}
@Nullable
public String code() {
return code;
}
@Nullable
public String param() {
return param;
}
public String type() {
return type;
}
}
}

View File

@ -21,15 +21,10 @@ import java.util.Objects;
public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject {
public static final String USER_FIELD = "user";
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private final UnifiedChatInput unifiedChatInput;
private final OpenAiChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

View File

@ -54,6 +54,7 @@ import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
@ -273,8 +274,10 @@ public class DeepSeekServiceTests extends ESTestCase {
assertThat(
e.getMessage(),
equalTo(
"Received an unsuccessful status code for request from inference entity id [inference-id] status"
+ " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
"Resource not found at ["
+ getUrl(webServer)
+ "] for request from inference entity id [inference-id]"
+ " status [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
)
);
}

View File

@ -66,6 +66,7 @@ import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@ -377,15 +378,14 @@ public class HuggingFaceServiceTests extends ESTestCase {
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is("""
assertThat(json, is(String.format(Locale.ROOT, """
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
"message":"Resource not found at [%s] for request from inference entity id [id] status \
[404]. Error message: [Model not found.]",\
"type":"hugging_face_error"\
}}"""));
}}""", getUrl(webServer))));
} catch (IOException ex) {
throw new RuntimeException(ex);
}

View File

@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
@ -29,39 +30,53 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
@ -72,12 +87,14 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_KEY_FIELD;
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettingsTests.getServiceSettingsMap;
import static org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettingsTests.createRequestSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -118,11 +135,7 @@ public class MistralServiceTests extends ESTestCase {
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getSecretSettingsMap("secret")
),
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
modelVerificationListener
);
}
@ -144,8 +157,8 @@ public class MistralServiceTests extends ESTestCase {
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(null, null),
getTaskSettingsMap(),
createRandomChunkingSettingsMap(),
getSecretSettingsMap("secret")
),
@ -169,16 +182,289 @@ public class MistralServiceTests extends ESTestCase {
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getSecretSettingsMap("secret")
),
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
modelVerificationListener
);
}
}
public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException {
var model = "model";
var secret = "secret";
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
assertThat(m, instanceOf(MistralChatCompletionModel.class));
var completionsModel = (MistralChatCompletionModel) m;
assertThat(completionsModel.getServiceSettings().modelId(), is(model));
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
}, exception -> fail("Unexpected exception: " + exception));
service.parseRequestConfig(
"id",
TaskType.COMPLETION,
getRequestConfigMap(getServiceSettingsMap(model), getSecretSettingsMap(secret)),
modelVerificationListener
);
}
}
public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException {
var secret = "secret";
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
assertThat(m, instanceOf(MistralChatCompletionModel.class));
var completionsModel = (MistralChatCompletionModel) m;
assertNull(completionsModel.getServiceSettings().modelId());
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
}, exception -> {
assertThat(exception, instanceOf(ValidationException.class));
assertThat(
exception.getMessage(),
is("Validation Failed: 1: [service_settings] does not contain the required setting [model];")
);
});
service.parseRequestConfig(
"id",
TaskType.COMPLETION,
getRequestConfigMap(Collections.emptyMap(), getSecretSettingsMap(secret)),
modelVerificationListener
);
}
}
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION);
try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}
verify(sender, times(1)).close();
verifyNoMoreInteractions(factory);
verifyNoMoreInteractions(sender);
}
public void testUnifiedCompletionInfer() throws Exception {
// The escapes are because the streaming response must be on a single line
String responseJson = """
data: {\
"id": "37d683fc0b3949b880529fb20973aca7",\
"object": "chat.completion.chunk",\
"created": 1749032579,\
"model": "mistral-small-latest",\
"choices": [\
{\
"index": 0,\
"delta": {\
"content": "Cho"\
},\
"finish_reason": "length",\
"logprobs": null\
}\
],\
"usage": {\
"prompt_tokens": 10,\
"total_tokens": 11,\
"completion_tokens": 1\
}\
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace("""
{
"id": "37d683fc0b3949b880529fb20973aca7",
"choices": [{
"delta": {
"content": "Cho"
},
"finish_reason": "length",
"index": 0
}
],
"model": "mistral-small-latest",
"object": "chat.completion.chunk",
"usage": {
"completion_tokens": 1,
"prompt_tokens": 10,
"total_tokens": 11
}
}
"""));
}
}
public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception {
String responseJson = """
{
"detail": "Not Found"
}
""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
var latch = new CountDownLatch(1);
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
try (var builder = XContentFactory.jsonBuilder()) {
var t = unwrapCause(e);
assertThat(t, isA(UnifiedChatCompletionException.class));
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace("""
{
"error" : {
"code" : "not_found",
"message" : "Resource not found at [%s] for request from inference entity id [id] status \
[404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]",
"type" : "mistral_error"
}
}"""), getUrl(webServer))));
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}), latch::countDown)
);
assertTrue(latch.await(30, TimeUnit.SECONDS));
}
}
public void testInfer_StreamRequest() throws Exception {
String responseJson = """
data: {\
"id":"12345",\
"object":"chat.completion.chunk",\
"created":123456789,\
"model":"gpt-4o-mini",\
"system_fingerprint": "123456789",\
"choices":[\
{\
"index":0,\
"delta":{\
"content":"hello, world"\
},\
"logprobs":null,\
"finish_reason":null\
}\
]\
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
private InferenceEventsAssertion streamCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
public void testInfer_StreamRequest_ErrorResponse() {
String responseJson = """
{
"message": "Unauthorized",
"request_id": "ad95a2165083f20b490f8f78a14bb104"
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
assertThat(e.getMessage(), equalTo("""
Received an authentication error status code for request from inference entity id [id] status [401]. Error message: [{
"message": "Unauthorized",
"request_id": "ad95a2165083f20b490f8f78a14bb104"
}]"""));
}
public void testSupportsStreaming() throws IOException {
try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) {
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
@ -192,24 +478,37 @@ public class MistralServiceTests extends ESTestCase {
service.parseRequestConfig(
"id",
TaskType.SPARSE_EMBEDDING,
getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getSecretSettingsMap("secret")
),
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
modelVerificationListener
);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
TaskType.TEXT_EMBEDDING
);
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_Completion() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
getRequestConfigMap(getServiceSettingsMap("mistral-completion"), getSecretSettingsMap("secret")),
TaskType.COMPLETION
);
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_ChatCompletion() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
getRequestConfigMap(getServiceSettingsMap("mistral-chat-completion"), getSecretSettingsMap("secret")),
TaskType.CHAT_COMPLETION
);
}
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(Map<String, Object> secret, TaskType chatCompletion)
throws IOException {
try (var service = createService()) {
var config = getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getSecretSettingsMap("secret")
);
config.put("extra_key", "value");
secret.put("extra_key", "value");
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
@ -222,20 +521,40 @@ public class MistralServiceTests extends ESTestCase {
}
);
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
service.parseRequestConfig("id", chatCompletion, secret, modelVerificationListener);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
getEmbeddingsServiceSettingsMap(null, null),
TaskType.TEXT_EMBEDDING
);
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionTaskSettingsMap() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
getServiceSettingsMap("mistral-completion"),
TaskType.COMPLETION
);
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
getServiceSettingsMap("mistral-chat-completion"),
TaskType.CHAT_COMPLETION
);
}
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
Map<String, Object> serviceSettingsMap,
TaskType chatCompletion
) throws IOException {
try (var service = createService()) {
var taskSettings = new HashMap<String, Object>();
taskSettings.put("extra_key", "value");
var config = getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
taskSettings,
getSecretSettingsMap("secret")
);
var config = getRequestConfigMap(serviceSettingsMap, taskSettings, getSecretSettingsMap("secret"));
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
@ -248,7 +567,7 @@ public class MistralServiceTests extends ESTestCase {
}
);
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
}
}
@ -257,11 +576,7 @@ public class MistralServiceTests extends ESTestCase {
var secretSettings = getSecretSettingsMap("secret");
secretSettings.put("extra_key", "value");
var config = getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
secretSettings
);
var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), secretSettings);
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
@ -278,11 +593,42 @@ public class MistralServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-completion", TaskType.COMPLETION);
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-chat-completion", TaskType.CHAT_COMPLETION);
}
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap(String modelId, TaskType chatCompletion)
throws IOException {
try (var service = createService()) {
var secretSettings = getSecretSettingsMap("secret");
secretSettings.put("extra_key", "value");
var config = getRequestConfigMap(getServiceSettingsMap(modelId), secretSettings);
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service")
);
}
);
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
}
}
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(1024, 512),
getTaskSettingsMap(),
getSecretSettingsMap("secret")
);
@ -298,11 +644,33 @@ public class MistralServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfig_CreatesAMistralCompletionModel() throws IOException {
testParsePersistedConfig_CreatesAMistralModel("mistral-completion", TaskType.COMPLETION);
}
public void testParsePersistedConfig_CreatesAMistralChatCompletionModel() throws IOException {
testParsePersistedConfig_CreatesAMistralModel("mistral-chat-completion", TaskType.CHAT_COMPLETION);
}
private void testParsePersistedConfig_CreatesAMistralModel(String modelId, TaskType chatCompletion) throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(getServiceSettingsMap(modelId), getTaskSettingsMap(), getSecretSettingsMap("secret"));
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
assertThat(model, instanceOf(MistralChatCompletionModel.class));
var embeddingsModel = (MistralChatCompletionModel) model;
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(1024, 512),
getTaskSettingsMap(),
createRandomChunkingSettingsMap(),
getSecretSettingsMap("secret")
);
@ -323,8 +691,8 @@ public class MistralServiceTests extends ESTestCase {
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(1024, 512),
getTaskSettingsMap(),
getSecretSettingsMap("secret")
);
@ -354,11 +722,7 @@ public class MistralServiceTests extends ESTestCase {
service.parseRequestConfig(
"id",
TaskType.SPARSE_EMBEDDING,
getRequestConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getSecretSettingsMap("secret")
),
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
modelVerificationListener
);
}
@ -367,8 +731,8 @@ public class MistralServiceTests extends ESTestCase {
public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(null, null),
getTaskSettingsMap(),
getSecretSettingsMap("secret")
);
@ -384,38 +748,92 @@ public class MistralServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigEmbeddings() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
getEmbeddingsServiceSettingsMap(1024, 512),
TaskType.TEXT_EMBEDDING,
instanceOf(MistralEmbeddingsModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigCompletion() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
getServiceSettingsMap("mistral-completion"),
TaskType.COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigChatCompletion() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
getServiceSettingsMap("mistral-chat-completion"),
TaskType.CHAT_COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
Map<String, Object> serviceSettingsMap,
TaskType chatCompletion,
Matcher<MistralModel> matcher
) throws IOException {
try (var service = createService()) {
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
var taskSettings = getEmbeddingsTaskSettingsMap();
var taskSettings = getTaskSettingsMap();
var secretSettings = getSecretSettingsMap("secret");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
config.config().put("extra_key", "value");
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
assertThat(model, matcher);
}
}
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
getEmbeddingsServiceSettingsMap(1024, 512),
TaskType.TEXT_EMBEDDING,
instanceOf(MistralEmbeddingsModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInCompletionServiceSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
getServiceSettingsMap("mistral-completion"),
TaskType.COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
getServiceSettingsMap("mistral-chat-completion"),
TaskType.CHAT_COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
private void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
Map<String, Object> serviceSettingsMap,
TaskType chatCompletion,
Matcher<MistralModel> matcher
) throws IOException {
try (var service = createService()) {
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
serviceSettings.put("extra_key", "value");
serviceSettingsMap.put("extra_key", "value");
var taskSettings = getEmbeddingsTaskSettingsMap();
var taskSettings = getTaskSettingsMap();
var secretSettings = getSecretSettingsMap("secret");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
assertThat(model, matcher);
}
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
try (var service = createService()) {
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
var serviceSettings = getEmbeddingsServiceSettingsMap(1024, 512);
var taskSettings = new HashMap<String, Object>();
taskSettings.put("extra_key", "value");
@ -429,27 +847,50 @@ public class MistralServiceTests extends ESTestCase {
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
getEmbeddingsServiceSettingsMap(1024, 512),
TaskType.TEXT_EMBEDDING,
instanceOf(MistralEmbeddingsModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
getServiceSettingsMap("mistral-completion"),
TaskType.COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
getServiceSettingsMap("mistral-chat-completion"),
TaskType.CHAT_COMPLETION,
instanceOf(MistralChatCompletionModel.class)
);
}
private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
Map<String, Object> serviceSettingsMap,
TaskType chatCompletion,
Matcher<MistralModel> matcher
) throws IOException {
try (var service = createService()) {
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
var taskSettings = getEmbeddingsTaskSettingsMap();
var taskSettings = getTaskSettingsMap();
var secretSettings = getSecretSettingsMap("secret");
secretSettings.put("extra_key", "value");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
assertThat(model, matcher);
}
}
public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
Map.of()
);
var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of());
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config());
@ -465,8 +906,8 @@ public class MistralServiceTests extends ESTestCase {
public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
getEmbeddingsServiceSettingsMap(1024, 512),
getTaskSettingsMap(),
createRandomChunkingSettingsMap(),
Map.of()
);
@ -485,11 +926,7 @@ public class MistralServiceTests extends ESTestCase {
public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
getEmbeddingsTaskSettingsMap(),
Map.of()
);
var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of());
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config());
@ -780,7 +1217,7 @@ public class MistralServiceTests extends ESTestCase {
{
"service": "mistral",
"name": "Mistral",
"task_types": ["text_embedding"],
"task_types": ["text_embedding", "completion", "chat_completion"],
"configurations": {
"api_key": {
"description": "API Key for the provider you're connecting to.",
@ -789,7 +1226,7 @@ public class MistralServiceTests extends ESTestCase {
"sensitive": true,
"updatable": true,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"model": {
"description": "Refer to the Mistral models documentation for the list of available text embedding models.",
@ -798,7 +1235,7 @@ public class MistralServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.",
@ -807,7 +1244,7 @@ public class MistralServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.",
@ -816,7 +1253,7 @@ public class MistralServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
}
}
}
@ -868,16 +1305,19 @@ public class MistralServiceTests extends ESTestCase {
);
}
private static Map<String, Object> getEmbeddingsServiceSettingsMap(
String model,
@Nullable Integer dimensions,
@Nullable Integer maxTokens,
@Nullable SimilarityMeasure similarityMeasure
) {
return createRequestSettingsMap(model, dimensions, maxTokens, similarityMeasure);
private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSettings, Map<String, Object> secretSettings) {
var builtServiceSettings = new HashMap<>();
builtServiceSettings.putAll(serviceSettings);
builtServiceSettings.putAll(secretSettings);
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
}
private static Map<String, Object> getEmbeddingsTaskSettingsMap() {
private static Map<String, Object> getEmbeddingsServiceSettingsMap(@Nullable Integer dimensions, @Nullable Integer maxTokens) {
return createRequestSettingsMap("mistral-embed", dimensions, maxTokens, null);
}
private static Map<String, Object> getTaskSettingsMap() {
// no task settings for Mistral embeddings
return Map.of();
}
@ -886,25 +1326,4 @@ public class MistralServiceTests extends ESTestCase {
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
}
private static final String testEmbeddingResultJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.0123,
-0.0123
]
}
],
"model": "text-embedding-ada-002-v2",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
""";
}

View File

@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral;
import org.apache.http.HttpResponse;
import org.apache.http.StatusLine;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MistralUnifiedChatCompletionResponseHandlerTests extends ESTestCase {
private final MistralUnifiedChatCompletionResponseHandler responseHandler = new MistralUnifiedChatCompletionResponseHandler(
"chat completions",
(a, b) -> mock()
);
public void testFailNotFound() throws IOException {
var responseJson = XContentHelper.stripWhitespace("""
{
"detail": "Not Found"
}
""");
var errorJson = invalidResponseJson(responseJson, 404);
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
{
"error" : {
"code" : "not_found",
"message" : "Resource not found at [https://api.mistral.ai/v1/chat/completions] for request from inference entity id [id] \
status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
"type" : "mistral_error"
}
}""")));
}
public void testFailUnauthorized() throws IOException {
var responseJson = XContentHelper.stripWhitespace("""
{
"message": "Unauthorized",
"request_id": "a580d263fb1521778782b22104efb415"
}
""");
var errorJson = invalidResponseJson(responseJson, 401);
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
{
"error" : {
"code" : "unauthorized",
"message" : "Received an authentication error status code for request from inference entity id [id] status [401]. Error \
message: [{\\"message\\":\\"Unauthorized\\",\\"request_id\\":\\"a580d263fb1521778782b22104efb415\\"}]",
"type" : "mistral_error"
}
}""")));
}
public void testFailBadRequest() throws IOException {
var responseJson = XContentHelper.stripWhitespace("""
{
"object": "error",
"message": "Invalid model: mistral-small-l2atest",
"type": "invalid_model",
"param": null,
"code": "1500"
}
""");
var errorJson = invalidResponseJson(responseJson, 400);
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
{
"error" : {
"code" : "bad_request",
"message" : "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \
[{\\"object\\":\\"error\\",\\"message\\":\\"Invalid model: mistral-small-l2atest\\",\\"type\\":\\"invalid_model\\",\\"par\
am\\":null,\\"code\\":\\"1500\\"}]",
"type" : "mistral_error"
}
}""")));
}
private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
var exception = invalidResponse(responseJson, statusCode);
assertThat(exception, isA(RetryException.class));
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
}
private Exception invalidResponse(String responseJson, int statusCode) {
return expectThrows(
RetryException.class,
() -> responseHandler.validateResponse(
mock(),
mock(),
mockRequest(),
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);
}
private static Request mockRequest() throws URISyntaxException {
var request = mock(Request.class);
when(request.getInferenceEntityId()).thenReturn("id");
when(request.isStreaming()).thenReturn(true);
when(request.getURI()).thenReturn(new URI("https://api.mistral.ai/v1/chat/completions"));
return request;
}
private static HttpResponse mockErrorResponse(int statusCode) {
var statusLine = mock(StatusLine.class);
when(statusLine.getStatusCode()).thenReturn(statusCode);
var response = mock(HttpResponse.class);
when(response.getStatusLine()).thenReturn(statusLine);
return response;
}
private String toJson(UnifiedChatCompletionException e) throws IOException {
try (var builder = XContentFactory.jsonBuilder()) {
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
}
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.action;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class MistralActionCreatorTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "chat.completion",
"id": "",
"created": 1745855316,
"model": "/repository",
"system_fingerprint": "3.2.3-sha-a1f3ebe",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello there, how may I assist you today?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 8,
"completion_tokens": 50,
"total_tokens": 58
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
assertChatCompletionRequest();
}
}
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
var settings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"invalid_field": "unexpected"
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
sender,
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to send Mistral completion request from inference entity id " + "[id]. Cause: Required [choices]")
);
assertChatCompletionRequest();
}
}
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
var model = MistralChatCompletionModelTests.createCompletionModel("secret", "model");
model.setURI(getUrl(webServer));
var actionCreator = new MistralActionCreator(sender, threadPool);
var action = actionCreator.create(model);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
return listener;
}
private void assertChatCompletionRequest() throws IOException {
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
);
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertThat(requestMap.get("stream"), is(false));
}
}

View File

@ -0,0 +1,236 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.action;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.COMPLETION_HANDLER;
import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.USER_ROLE;
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
public class MistralChatCompletionActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "9d80f26810ac4e9582f927fcf0512ec7",
"object": "chat.completion",
"created": 1748596419,
"model": "mistral-small-latest",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": null,
"content": "result content"
},
"finish_reason": "length",
"logprobs": null
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 11,
"completion_tokens": 1
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content"))));
assertThat(webServer.requests(), hasSize(1));
MockRequest request = webServer.requests().get(0);
assertNull(request.getUri().getQuery());
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()));
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(request.getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertThat(requestMap.get("stream"), is(false));
}
}
public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("failed"));
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
var sender = mock(Sender.class);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
}
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "9d80f26810ac4e9582f927fcf0512ec7",
"object": "chat.completion",
"created": 1748596419,
"model": "mistral-small-latest",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": null,
"content": "result content"
},
"finish_reason": "length",
"logprobs": null
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 11,
"completion_tokens": 1
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}
}
private ExecutableAction createAction(String url, Sender sender) {
var model = createCompletionModel("secret", "model");
model.setURI(url);
var manager = new GenericRequestManager<>(
threadPool,
model,
COMPLETION_HANDLER,
inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
}
}

View File

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.completion;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.List;
import static org.hamcrest.Matchers.is;
public class MistralChatCompletionModelTests extends ESTestCase {
public static MistralChatCompletionModel createCompletionModel(String apiKey, String modelId) {
return new MistralChatCompletionModel(
"id",
TaskType.COMPLETION,
"service",
new MistralChatCompletionServiceSettings(modelId, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static MistralChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) {
MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel(
"id",
TaskType.COMPLETION,
"service",
new MistralChatCompletionServiceSettings(modelId, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
mistralChatCompletionModel.setURI(url);
return mistralChatCompletionModel;
}
public static MistralChatCompletionModel createChatCompletionModel(String apiKey, String modelId) {
return new MistralChatCompletionModel(
"id",
TaskType.CHAT_COMPLETION,
"service",
new MistralChatCompletionServiceSettings(modelId, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static MistralChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) {
MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel(
"id",
TaskType.CHAT_COMPLETION,
"service",
new MistralChatCompletionServiceSettings(modelId, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
mistralChatCompletionModel.setURI(url);
return mistralChatCompletionModel;
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
var model = createCompletionModel("api_key", "model_name");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = MistralChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
var model = createCompletionModel("api_key", null);
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = MistralChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
var model = createCompletionModel("api_key", null);
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null,
null,
null,
null,
null,
null,
null
);
var overriddenModel = MistralChatCompletionModel.of(model, request);
assertNull(overriddenModel.getServiceSettings().modelId());
}
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
var model = createCompletionModel("api_key", "model_name");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null, // not overriding model
null,
null,
null,
null,
null,
null
);
var overriddenModel = MistralChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
}
}

View File

@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
public class MistralChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<MistralChatCompletionServiceSettings> {
public static final String MODEL_ID = "some model";
public static final int RATE_LIMIT = 2;
public void testFromMap_AllFields_Success() {
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
MistralConstants.MODEL_FIELD,
MODEL_ID,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
serviceSettings,
is(
new MistralChatCompletionServiceSettings(
MODEL_ID,
new RateLimitSettings(RATE_LIMIT)
)
)
);
}
public void testFromMap_MissingModelId_ThrowsException() {
var thrownException = expectThrows(
ValidationException.class,
() -> MistralChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)))
),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
thrownException.getMessage(),
containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model];")
);
}
public void testFromMap_MissingRateLimit_Success() {
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new MistralChatCompletionServiceSettings(MODEL_ID, null)));
}
public void testToXContent_WritesAllValues() throws IOException {
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
MistralConstants.MODEL_FIELD,
MODEL_ID,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
var expected = XContentHelper.stripWhitespace("""
{
"model": "some model",
"rate_limit": {
"requests_per_minute": 2
}
}
""");
assertThat(xContentResult, is(expected));
}
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)),
ConfigurationParseContext.PERSISTENT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
var expected = XContentHelper.stripWhitespace("""
{
"model": "some model",
"rate_limit": {
"requests_per_minute": 240
}
}
""");
assertThat(xContentResult, is(expected));
}
@Override
protected Writeable.Reader<MistralChatCompletionServiceSettings> instanceReader() {
return MistralChatCompletionServiceSettings::new;
}
@Override
protected MistralChatCompletionServiceSettings createTestInstance() {
return createRandom();
}
@Override
protected MistralChatCompletionServiceSettings mutateInstance(MistralChatCompletionServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, MistralChatCompletionServiceSettingsTests::createRandom);
}
@Override
protected MistralChatCompletionServiceSettings mutateInstanceForVersion(
MistralChatCompletionServiceSettings instance,
TransportVersion version
) {
return instance;
}
private static MistralChatCompletionServiceSettings createRandom() {
var modelId = randomAlphaOfLength(8);
return new MistralChatCompletionServiceSettings(modelId, RateLimitSettingsTests.createRandom());
}
public static Map<String, Object> getServiceSettingsMap(String model) {
var map = new HashMap<String, Object>();
map.put(MistralConstants.MODEL_FIELD, model);
return map;
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequestEntity;
import java.io.IOException;
import java.util.List;

View File

@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
import java.io.IOException;
import java.util.List;

View File

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
import java.io.IOException;
import java.util.ArrayList;
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel;
public class MistralChatCompletionRequestEntityTests extends ESTestCase {
private static final String ROLE = "user";
public void testModelUserFieldsSerialization() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(message);
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
MistralChatCompletionModel model = createCompletionModel("api-key", "test-endpoint");
MistralChatCompletionRequestEntity entity = new MistralChatCompletionRequestEntity(unifiedChatInput, model);
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String expectedJson = """
{
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
],
"model": "test-endpoint",
"n": 1,
"stream": true
}
""";
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
}
}

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class MistralChatCompletionRequestTests extends ESTestCase {
public void testCreateRequest_WithStreaming() throws IOException {
var request = createRequest("secret", randomAlphaOfLength(15), "model", true);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap.get("stream"), is(true));
}
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
String input = randomAlphaOfLength(5);
var request = createRequest("secret", input, "model", true);
var truncatedRequest = request.truncate();
assertThat(request.getURI().toString(), is(MistralConstants.API_COMPLETIONS_PATH));
var httpRequest = truncatedRequest.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(4));
// We do not truncate for Hugging Face chat completions
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertTrue((Boolean) requestMap.get("stream"));
assertNull(requestMap.get("stream_options")); // Mistral does not use stream options
}
public void testTruncationInfo_ReturnsNull() {
var request = createRequest("secret", randomAlphaOfLength(5), "model", true);
assertNull(request.getTruncationInfo());
}
public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model) {
return createRequest(apiKey, input, model, false);
}
public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream) {
var chatCompletionModel = MistralChatCompletionModelTests.createCompletionModel(apiKey, model);
return new MistralChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
}
}

View File

@ -0,0 +1,33 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.mistral.response;
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import java.nio.charset.StandardCharsets;
import static org.mockito.Mockito.mock;
public class MistralErrorResponseTests extends ESTestCase {
public static final String ERROR_RESPONSE_JSON = """
{
"error": "A valid user token is required"
}
""";
public void testFromResponse() {
var errorResponse = MistralErrorResponse.fromResponse(
new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8))
);
assertNotNull(errorResponse);
assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage());
}
}

View File

@ -100,7 +100,7 @@ public class OpenAiResponseHandlerTests extends ESTestCase {
assertFalse(retryException.shouldRetry());
assertThat(
retryException.getCause().getMessage(),
containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
containsString("Received a bad request status code for request from inference entity id [id] status [400]")
);
assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
// 400 is not flagged as a content too large when the error message is different
@ -112,7 +112,7 @@ public class OpenAiResponseHandlerTests extends ESTestCase {
assertFalse(retryException.shouldRetry());
assertThat(
retryException.getCause().getMessage(),
containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
containsString("Received a bad request status code for request from inference entity id [id] status [400]")
);
assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
// 401

View File

@ -65,6 +65,7 @@ import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
@ -1153,14 +1154,14 @@ public class OpenAiServiceTests extends ESTestCase {
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is("""
assertThat(json, is(String.format(Locale.ROOT, """
{\
"error":{\
"code":"model_not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
"message":"Resource not found at [%s] for request from inference entity id [id] status \
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
"type":"invalid_request_error"\
}}"""));
}}""", getUrl(webServer))));
} catch (IOException ex) {
throw new RuntimeException(ex);
}