Add Mistral AI Chat Completion support to Inference Plugin (#128538)
* Add Mistral AI Chat Completion support to Inference Plugin * Add changelog file * Fix tests and typos * Refactor Mistral chat completion integration and add tests * Refactor Mistral error response handling and extract StreamingErrorResponse entity * Add Mistral chat completion request and response tests * Enhance error response documentation and clarify StreamingErrorResponse structure * Refactor Mistral chat completion request handling and introduce skip stream options parameter * Refactor MistralChatCompletionServiceSettings to include rateLimitSettings in equality checks * Enhance MistralErrorResponse documentation with detailed error examples * Add comment for Mistral-specific 422 validation error in OpenAiResponseHandler * [CI] Auto commit changes from spotless * Refactor OpenAiUnifiedChatCompletionRequestEntity to remove unused fields and streamline constructor * Refactor UnifiedChatCompletionRequestEntity and UnifiedCompletionRequest to rename and update stream options parameter * Refactor MistralChatCompletionRequestEntityTests to improve JSON assertion and remove unused imports * Add unit tests for MistralUnifiedChatCompletionResponseHandler to validate error handling * Add unit tests for MistralService * Update expected service count in testGetServicesWithCompletionTaskType --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
3b217b19bd
commit
767d53fefa
|
@ -0,0 +1,5 @@
|
|||
pr: 128538
|
||||
summary: "Added Mistral Chat Completion support to the Inference Plugin"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -189,6 +189,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion DATA_STREAM_OPTIONS_API_REMOVE_INCLUDE_DEFAULTS_8_19 = def(8_841_0_41);
|
||||
public static final TransportVersion JOIN_ON_ALIASES_8_19 = def(8_841_0_42);
|
||||
public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43);
|
||||
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_44);
|
||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||
|
@ -282,6 +283,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES = def(9_087_0_00);
|
||||
public static final TransportVersion JOIN_ON_ALIASES = def(9_088_0_00);
|
||||
public static final TransportVersion ILM_ADD_SKIP_SETTING = def(9_089_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
|
|||
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
|
||||
*/
|
||||
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
|
||||
/**
|
||||
* Indicates whether to include the `stream_options` field in the JSON output.
|
||||
* Some providers do not support this field. In such cases, this parameter should be set to "false",
|
||||
* and the `stream_options` field will be excluded from the output.
|
||||
* For providers that do support stream options, this parameter is left unset (default behavior),
|
||||
* which implicitly includes the `stream_options` field in the output.
|
||||
*/
|
||||
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";
|
||||
|
||||
/**
|
||||
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
|
||||
|
@ -91,6 +99,23 @@ public record UnifiedCompletionRequest(
|
|||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
|
||||
* - Key: {@link #MODEL_FIELD}, Value: modelId
|
||||
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
|
||||
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
|
||||
*/
|
||||
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
|
||||
return new DelegatingMapParams(
|
||||
Map.ofEntries(
|
||||
Map.entry(MODEL_ID_PARAM, modelId),
|
||||
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
|
||||
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
|
||||
),
|
||||
params
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
|
||||
* - Key: {@link #MODEL_FIELD}, Value: modelId
|
||||
|
|
|
@ -134,7 +134,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
|
||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.COMPLETION);
|
||||
assertThat(services.size(), equalTo(13));
|
||||
assertThat(services.size(), equalTo(14));
|
||||
|
||||
var providers = providers(services);
|
||||
|
||||
|
@ -154,7 +154,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"openai",
|
||||
"streaming_completion_test_service",
|
||||
"hugging_face",
|
||||
"amazon_sagemaker"
|
||||
"amazon_sagemaker",
|
||||
"mistral"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
@ -162,7 +163,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
|
||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||
assertThat(services.size(), equalTo(7));
|
||||
assertThat(services.size(), equalTo(8));
|
||||
|
||||
var providers = providers(services);
|
||||
|
||||
|
@ -176,7 +177,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"streaming_completion_test_service",
|
||||
"hugging_face",
|
||||
"amazon_sagemaker",
|
||||
"googlevertexai"
|
||||
"googlevertexai",
|
||||
"mistral"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
|
|
@ -100,6 +100,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbedd
|
|||
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
|
||||
|
@ -266,6 +267,13 @@ public class InferenceNamedWriteablesProvider {
|
|||
MistralEmbeddingsServiceSettings::new
|
||||
)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
MistralChatCompletionServiceSettings.NAME,
|
||||
MistralChatCompletionServiceSettings::new
|
||||
)
|
||||
);
|
||||
|
||||
// note - no task settings for Mistral embeddings...
|
||||
}
|
||||
|
|
|
@ -21,12 +21,13 @@ import java.util.Objects;
|
|||
* A pattern is emerging in how external providers provide error responses.
|
||||
*
|
||||
* At a minimum, these return:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "error: {
|
||||
* "message": "(error message)"
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* </code></pre>
|
||||
* Others may return additional information such as error codes specific to the service.
|
||||
*
|
||||
* This currently covers error handling for Azure AI Studio, however this pattern
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.streaming;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Represents an error response from a streaming inference service.
|
||||
* This class extends {@link ErrorResponse} and provides additional fields
|
||||
* specific to streaming errors, such as code, param, and type.
|
||||
* An example error response for a streaming service might look like:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "error": {
|
||||
* "message": "Invalid input",
|
||||
* "code": "400",
|
||||
* "param": "input",
|
||||
* "type": "invalid_request_error"
|
||||
* }
|
||||
* }
|
||||
* </code></pre>
|
||||
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
|
||||
*/
|
||||
public class StreamingErrorResponse extends ErrorResponse {
|
||||
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
|
||||
"streaming_error",
|
||||
true,
|
||||
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
|
||||
);
|
||||
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
|
||||
"streaming_error",
|
||||
true,
|
||||
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
|
||||
);
|
||||
|
||||
static {
|
||||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
|
||||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
|
||||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
|
||||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));
|
||||
|
||||
ERROR_PARSER.declareObjectOrNull(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
ERROR_BODY_PARSER,
|
||||
null,
|
||||
new ParseField("error")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Standard error response parser. This can be overridden for those subclasses that
|
||||
* have a different error response structure.
|
||||
* @param response The error response as an HttpResult
|
||||
*/
|
||||
public static ErrorResponse fromResponse(HttpResult response) {
|
||||
try (
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(XContentParserConfiguration.EMPTY, response.body())
|
||||
) {
|
||||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
|
||||
/**
|
||||
* Standard error response parser. This can be overridden for those subclasses that
|
||||
* have a different error response structure.
|
||||
* @param response The error response as a string
|
||||
*/
|
||||
public static ErrorResponse fromString(String response) {
|
||||
try (
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
|
||||
) {
|
||||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private final String code;
|
||||
@Nullable
|
||||
private final String param;
|
||||
private final String type;
|
||||
|
||||
StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
|
||||
super(errorMessage);
|
||||
this.code = code;
|
||||
this.param = param;
|
||||
this.type = Objects.requireNonNull(type);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public String code() {
|
||||
return code;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public String param() {
|
||||
return param;
|
||||
}
|
||||
|
||||
public String type() {
|
||||
return type;
|
||||
}
|
||||
}
|
|
@ -15,6 +15,12 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
|||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM;
|
||||
|
||||
/**
|
||||
* Represents a unified chat completion request entity.
|
||||
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
|
||||
*/
|
||||
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
|
||||
|
||||
public static final String STREAM_FIELD = "stream";
|
||||
|
@ -42,7 +48,8 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
|
|||
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
|
||||
|
||||
builder.field(STREAM_FIELD, stream);
|
||||
if (stream) {
|
||||
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
|
||||
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) {
|
||||
builder.startObject(STREAM_OPTIONS_FIELD);
|
||||
builder.field(INCLUDE_USAGE_FIELD, true);
|
||||
builder.endObject();
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||
|
||||
/**
|
||||
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
|
||||
* This class is specifically designed to handle Mistral's error response format.
|
||||
*/
|
||||
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
|
||||
|
||||
/**
|
||||
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
|
||||
*
|
||||
* @param requestType The type of request being handled (e.g., "mistral completions").
|
||||
* @param parseFunction The function to parse the response.
|
||||
*/
|
||||
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.services.mistral;
|
|||
|
||||
public class MistralConstants {
|
||||
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
|
||||
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";
|
||||
|
||||
// note - there is no bounds information available from Mistral,
|
||||
// so we'll use a sane default here which is the same as Cohere's
|
||||
|
@ -18,4 +19,8 @@ public class MistralConstants {
|
|||
public static final String MODEL_FIELD = "model";
|
||||
public static final String INPUT_FIELD = "input";
|
||||
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
|
||||
public static final String MAX_TOKENS_FIELD = "max_tokens";
|
||||
public static final String DETAIL_FIELD = "detail";
|
||||
public static final String MSG_FIELD = "msg";
|
||||
public static final String MESSAGE_FIELD = "message";
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
|||
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;
|
||||
|
||||
import java.util.List;
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral;
|
||||
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
|
||||
/**
|
||||
* Represents a Mistral model that can be used for inference tasks.
|
||||
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
|
||||
*/
|
||||
public abstract class MistralModel extends RateLimitGroupingModel {
|
||||
protected String model;
|
||||
protected URI uri;
|
||||
protected RateLimitSettings rateLimitSettings;
|
||||
|
||||
protected MistralModel(ModelConfigurations configurations, ModelSecrets secrets) {
|
||||
super(configurations, secrets);
|
||||
}
|
||||
|
||||
protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) {
|
||||
super(model, serviceSettings);
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
public URI uri() {
|
||||
return this.uri;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return this.rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int rateLimitGroupingHash() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Needed for testing only
|
||||
public void setURI(String newUri) {
|
||||
try {
|
||||
this.uri = new URI(newUri);
|
||||
} catch (URISyntaxException e) {
|
||||
// swallow any error
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
}
|
|
@ -30,7 +30,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
|||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
|
@ -39,8 +42,11 @@ import org.elasticsearch.xpack.inference.services.SenderService;
|
|||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -48,6 +54,7 @@ import java.util.EnumSet;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||
|
@ -56,14 +63,26 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFrom
|
|||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
|
||||
|
||||
/**
|
||||
* MistralService is an implementation of the SenderService that handles inference tasks
|
||||
* using Mistral models. It supports text embedding, completion, and chat completion tasks.
|
||||
* The service uses MistralActionCreator to create actions for executing inference requests.
|
||||
*/
|
||||
public class MistralService extends SenderService {
|
||||
public static final String NAME = "mistral";
|
||||
|
||||
private static final String SERVICE_NAME = "Mistral";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
TaskType.COMPLETION,
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler(
|
||||
"mistral chat completions",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
|
||||
public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
|
@ -79,11 +98,16 @@ public class MistralService extends SenderService {
|
|||
) {
|
||||
var actionCreator = new MistralActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) {
|
||||
var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
|
||||
action.execute(inputs, timeout, listener);
|
||||
} else {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
switch (model) {
|
||||
case MistralEmbeddingsModel mistralEmbeddingsModel:
|
||||
mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener);
|
||||
break;
|
||||
case MistralChatCompletionModel mistralChatCompletionModel:
|
||||
mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener);
|
||||
break;
|
||||
default:
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,7 +123,24 @@ public class MistralService extends SenderService {
|
|||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
throwUnsupportedUnifiedCompletionOperation(NAME);
|
||||
if (model instanceof MistralChatCompletionModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
MistralChatCompletionModel mistralChatCompletionModel = (MistralChatCompletionModel) model;
|
||||
var overriddenModel = MistralChatCompletionModel.of(mistralChatCompletionModel, inputs.getRequest());
|
||||
var manager = new GenericRequestManager<>(
|
||||
getServiceComponents().threadPool(),
|
||||
overriddenModel,
|
||||
UNIFIED_CHAT_COMPLETION_HANDLER,
|
||||
unifiedChatInput -> new MistralChatCompletionRequest(unifiedChatInput, overriddenModel),
|
||||
UnifiedChatInput.class
|
||||
);
|
||||
var errorMessage = MistralActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
|
||||
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
|
||||
|
||||
action.execute(inputs, timeout, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -162,7 +203,7 @@ public class MistralService extends SenderService {
|
|||
);
|
||||
}
|
||||
|
||||
MistralEmbeddingsModel model = createModel(
|
||||
MistralModel model = createModel(
|
||||
modelId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
|
@ -184,7 +225,7 @@ public class MistralService extends SenderService {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfigWithSecrets(
|
||||
public MistralModel parsePersistedConfigWithSecrets(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
|
@ -211,7 +252,7 @@ public class MistralService extends SenderService {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
||||
public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
|
@ -236,7 +277,12 @@ public class MistralService extends SenderService {
|
|||
return TransportVersions.V_8_15_0;
|
||||
}
|
||||
|
||||
private static MistralEmbeddingsModel createModel(
|
||||
@Override
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
private static MistralModel createModel(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
|
@ -246,23 +292,26 @@ public class MistralService extends SenderService {
|
|||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
if (taskType == TaskType.TEXT_EMBEDDING) {
|
||||
return new MistralEmbeddingsModel(
|
||||
modelId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
chunkingSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
switch (taskType) {
|
||||
case TEXT_EMBEDDING:
|
||||
return new MistralEmbeddingsModel(
|
||||
modelId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
chunkingSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
case CHAT_COMPLETION, COMPLETION:
|
||||
return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
default:
|
||||
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
}
|
||||
|
||||
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
}
|
||||
|
||||
private MistralEmbeddingsModel createModelFromPersistent(
|
||||
private MistralModel createModelFromPersistent(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
|
@ -284,7 +333,7 @@ public class MistralService extends SenderService {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||
public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||
if (model instanceof MistralEmbeddingsModel embeddingsModel) {
|
||||
var serviceSettings = embeddingsModel.getServiceSettings();
|
||||
|
||||
|
@ -304,6 +353,10 @@ public class MistralService extends SenderService {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration class for the Mistral inference service.
|
||||
* It provides the settings and configurations required for the service.
|
||||
*/
|
||||
public static class Configuration {
|
||||
public static InferenceServiceConfiguration get() {
|
||||
return configuration.getOrCompute();
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral;
|
||||
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Handles streaming chat completion responses and error parsing for Mistral inference endpoints.
|
||||
* Adapts the OpenAI handler to support Mistral's error schema.
|
||||
*/
|
||||
public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
|
||||
|
||||
private static final String MISTRAL_ERROR = "mistral_error";
|
||||
|
||||
public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
|
||||
assert request.isStreaming() : "Only streaming requests support this format";
|
||||
var responseStatusCode = result.response().getStatusLine().getStatusCode();
|
||||
if (request.isStreaming()) {
|
||||
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
|
||||
var restStatus = toRestStatus(responseStatusCode);
|
||||
return errorResponse instanceof MistralErrorResponse
|
||||
? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
|
||||
: new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
errorMessage,
|
||||
createErrorType(errorResponse),
|
||||
restStatus.name().toLowerCase(Locale.ROOT)
|
||||
);
|
||||
} else {
|
||||
return super.buildError(message, request, result, errorResponse);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,19 +7,41 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.action;
|
||||
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralEmbeddingsRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
/**
|
||||
* MistralActionCreator is responsible for creating executable actions for Mistral models.
|
||||
* It implements the MistralActionVisitor interface to provide specific implementations
|
||||
* for different types of Mistral models.
|
||||
*/
|
||||
public class MistralActionCreator implements MistralActionVisitor {
|
||||
|
||||
public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
|
||||
static final String USER_ROLE = "user";
|
||||
static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
|
||||
"mistral completions",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
private final Sender sender;
|
||||
private final ServiceComponents serviceComponents;
|
||||
|
||||
|
@ -35,7 +57,32 @@ public class MistralActionCreator implements MistralActionVisitor {
|
|||
serviceComponents.truncator(),
|
||||
serviceComponents.threadPool()
|
||||
);
|
||||
var errorMessage = constructFailedToSendRequestMessage("Mistral embeddings");
|
||||
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, embeddingsModel.getInferenceEntityId());
|
||||
return new SenderExecutableAction(sender, requestManager, errorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(MistralChatCompletionModel chatCompletionModel) {
|
||||
var manager = new GenericRequestManager<>(
|
||||
serviceComponents.threadPool(),
|
||||
chatCompletionModel,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
|
||||
var errorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId());
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an error message for Mistral actions.
|
||||
*
|
||||
* @param requestType The type of request (e.g., TEXT_EMBEDDING, COMPLETION).
|
||||
* @param inferenceId The ID of the inference entity.
|
||||
* @return A formatted error message.
|
||||
*/
|
||||
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||
return format("Failed to send Mistral %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,10 +8,33 @@
|
|||
package org.elasticsearch.xpack.inference.services.mistral.action;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Interface for creating {@link ExecutableAction} instances for Mistral models.
|
||||
* <p>
|
||||
* This interface is used to create {@link ExecutableAction} instances for different types of Mistral models, such as
|
||||
* {@link MistralEmbeddingsModel} and {@link MistralChatCompletionModel}.
|
||||
*/
|
||||
public interface MistralActionVisitor {
|
||||
|
||||
/**
|
||||
* Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}.
|
||||
*
|
||||
* @param embeddingsModel The model to create the action for.
|
||||
* @param taskSettings The task settings to use.
|
||||
* @return An {@link ExecutableAction} for the given model.
|
||||
*/
|
||||
ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
|
||||
|
||||
/**
|
||||
* Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}.
|
||||
*
|
||||
* @param chatCompletionModel The model to create the action for.
|
||||
* @return An {@link ExecutableAction} for the given model.
|
||||
*/
|
||||
ExecutableAction create(MistralChatCompletionModel chatCompletionModel);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.completion;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH;
|
||||
|
||||
/**
|
||||
* Represents a Mistral chat completion model.
|
||||
* This class extends RateLimitGroupingModel to handle rate limiting based on model and API key.
|
||||
*/
|
||||
public class MistralChatCompletionModel extends MistralModel {
|
||||
|
||||
/**
|
||||
* Constructor for MistralChatCompletionModel.
|
||||
*
|
||||
* @param inferenceEntityId The unique identifier for the inference entity.
|
||||
* @param taskType The type of task this model is designed for.
|
||||
* @param service The name of the service this model belongs to.
|
||||
* @param serviceSettings The settings specific to the Mistral chat completion service.
|
||||
* @param secrets The secrets required for accessing the service.
|
||||
* @param context The context for parsing configuration settings.
|
||||
*/
|
||||
public MistralChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
MistralChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new MistralChatCompletionModel with overridden service settings.
|
||||
*
|
||||
* @param model The original MistralChatCompletionModel.
|
||||
* @param request The UnifiedCompletionRequest containing the model override.
|
||||
* @return A new MistralChatCompletionModel with the overridden model ID.
|
||||
*/
|
||||
public static MistralChatCompletionModel of(MistralChatCompletionModel model, UnifiedCompletionRequest request) {
|
||||
if (request.model() == null) {
|
||||
// If no model is specified in the request, return the original model
|
||||
return model;
|
||||
}
|
||||
|
||||
var originalModelServiceSettings = model.getServiceSettings();
|
||||
var overriddenServiceSettings = new MistralChatCompletionServiceSettings(
|
||||
request.model(),
|
||||
originalModelServiceSettings.rateLimitSettings()
|
||||
);
|
||||
|
||||
return new MistralChatCompletionModel(
|
||||
model.getInferenceEntityId(),
|
||||
model.getTaskType(),
|
||||
model.getConfigurations().getService(),
|
||||
overriddenServiceSettings,
|
||||
model.getSecretSettings()
|
||||
);
|
||||
}
|
||||
|
||||
public MistralChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
MistralChatCompletionServiceSettings serviceSettings,
|
||||
DefaultSecretSettings secrets
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()),
|
||||
new ModelSecrets(secrets)
|
||||
);
|
||||
setPropertiesFromServiceSettings(serviceSettings);
|
||||
}
|
||||
|
||||
private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) {
|
||||
this.model = serviceSettings.modelId();
|
||||
this.rateLimitSettings = serviceSettings.rateLimitSettings();
|
||||
setEndpointUrl();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int rateLimitGroupingHash() {
|
||||
return Objects.hash(model, getSecretSettings().apiKey());
|
||||
}
|
||||
|
||||
private void setEndpointUrl() {
|
||||
try {
|
||||
this.uri = new URI(API_COMPLETIONS_PATH);
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MistralChatCompletionServiceSettings getServiceSettings() {
|
||||
return (MistralChatCompletionServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action for this model.
|
||||
*
|
||||
* @param creator The visitor that creates the executable action.
|
||||
* @return An ExecutableAction that can be executed.
|
||||
*/
|
||||
public ExecutableAction accept(MistralActionVisitor creator) {
|
||||
return creator.create(this);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.completion;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
|
||||
|
||||
/**
|
||||
* Represents the settings for the Mistral chat completion service.
|
||||
* This class encapsulates the model ID and rate limit settings for the Mistral chat completion service.
|
||||
*/
|
||||
public class MistralChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings {
|
||||
public static final String NAME = "mistral_completions_service_settings";
|
||||
|
||||
private final String modelId;
|
||||
private final RateLimitSettings rateLimitSettings;
|
||||
|
||||
// default for Mistral is 5 requests / sec
|
||||
// setting this to 240 (4 requests / sec) is a sane default for us
|
||||
protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240);
|
||||
|
||||
public static MistralChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
MistralService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new MistralChatCompletionServiceSettings(model, rateLimitSettings);
|
||||
}
|
||||
|
||||
public MistralChatCompletionServiceSettings(StreamInput in) throws IOException {
|
||||
this.modelId = in.readString();
|
||||
this.rateLimitSettings = new RateLimitSettings(in);
|
||||
}
|
||||
|
||||
public MistralChatCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
|
||||
this.modelId = modelId;
|
||||
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return this.modelId;
|
||||
}
|
||||
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return this.rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
rateLimitSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
this.toXContentFragmentOfExposedFields(builder, params);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(MODEL_FIELD, this.modelId);
|
||||
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MistralChatCompletionServiceSettings that = (MistralChatCompletionServiceSettings) o;
|
||||
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, rateLimitSettings);
|
||||
}
|
||||
|
||||
}
|
|
@ -10,16 +10,15 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings;
|
|||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
|
@ -27,10 +26,11 @@ import java.util.Map;
|
|||
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_EMBEDDINGS_PATH;
|
||||
|
||||
public class MistralEmbeddingsModel extends Model {
|
||||
protected String model;
|
||||
protected URI uri;
|
||||
protected RateLimitSettings rateLimitSettings;
|
||||
/**
|
||||
* Represents a Mistral embeddings model.
|
||||
* This class extends MistralModel to handle embeddings-specific settings and actions.
|
||||
*/
|
||||
public class MistralEmbeddingsModel extends MistralModel {
|
||||
|
||||
public MistralEmbeddingsModel(
|
||||
String inferenceEntityId,
|
||||
|
@ -58,6 +58,20 @@ public class MistralEmbeddingsModel extends Model {
|
|||
setPropertiesFromServiceSettings(serviceSettings);
|
||||
}
|
||||
|
||||
private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
|
||||
this.model = serviceSettings.modelId();
|
||||
this.rateLimitSettings = serviceSettings.rateLimitSettings();
|
||||
setEndpointUrl();
|
||||
}
|
||||
|
||||
private void setEndpointUrl() {
|
||||
try {
|
||||
this.uri = new URI(API_EMBEDDINGS_PATH);
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public MistralEmbeddingsModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
|
@ -74,51 +88,11 @@ public class MistralEmbeddingsModel extends Model {
|
|||
setPropertiesFromServiceSettings(serviceSettings);
|
||||
}
|
||||
|
||||
private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) {
|
||||
this.model = serviceSettings.modelId();
|
||||
this.rateLimitSettings = serviceSettings.rateLimitSettings();
|
||||
setEndpointUrl();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MistralEmbeddingsServiceSettings getServiceSettings() {
|
||||
return (MistralEmbeddingsServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
public URI uri() {
|
||||
return this.uri;
|
||||
}
|
||||
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return this.rateLimitSettings;
|
||||
}
|
||||
|
||||
private void setEndpointUrl() {
|
||||
try {
|
||||
this.uri = new URI(API_EMBEDDINGS_PATH);
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Needed for testing only
|
||||
public void setURI(String newUri) {
|
||||
try {
|
||||
this.uri = new URI(newUri);
|
||||
} catch (URISyntaxException e) {
|
||||
// swallow any error
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
public ExecutableAction accept(MistralActionVisitor creator, Map<String, Object> taskSettings) {
|
||||
return creator.create(this, taskSettings);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
/**
|
||||
* Mistral Unified Chat Completion Request
|
||||
* This class is responsible for creating a request to the Mistral chat completion model.
|
||||
* It constructs an HTTP POST request with the necessary headers and body content.
|
||||
*/
|
||||
public class MistralChatCompletionRequest implements Request {
|
||||
|
||||
private final MistralChatCompletionModel model;
|
||||
private final UnifiedChatInput chatInput;
|
||||
|
||||
public MistralChatCompletionRequest(UnifiedChatInput chatInput, MistralChatCompletionModel model) {
|
||||
this.chatInput = Objects.requireNonNull(chatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(model.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new MistralChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||
httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey()));
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return model.uri();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
// No truncation for Mistral chat completions
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
// No truncation for Mistral chat completions
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStreaming() {
|
||||
return chatInput.stream();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
|
||||
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* MistralChatCompletionRequestEntity is responsible for creating the request entity for Mistral chat completion.
|
||||
* It implements ToXContentObject to allow serialization to XContent format.
|
||||
*/
|
||||
public class MistralChatCompletionRequestEntity implements ToXContentObject {
|
||||
|
||||
private final MistralChatCompletionModel model;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
|
||||
public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, MistralChatCompletionModel model) {
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
unifiedRequestEntity.toXContent(
|
||||
builder,
|
||||
UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params)
|
||||
);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request;
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.embeddings;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request;
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.embeddings;
|
||||
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.response;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
/**
|
||||
* Represents an error response entity for Mistral inference services.
|
||||
* This class extends ErrorResponse and provides a method to create an instance
|
||||
* from an HttpResult, attempting to read the body as a UTF-8 string.
|
||||
* An example error response for Not Found error would look like:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "detail": "Not Found"
|
||||
* }
|
||||
* </code></pre>
|
||||
* An example error response for Bad Request error would look like:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "object": "error",
|
||||
* "message": "Invalid model: wrong-model-name",
|
||||
* "type": "invalid_model",
|
||||
* "param": null,
|
||||
* "code": "1500"
|
||||
* }
|
||||
* </code></pre>
|
||||
* An example error response for Unauthorized error would look like:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "message": "Unauthorized",
|
||||
* "request_id": "ad95a2165083f20b490f8f78a14bb104"
|
||||
* }
|
||||
* </code></pre>
|
||||
* An example error response for Unprocessable Entity error would look like:
|
||||
* <pre><code>
|
||||
* {
|
||||
* "object": "error",
|
||||
* "message": {
|
||||
* "detail": [
|
||||
* {
|
||||
* "type": "greater_than_equal",
|
||||
* "loc": [
|
||||
* "body",
|
||||
* "max_tokens"
|
||||
* ],
|
||||
* "msg": "Input should be greater than or equal to 0",
|
||||
* "input": -10,
|
||||
* "ctx": {
|
||||
* "ge": 0
|
||||
* }
|
||||
* }
|
||||
* ]
|
||||
* },
|
||||
* "type": "invalid_request_error",
|
||||
* "param": null,
|
||||
* "code": null
|
||||
* }
|
||||
* </code></pre>
|
||||
*/
|
||||
public class MistralErrorResponse extends ErrorResponse {
|
||||
|
||||
public MistralErrorResponse(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an ErrorResponse from the given HttpResult.
|
||||
* Attempts to read the body as a UTF-8 string and constructs a MistralErrorResponseEntity.
|
||||
* If reading fails, returns a generic UNDEFINED_ERROR.
|
||||
*
|
||||
* @param response the HttpResult containing the error response
|
||||
* @return an ErrorResponse instance
|
||||
*/
|
||||
public static ErrorResponse fromResponse(HttpResult response) {
|
||||
try {
|
||||
String errorMessage = new String(response.body(), StandardCharsets.UTF_8);
|
||||
return new MistralErrorResponse(errorMessage);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentE
|
|||
import java.util.concurrent.Flow;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
|
||||
|
||||
public class OpenAiResponseHandler extends BaseResponseHandler {
|
||||
|
@ -40,6 +41,7 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
|
|||
static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens";
|
||||
|
||||
static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length.";
|
||||
static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response";
|
||||
|
||||
static final String OPENAI_SERVER_BUSY = "Received a server busy error status code";
|
||||
|
||||
|
@ -86,11 +88,23 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
|
|||
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
|
||||
} else if (statusCode >= 300 && statusCode < 400) {
|
||||
throw new RetryException(false, buildError(REDIRECTION, request, result));
|
||||
} else if (statusCode == 422) {
|
||||
// OpenAI does not return 422 at the time of writing, but Mistral does and follows most of OpenAI's format.
|
||||
// TODO: Revisit this in the future to decouple OpenAI and Mistral error handling.
|
||||
throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result));
|
||||
} else if (statusCode == 400) {
|
||||
throw new RetryException(false, buildError(BAD_REQUEST, request, result));
|
||||
} else if (statusCode == 404) {
|
||||
throw new RetryException(false, buildError(resourceNotFoundError(request), request, result));
|
||||
} else {
|
||||
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
|
||||
}
|
||||
}
|
||||
|
||||
private static String resourceNotFoundError(Request request) {
|
||||
return format("Resource not found at [%s]", request.getURI());
|
||||
}
|
||||
|
||||
protected RetryException buildExceptionHandling429(Request request, HttpResult result) {
|
||||
return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
|
||||
}
|
||||
|
|
|
@ -7,15 +7,8 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.openai;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
|
@ -24,18 +17,21 @@ import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
|||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
|
||||
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
|
||||
import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse;
|
||||
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
/**
|
||||
* Handles streaming chat completion responses and error parsing for OpenAI inference endpoints.
|
||||
* This handler is designed to work with the unified OpenAI chat completion API.
|
||||
*/
|
||||
public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
|
||||
public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
|
||||
super(requestType, parseFunction, StreamingErrorResponse::fromResponse);
|
||||
}
|
||||
|
||||
public OpenAiUnifiedChatCompletionResponseHandler(
|
||||
|
@ -62,7 +58,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
if (request.isStreaming()) {
|
||||
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
|
||||
var restStatus = toRestStatus(responseStatusCode);
|
||||
return errorResponse instanceof OpenAiErrorResponse oer
|
||||
return errorResponse instanceof StreamingErrorResponse oer
|
||||
? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param())
|
||||
: new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
|
@ -84,8 +80,8 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
}
|
||||
|
||||
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
|
||||
var errorResponse = OpenAiErrorResponse.fromString(message);
|
||||
if (errorResponse instanceof OpenAiErrorResponse oer) {
|
||||
var errorResponse = StreamingErrorResponse.fromString(message);
|
||||
if (errorResponse instanceof StreamingErrorResponse oer) {
|
||||
return new UnifiedChatCompletionException(
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
format(
|
||||
|
@ -109,85 +105,4 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static class OpenAiErrorResponse extends ErrorResponse {
|
||||
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
|
||||
"open_ai_error",
|
||||
true,
|
||||
args -> Optional.ofNullable((OpenAiErrorResponse) args[0])
|
||||
);
|
||||
private static final ConstructingObjectParser<OpenAiErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
|
||||
"open_ai_error",
|
||||
true,
|
||||
args -> new OpenAiErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
|
||||
);
|
||||
|
||||
static {
|
||||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
|
||||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
|
||||
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
|
||||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));
|
||||
|
||||
ERROR_PARSER.declareObjectOrNull(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
ERROR_BODY_PARSER,
|
||||
null,
|
||||
new ParseField("error")
|
||||
);
|
||||
}
|
||||
|
||||
private static ErrorResponse fromResponse(HttpResult response) {
|
||||
try (
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(XContentParserConfiguration.EMPTY, response.body())
|
||||
) {
|
||||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
|
||||
private static ErrorResponse fromString(String response) {
|
||||
try (
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(XContentParserConfiguration.EMPTY, response)
|
||||
) {
|
||||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private final String code;
|
||||
@Nullable
|
||||
private final String param;
|
||||
private final String type;
|
||||
|
||||
OpenAiErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
|
||||
super(errorMessage);
|
||||
this.code = code;
|
||||
this.param = param;
|
||||
this.type = Objects.requireNonNull(type);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public String code() {
|
||||
return code;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public String param() {
|
||||
return param;
|
||||
}
|
||||
|
||||
public String type() {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,15 +21,10 @@ import java.util.Objects;
|
|||
public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject {
|
||||
|
||||
public static final String USER_FIELD = "user";
|
||||
private static final String MODEL_FIELD = "model";
|
||||
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
|
||||
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
private final OpenAiChatCompletionModel model;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
|
||||
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ import static org.elasticsearch.common.Strings.format;
|
|||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -273,8 +274,10 @@ public class DeepSeekServiceTests extends ESTestCase {
|
|||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an unsuccessful status code for request from inference entity id [inference-id] status"
|
||||
+ " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
|
||||
"Resource not found at ["
|
||||
+ getUrl(webServer)
|
||||
+ "] for request from inference entity id [inference-id]"
|
||||
+ " status [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -66,6 +66,7 @@ import java.io.IOException;
|
|||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
@ -377,15 +378,14 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is("""
|
||||
assertThat(json, is(String.format(Locale.ROOT, """
|
||||
{\
|
||||
"error":{\
|
||||
"code":"not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
"message":"Resource not found at [%s] for request from inference entity id [id] status \
|
||||
[404]. Error message: [Model not found.]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}"""));
|
||||
}}""", getUrl(webServer))));
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionTestUtils;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
|
@ -29,39 +30,53 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.hamcrest.CoreMatchers;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
|
@ -72,12 +87,14 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_KEY_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettingsTests.getServiceSettingsMap;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettingsTests.createRequestSettingsMap;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -118,11 +135,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
|
@ -144,8 +157,8 @@ public class MistralServiceTests extends ESTestCase {
|
|||
"id",
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(null, null),
|
||||
getTaskSettingsMap(),
|
||||
createRandomChunkingSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
|
@ -169,16 +182,289 @@ public class MistralServiceTests extends ESTestCase {
|
|||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException {
|
||||
var model = "model";
|
||||
var secret = "secret";
|
||||
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||
assertThat(m, instanceOf(MistralChatCompletionModel.class));
|
||||
|
||||
var completionsModel = (MistralChatCompletionModel) m;
|
||||
|
||||
assertThat(completionsModel.getServiceSettings().modelId(), is(model));
|
||||
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
getRequestConfigMap(getServiceSettingsMap(model), getSecretSettingsMap(secret)),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException {
|
||||
var secret = "secret";
|
||||
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||
assertThat(m, instanceOf(MistralChatCompletionModel.class));
|
||||
|
||||
var completionsModel = (MistralChatCompletionModel) m;
|
||||
|
||||
assertNull(completionsModel.getServiceSettings().modelId());
|
||||
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||
|
||||
}, exception -> {
|
||||
assertThat(exception, instanceOf(ValidationException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Validation Failed: 1: [service_settings] does not contain the required setting [model];")
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
getRequestConfigMap(Collections.emptyMap(), getSecretSettingsMap(secret)),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION);
|
||||
|
||||
try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) {
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(
|
||||
mockModel,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of(""),
|
||||
false,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
verify(sender, times(1)).close();
|
||||
verifyNoMoreInteractions(factory);
|
||||
verifyNoMoreInteractions(sender);
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionInfer() throws Exception {
|
||||
// The escapes are because the streaming response must be on a single line
|
||||
String responseJson = """
|
||||
data: {\
|
||||
"id": "37d683fc0b3949b880529fb20973aca7",\
|
||||
"object": "chat.completion.chunk",\
|
||||
"created": 1749032579,\
|
||||
"model": "mistral-small-latest",\
|
||||
"choices": [\
|
||||
{\
|
||||
"index": 0,\
|
||||
"delta": {\
|
||||
"content": "Cho"\
|
||||
},\
|
||||
"finish_reason": "length",\
|
||||
"logprobs": null\
|
||||
}\
|
||||
],\
|
||||
"usage": {\
|
||||
"prompt_tokens": 10,\
|
||||
"total_tokens": 11,\
|
||||
"completion_tokens": 1\
|
||||
}\
|
||||
}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"id": "37d683fc0b3949b880529fb20973aca7",
|
||||
"choices": [{
|
||||
"delta": {
|
||||
"content": "Cho"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "mistral-small-latest",
|
||||
"object": "chat.completion.chunk",
|
||||
"usage": {
|
||||
"completion_tokens": 1,
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11
|
||||
}
|
||||
}
|
||||
"""));
|
||||
}
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception {
|
||||
String responseJson = """
|
||||
{
|
||||
"detail": "Not Found"
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||
var latch = new CountDownLatch(1);
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
var t = unwrapCause(e);
|
||||
assertThat(t, isA(UnifiedChatCompletionException.class));
|
||||
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"error" : {
|
||||
"code" : "not_found",
|
||||
"message" : "Resource not found at [%s] for request from inference entity id [id] status \
|
||||
[404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]",
|
||||
"type" : "mistral_error"
|
||||
}
|
||||
}"""), getUrl(webServer))));
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
}), latch::countDown)
|
||||
);
|
||||
assertTrue(latch.await(30, TimeUnit.SECONDS));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_StreamRequest() throws Exception {
|
||||
String responseJson = """
|
||||
data: {\
|
||||
"id":"12345",\
|
||||
"object":"chat.completion.chunk",\
|
||||
"created":123456789,\
|
||||
"model":"gpt-4o-mini",\
|
||||
"system_fingerprint": "123456789",\
|
||||
"choices":[\
|
||||
{\
|
||||
"index":0,\
|
||||
"delta":{\
|
||||
"content":"hello, world"\
|
||||
},\
|
||||
"logprobs":null,\
|
||||
"finish_reason":null\
|
||||
}\
|
||||
]\
|
||||
}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
streamCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
private InferenceEventsAssertion streamCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(
|
||||
model,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of("abc"),
|
||||
true,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_StreamRequest_ErrorResponse() {
|
||||
String responseJson = """
|
||||
{
|
||||
"message": "Unauthorized",
|
||||
"request_id": "ad95a2165083f20b490f8f78a14bb104"
|
||||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||
|
||||
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
|
||||
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
|
||||
assertThat(e.getMessage(), equalTo("""
|
||||
Received an authentication error status code for request from inference entity id [id] status [401]. Error message: [{
|
||||
"message": "Unauthorized",
|
||||
"request_id": "ad95a2165083f20b490f8f78a14bb104"
|
||||
}]"""));
|
||||
}
|
||||
|
||||
public void testSupportsStreaming() throws IOException {
|
||||
try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) {
|
||||
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
|
||||
assertFalse(service.canStream(TaskType.ANY));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
|
@ -192,24 +478,37 @@ public class MistralServiceTests extends ESTestCase {
|
|||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.SPARSE_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
|
||||
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
|
||||
TaskType.TEXT_EMBEDDING
|
||||
);
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_Completion() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
|
||||
getRequestConfigMap(getServiceSettingsMap("mistral-completion"), getSecretSettingsMap("secret")),
|
||||
TaskType.COMPLETION
|
||||
);
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_ChatCompletion() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(
|
||||
getRequestConfigMap(getServiceSettingsMap("mistral-chat-completion"), getSecretSettingsMap("secret")),
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
}
|
||||
|
||||
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(Map<String, Object> secret, TaskType chatCompletion)
|
||||
throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
config.put("extra_key", "value");
|
||||
secret.put("extra_key", "value");
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
|
@ -222,20 +521,40 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
|
||||
service.parseRequestConfig("id", chatCompletion, secret, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
|
||||
getEmbeddingsServiceSettingsMap(null, null),
|
||||
TaskType.TEXT_EMBEDDING
|
||||
);
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionTaskSettingsMap() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
|
||||
getServiceSettingsMap("mistral-completion"),
|
||||
TaskType.COMPLETION
|
||||
);
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
|
||||
getServiceSettingsMap("mistral-chat-completion"),
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
}
|
||||
|
||||
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap(
|
||||
Map<String, Object> serviceSettingsMap,
|
||||
TaskType chatCompletion
|
||||
) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var taskSettings = new HashMap<String, Object>();
|
||||
taskSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
taskSettings,
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
var config = getRequestConfigMap(serviceSettingsMap, taskSettings, getSecretSettingsMap("secret"));
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
|
@ -248,7 +567,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
|
||||
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,11 +576,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
var secretSettings = getSecretSettingsMap("secret");
|
||||
secretSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
secretSettings
|
||||
);
|
||||
var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), secretSettings);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
|
@ -278,11 +593,42 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-completion", TaskType.COMPLETION);
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
|
||||
testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-chat-completion", TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap(String modelId, TaskType chatCompletion)
|
||||
throws IOException {
|
||||
try (var service = createService()) {
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
secretSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(getServiceSettingsMap(modelId), secretSettings);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service")
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
getTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
||||
|
@ -298,11 +644,33 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAMistralCompletionModel() throws IOException {
|
||||
testParsePersistedConfig_CreatesAMistralModel("mistral-completion", TaskType.COMPLETION);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAMistralChatCompletionModel() throws IOException {
|
||||
testParsePersistedConfig_CreatesAMistralModel("mistral-chat-completion", TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
private void testParsePersistedConfig_CreatesAMistralModel(String modelId, TaskType chatCompletion) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(getServiceSettingsMap(modelId), getTaskSettingsMap(), getSecretSettingsMap("secret"));
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(MistralChatCompletionModel.class));
|
||||
|
||||
var embeddingsModel = (MistralChatCompletionModel) model;
|
||||
assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
|
||||
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
getTaskSettingsMap(),
|
||||
createRandomChunkingSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
@ -323,8 +691,8 @@ public class MistralServiceTests extends ESTestCase {
|
|||
public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
getTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
||||
|
@ -354,11 +722,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.SPARSE_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
|
@ -367,8 +731,8 @@ public class MistralServiceTests extends ESTestCase {
|
|||
public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(null, null),
|
||||
getTaskSettingsMap(),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
||||
|
@ -384,38 +748,92 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigEmbeddings() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
instanceOf(MistralEmbeddingsModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigCompletion() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
|
||||
getServiceSettingsMap("mistral-completion"),
|
||||
TaskType.COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigChatCompletion() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
|
||||
getServiceSettingsMap("mistral-chat-completion"),
|
||||
TaskType.CHAT_COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig(
|
||||
Map<String, Object> serviceSettingsMap,
|
||||
TaskType chatCompletion,
|
||||
Matcher<MistralModel> matcher
|
||||
) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
|
||||
var taskSettings = getEmbeddingsTaskSettingsMap();
|
||||
var taskSettings = getTaskSettingsMap();
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
|
||||
config.config().put("extra_key", "value");
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
|
||||
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
|
||||
assertThat(model, matcher);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
instanceOf(MistralEmbeddingsModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInCompletionServiceSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
|
||||
getServiceSettingsMap("mistral-completion"),
|
||||
TaskType.COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
|
||||
getServiceSettingsMap("mistral-chat-completion"),
|
||||
TaskType.CHAT_COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
private void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap(
|
||||
Map<String, Object> serviceSettingsMap,
|
||||
TaskType chatCompletion,
|
||||
Matcher<MistralModel> matcher
|
||||
) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
|
||||
serviceSettings.put("extra_key", "value");
|
||||
serviceSettingsMap.put("extra_key", "value");
|
||||
|
||||
var taskSettings = getEmbeddingsTaskSettingsMap();
|
||||
var taskSettings = getTaskSettingsMap();
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
|
||||
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
|
||||
assertThat(model, matcher);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap(1024, 512);
|
||||
var taskSettings = new HashMap<String, Object>();
|
||||
taskSettings.put("extra_key", "value");
|
||||
|
||||
|
@ -429,27 +847,50 @@ public class MistralServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
instanceOf(MistralEmbeddingsModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
|
||||
getServiceSettingsMap("mistral-completion"),
|
||||
TaskType.COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
|
||||
testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
|
||||
getServiceSettingsMap("mistral-chat-completion"),
|
||||
TaskType.CHAT_COMPLETION,
|
||||
instanceOf(MistralChatCompletionModel.class)
|
||||
);
|
||||
}
|
||||
|
||||
private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap(
|
||||
Map<String, Object> serviceSettingsMap,
|
||||
TaskType chatCompletion,
|
||||
Matcher<MistralModel> matcher
|
||||
) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null);
|
||||
var taskSettings = getEmbeddingsTaskSettingsMap();
|
||||
var taskSettings = getTaskSettingsMap();
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
secretSettings.put("extra_key", "value");
|
||||
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
|
||||
var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(MistralEmbeddingsModel.class));
|
||||
assertThat(model, matcher);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
Map.of()
|
||||
);
|
||||
var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of());
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config());
|
||||
|
||||
|
@ -465,8 +906,8 @@ public class MistralServiceTests extends ESTestCase {
|
|||
public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
getEmbeddingsServiceSettingsMap(1024, 512),
|
||||
getTaskSettingsMap(),
|
||||
createRandomChunkingSettingsMap(),
|
||||
Map.of()
|
||||
);
|
||||
|
@ -485,11 +926,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
|
||||
public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null),
|
||||
getEmbeddingsTaskSettingsMap(),
|
||||
Map.of()
|
||||
);
|
||||
var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of());
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config());
|
||||
|
||||
|
@ -780,7 +1217,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
{
|
||||
"service": "mistral",
|
||||
"name": "Mistral",
|
||||
"task_types": ["text_embedding"],
|
||||
"task_types": ["text_embedding", "completion", "chat_completion"],
|
||||
"configurations": {
|
||||
"api_key": {
|
||||
"description": "API Key for the provider you're connecting to.",
|
||||
|
@ -789,7 +1226,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
"sensitive": true,
|
||||
"updatable": true,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"model": {
|
||||
"description": "Refer to the Mistral models documentation for the list of available text embedding models.",
|
||||
|
@ -798,7 +1235,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"rate_limit.requests_per_minute": {
|
||||
"description": "Minimize the number of rate limit errors.",
|
||||
|
@ -807,7 +1244,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "int",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"max_input_tokens": {
|
||||
"description": "Allows you to specify the maximum number of tokens per input.",
|
||||
|
@ -816,7 +1253,7 @@ public class MistralServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "int",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -868,16 +1305,19 @@ public class MistralServiceTests extends ESTestCase {
|
|||
);
|
||||
}
|
||||
|
||||
private static Map<String, Object> getEmbeddingsServiceSettingsMap(
|
||||
String model,
|
||||
@Nullable Integer dimensions,
|
||||
@Nullable Integer maxTokens,
|
||||
@Nullable SimilarityMeasure similarityMeasure
|
||||
) {
|
||||
return createRequestSettingsMap(model, dimensions, maxTokens, similarityMeasure);
|
||||
private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSettings, Map<String, Object> secretSettings) {
|
||||
var builtServiceSettings = new HashMap<>();
|
||||
builtServiceSettings.putAll(serviceSettings);
|
||||
builtServiceSettings.putAll(secretSettings);
|
||||
|
||||
return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings));
|
||||
}
|
||||
|
||||
private static Map<String, Object> getEmbeddingsTaskSettingsMap() {
|
||||
private static Map<String, Object> getEmbeddingsServiceSettingsMap(@Nullable Integer dimensions, @Nullable Integer maxTokens) {
|
||||
return createRequestSettingsMap("mistral-embed", dimensions, maxTokens, null);
|
||||
}
|
||||
|
||||
private static Map<String, Object> getTaskSettingsMap() {
|
||||
// no task settings for Mistral embeddings
|
||||
return Map.of();
|
||||
}
|
||||
|
@ -886,25 +1326,4 @@ public class MistralServiceTests extends ESTestCase {
|
|||
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
|
||||
}
|
||||
|
||||
private static final String testEmbeddingResultJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.0123,
|
||||
-0.0123
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-ada-002-v2",
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.StatusLine;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MistralUnifiedChatCompletionResponseHandlerTests extends ESTestCase {
|
||||
private final MistralUnifiedChatCompletionResponseHandler responseHandler = new MistralUnifiedChatCompletionResponseHandler(
|
||||
"chat completions",
|
||||
(a, b) -> mock()
|
||||
);
|
||||
|
||||
public void testFailNotFound() throws IOException {
|
||||
var responseJson = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"detail": "Not Found"
|
||||
}
|
||||
""");
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson, 404);
|
||||
|
||||
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"error" : {
|
||||
"code" : "not_found",
|
||||
"message" : "Resource not found at [https://api.mistral.ai/v1/chat/completions] for request from inference entity id [id] \
|
||||
status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
|
||||
"type" : "mistral_error"
|
||||
}
|
||||
}""")));
|
||||
}
|
||||
|
||||
public void testFailUnauthorized() throws IOException {
|
||||
var responseJson = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"message": "Unauthorized",
|
||||
"request_id": "a580d263fb1521778782b22104efb415"
|
||||
}
|
||||
""");
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson, 401);
|
||||
|
||||
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"error" : {
|
||||
"code" : "unauthorized",
|
||||
"message" : "Received an authentication error status code for request from inference entity id [id] status [401]. Error \
|
||||
message: [{\\"message\\":\\"Unauthorized\\",\\"request_id\\":\\"a580d263fb1521778782b22104efb415\\"}]",
|
||||
"type" : "mistral_error"
|
||||
}
|
||||
}""")));
|
||||
}
|
||||
|
||||
public void testFailBadRequest() throws IOException {
|
||||
var responseJson = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"object": "error",
|
||||
"message": "Invalid model: mistral-small-l2atest",
|
||||
"type": "invalid_model",
|
||||
"param": null,
|
||||
"code": "1500"
|
||||
}
|
||||
""");
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson, 400);
|
||||
|
||||
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"error" : {
|
||||
"code" : "bad_request",
|
||||
"message" : "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \
|
||||
[{\\"object\\":\\"error\\",\\"message\\":\\"Invalid model: mistral-small-l2atest\\",\\"type\\":\\"invalid_model\\",\\"par\
|
||||
am\\":null,\\"code\\":\\"1500\\"}]",
|
||||
"type" : "mistral_error"
|
||||
}
|
||||
}""")));
|
||||
}
|
||||
|
||||
private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
|
||||
var exception = invalidResponse(responseJson, statusCode);
|
||||
assertThat(exception, isA(RetryException.class));
|
||||
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
|
||||
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
|
||||
}
|
||||
|
||||
private Exception invalidResponse(String responseJson, int statusCode) {
|
||||
return expectThrows(
|
||||
RetryException.class,
|
||||
() -> responseHandler.validateResponse(
|
||||
mock(),
|
||||
mock(),
|
||||
mockRequest(),
|
||||
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
|
||||
true
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static Request mockRequest() throws URISyntaxException {
|
||||
var request = mock(Request.class);
|
||||
when(request.getInferenceEntityId()).thenReturn("id");
|
||||
when(request.isStreaming()).thenReturn(true);
|
||||
when(request.getURI()).thenReturn(new URI("https://api.mistral.ai/v1/chat/completions"));
|
||||
return request;
|
||||
}
|
||||
|
||||
private static HttpResponse mockErrorResponse(int statusCode) {
|
||||
var statusLine = mock(StatusLine.class);
|
||||
when(statusLine.getStatusCode()).thenReturn(statusCode);
|
||||
|
||||
var response = mock(HttpResponse.class);
|
||||
when(response.getStatusLine()).thenReturn(statusLine);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private String toJson(UnifiedChatCompletionException e) throws IOException {
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.action;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class MistralActionCreatorTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "chat.completion",
|
||||
"id": "",
|
||||
"created": 1745855316,
|
||||
"model": "/repository",
|
||||
"system_fingerprint": "3.2.3-sha-a1f3ebe",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there, how may I assist you today?"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
|
||||
|
||||
assertChatCompletionRequest();
|
||||
}
|
||||
}
|
||||
|
||||
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
|
||||
var settings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"invalid_field": "unexpected"
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
|
||||
sender,
|
||||
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to send Mistral completion request from inference entity id " + "[id]. Cause: Required [choices]")
|
||||
);
|
||||
|
||||
assertChatCompletionRequest();
|
||||
}
|
||||
}
|
||||
|
||||
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
|
||||
var model = MistralChatCompletionModelTests.createCompletionModel("secret", "model");
|
||||
model.setURI(getUrl(webServer));
|
||||
var actionCreator = new MistralActionCreator(sender, threadPool);
|
||||
var action = actionCreator.create(model);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
return listener;
|
||||
}
|
||||
|
||||
private void assertChatCompletionRequest() throws IOException {
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
assertThat(
|
||||
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
||||
);
|
||||
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertThat(requestMap.get("stream"), is(false));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.action;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockRequest;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.COMPLETION_HANDLER;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.USER_ROLE;
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class MistralChatCompletionActionTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "9d80f26810ac4e9582f927fcf0512ec7",
|
||||
"object": "chat.completion",
|
||||
"created": 1748596419,
|
||||
"model": "mistral-small-latest",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"content": "result content"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11,
|
||||
"completion_tokens": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content"))));
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
|
||||
MockRequest request = webServer.requests().get(0);
|
||||
|
||||
assertNull(request.getUri().getQuery());
|
||||
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()));
|
||||
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(request.getBody());
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertThat(requestMap.get("stream"), is(false));
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "9d80f26810ac4e9582f927fcf0512ec7",
|
||||
"object": "chat.completion",
|
||||
"created": 1748596419,
|
||||
"model": "mistral-small-latest",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"content": "result content"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11,
|
||||
"completion_tokens": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
|
||||
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
private ExecutableAction createAction(String url, Sender sender) {
|
||||
var model = createCompletionModel("secret", "model");
|
||||
model.setURI(url);
|
||||
var manager = new GenericRequestManager<>(
|
||||
threadPool,
|
||||
model,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.completion;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class MistralChatCompletionModelTests extends ESTestCase {
|
||||
|
||||
public static MistralChatCompletionModel createCompletionModel(String apiKey, String modelId) {
|
||||
return new MistralChatCompletionModel(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
"service",
|
||||
new MistralChatCompletionServiceSettings(modelId, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static MistralChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) {
|
||||
MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
"service",
|
||||
new MistralChatCompletionServiceSettings(modelId, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
mistralChatCompletionModel.setURI(url);
|
||||
return mistralChatCompletionModel;
|
||||
}
|
||||
|
||||
public static MistralChatCompletionModel createChatCompletionModel(String apiKey, String modelId) {
|
||||
return new MistralChatCompletionModel(
|
||||
"id",
|
||||
TaskType.CHAT_COMPLETION,
|
||||
"service",
|
||||
new MistralChatCompletionServiceSettings(modelId, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static MistralChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) {
|
||||
MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel(
|
||||
"id",
|
||||
TaskType.CHAT_COMPLETION,
|
||||
"service",
|
||||
new MistralChatCompletionServiceSettings(modelId, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
mistralChatCompletionModel.setURI(url);
|
||||
return mistralChatCompletionModel;
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
|
||||
var model = createCompletionModel("api_key", "model_name");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = MistralChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
|
||||
var model = createCompletionModel("api_key", null);
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = MistralChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
|
||||
var model = createCompletionModel("api_key", null);
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = MistralChatCompletionModel.of(model, request);
|
||||
|
||||
assertNull(overriddenModel.getServiceSettings().modelId());
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
|
||||
var model = createCompletionModel("api_key", "model_name");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null, // not overriding model
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = MistralChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.completion;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class MistralChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<MistralChatCompletionServiceSettings> {
|
||||
|
||||
public static final String MODEL_ID = "some model";
|
||||
public static final int RATE_LIMIT = 2;
|
||||
|
||||
public void testFromMap_AllFields_Success() {
|
||||
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
MistralConstants.MODEL_FIELD,
|
||||
MODEL_ID,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new MistralChatCompletionServiceSettings(
|
||||
MODEL_ID,
|
||||
|
||||
new RateLimitSettings(RATE_LIMIT)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_MissingModelId_ThrowsException() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> MistralChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)))
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model];")
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_MissingRateLimit_Success() {
|
||||
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new MistralChatCompletionServiceSettings(MODEL_ID, null)));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
MistralConstants.MODEL_FIELD,
|
||||
MODEL_ID,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
var expected = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"model": "some model",
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 2
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
assertThat(xContentResult, is(expected));
|
||||
}
|
||||
|
||||
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
|
||||
var serviceSettings = MistralChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
var expected = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"model": "some model",
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 240
|
||||
}
|
||||
}
|
||||
""");
|
||||
assertThat(xContentResult, is(expected));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MistralChatCompletionServiceSettings> instanceReader() {
|
||||
return MistralChatCompletionServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MistralChatCompletionServiceSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MistralChatCompletionServiceSettings mutateInstance(MistralChatCompletionServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, MistralChatCompletionServiceSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MistralChatCompletionServiceSettings mutateInstanceForVersion(
|
||||
MistralChatCompletionServiceSettings instance,
|
||||
TransportVersion version
|
||||
) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
private static MistralChatCompletionServiceSettings createRandom() {
|
||||
var modelId = randomAlphaOfLength(8);
|
||||
|
||||
return new MistralChatCompletionServiceSettings(modelId, RateLimitSettingsTests.createRandom());
|
||||
}
|
||||
|
||||
public static Map<String, Object> getServiceSettingsMap(String model) {
|
||||
var map = new HashMap<String, Object>();
|
||||
|
||||
map.put(MistralConstants.MODEL_FIELD, model);
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequestEntity;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
|||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel;
|
||||
|
||||
public class MistralChatCompletionRequestEntityTests extends ESTestCase {
|
||||
|
||||
private static final String ROLE = "user";
|
||||
|
||||
public void testModelUserFieldsSerialization() throws IOException {
|
||||
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
|
||||
new UnifiedCompletionRequest.ContentString("Hello, world!"),
|
||||
ROLE,
|
||||
null,
|
||||
null
|
||||
);
|
||||
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
|
||||
messageList.add(message);
|
||||
|
||||
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
|
||||
|
||||
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
|
||||
MistralChatCompletionModel model = createCompletionModel("api-key", "test-endpoint");
|
||||
|
||||
MistralChatCompletionRequestEntity entity = new MistralChatCompletionRequestEntity(unifiedChatInput, model);
|
||||
|
||||
XContentBuilder builder = JsonXContent.contentBuilder();
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String expectedJson = """
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"model": "test-endpoint",
|
||||
"n": 1,
|
||||
"stream": true
|
||||
}
|
||||
""";
|
||||
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.request.completion;
|
||||
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralConstants;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class MistralChatCompletionRequestTests extends ESTestCase {
|
||||
|
||||
public void testCreateRequest_WithStreaming() throws IOException {
|
||||
var request = createRequest("secret", randomAlphaOfLength(15), "model", true);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap.get("stream"), is(true));
|
||||
}
|
||||
|
||||
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
|
||||
String input = randomAlphaOfLength(5);
|
||||
var request = createRequest("secret", input, "model", true);
|
||||
var truncatedRequest = request.truncate();
|
||||
assertThat(request.getURI().toString(), is(MistralConstants.API_COMPLETIONS_PATH));
|
||||
|
||||
var httpRequest = truncatedRequest.createHttpRequest();
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(4));
|
||||
|
||||
// We do not truncate for Hugging Face chat completions
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertTrue((Boolean) requestMap.get("stream"));
|
||||
assertNull(requestMap.get("stream_options")); // Mistral does not use stream options
|
||||
}
|
||||
|
||||
public void testTruncationInfo_ReturnsNull() {
|
||||
var request = createRequest("secret", randomAlphaOfLength(5), "model", true);
|
||||
assertNull(request.getTruncationInfo());
|
||||
}
|
||||
|
||||
public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model) {
|
||||
return createRequest(apiKey, input, model, false);
|
||||
}
|
||||
|
||||
public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream) {
|
||||
var chatCompletionModel = MistralChatCompletionModelTests.createCompletionModel(apiKey, model);
|
||||
return new MistralChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.mistral.response;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class MistralErrorResponseTests extends ESTestCase {
|
||||
|
||||
public static final String ERROR_RESPONSE_JSON = """
|
||||
{
|
||||
"error": "A valid user token is required"
|
||||
}
|
||||
""";
|
||||
|
||||
public void testFromResponse() {
|
||||
var errorResponse = MistralErrorResponse.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
assertNotNull(errorResponse);
|
||||
assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage());
|
||||
}
|
||||
}
|
|
@ -100,7 +100,7 @@ public class OpenAiResponseHandlerTests extends ESTestCase {
|
|||
assertFalse(retryException.shouldRetry());
|
||||
assertThat(
|
||||
retryException.getCause().getMessage(),
|
||||
containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
|
||||
containsString("Received a bad request status code for request from inference entity id [id] status [400]")
|
||||
);
|
||||
assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
// 400 is not flagged as a content too large when the error message is different
|
||||
|
@ -112,7 +112,7 @@ public class OpenAiResponseHandlerTests extends ESTestCase {
|
|||
assertFalse(retryException.shouldRetry());
|
||||
assertThat(
|
||||
retryException.getCause().getMessage(),
|
||||
containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
|
||||
containsString("Received a bad request status code for request from inference entity id [id] status [400]")
|
||||
);
|
||||
assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
// 401
|
||||
|
|
|
@ -65,6 +65,7 @@ import java.util.Arrays;
|
|||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
|
@ -1153,14 +1154,14 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is("""
|
||||
assertThat(json, is(String.format(Locale.ROOT, """
|
||||
{\
|
||||
"error":{\
|
||||
"code":"model_not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
"message":"Resource not found at [%s] for request from inference entity id [id] status \
|
||||
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
|
||||
"type":"invalid_request_error"\
|
||||
}}"""));
|
||||
}}""", getUrl(webServer))));
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue