Add Hugging Face Chat Completion support to Inference Plugin (#127254)
* Add Hugging Face Chat Completion support to Inference Plugin * Add support for streaming chat completion task for HuggingFace * [CI] Auto commit changes from spotless * Add support for non-streaming completion task for HuggingFace * Remove RequestManager for HF Chat Completion Task * Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests * Refactored Hugging Face Action Creator, added Unit Tests * Add Hugging Face Server Test * [CI] Auto commit changes from spotless * Removed parameters from media type for Chat Completion Request and unit tests * Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT * Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService * Update minimal supported version and add Hugging Face transport version constants * Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests * Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests * Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param * Removed getFirst() method calls for backport compatibility * Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing * Refactored tests to use stripWhitespace method for readability * Refactored javadoc for HuggingFaceService * Renamed HF chat completion TransportVersion constant names * Added random string generation in unit test * Refactored javadocs for HuggingFace requests * Refactored tests to reduce duplication * Added changelog file * Add HuggingFaceChatCompletionResponseHandler and associated tests * Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types * Refactor HuggingFace error handling to improve response structure and add streaming support * Allowing null function name for hugging face models --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
This commit is contained in:
parent
54f26680ea
commit
d1ad917855
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 127254
|
||||||
|
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -175,6 +175,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
|
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
|
||||||
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
|
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
|
||||||
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
|
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
|
||||||
|
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
|
||||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||||
|
@ -255,6 +256,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
|
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
|
||||||
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
|
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
|
||||||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
||||||
|
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
|
|
||||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.COMPLETION);
|
List<Object> services = getServices(TaskType.COMPLETION);
|
||||||
assertThat(services.size(), equalTo(10));
|
assertThat(services.size(), equalTo(11));
|
||||||
|
|
||||||
var providers = providers(services);
|
var providers = providers(services);
|
||||||
|
|
||||||
|
@ -140,7 +140,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"googleaistudio",
|
"googleaistudio",
|
||||||
"openai",
|
"openai",
|
||||||
"streaming_completion_test_service"
|
"streaming_completion_test_service",
|
||||||
|
"hugging_face"
|
||||||
).toArray()
|
).toArray()
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
@ -148,11 +149,14 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
|
|
||||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||||
assertThat(services.size(), equalTo(4));
|
assertThat(services.size(), equalTo(5));
|
||||||
|
|
||||||
var providers = providers(services);
|
var providers = providers(services);
|
||||||
|
|
||||||
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
|
assertThat(
|
||||||
|
providers,
|
||||||
|
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
||||||
|
|
|
@ -78,6 +78,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.Goog
|
||||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
|
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
||||||
|
@ -357,6 +358,13 @@ public class InferenceNamedWriteablesProvider {
|
||||||
namedWriteables.add(
|
namedWriteables.add(
|
||||||
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
|
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
|
||||||
);
|
);
|
||||||
|
namedWriteables.add(
|
||||||
|
new NamedWriteableRegistry.Entry(
|
||||||
|
ServiceSettings.class,
|
||||||
|
HuggingFaceChatCompletionServiceSettings.NAME,
|
||||||
|
HuggingFaceChatCompletionServiceSettings::new
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
|
||||||
|
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
|
||||||
|
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
|
||||||
|
*/
|
||||||
|
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
|
||||||
|
|
||||||
|
private static final String HUGGING_FACE_ERROR = "hugging_face_error";
|
||||||
|
|
||||||
|
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||||
|
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
|
||||||
|
assert request.isStreaming() : "Only streaming requests support this format";
|
||||||
|
var responseStatusCode = result.response().getStatusLine().getStatusCode();
|
||||||
|
if (request.isStreaming()) {
|
||||||
|
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
|
||||||
|
var restStatus = toRestStatus(responseStatusCode);
|
||||||
|
return errorResponse instanceof HuggingFaceErrorResponseEntity
|
||||||
|
? new UnifiedChatCompletionException(
|
||||||
|
restStatus,
|
||||||
|
errorMessage,
|
||||||
|
HUGGING_FACE_ERROR,
|
||||||
|
restStatus.name().toLowerCase(Locale.ROOT)
|
||||||
|
)
|
||||||
|
: new UnifiedChatCompletionException(
|
||||||
|
restStatus,
|
||||||
|
errorMessage,
|
||||||
|
createErrorType(errorResponse),
|
||||||
|
restStatus.name().toLowerCase(Locale.ROOT)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return super.buildError(message, request, result, errorResponse);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Exception buildMidStreamError(Request request, String message, Exception e) {
|
||||||
|
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
|
||||||
|
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
|
||||||
|
return new UnifiedChatCompletionException(
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
format(
|
||||||
|
"%s for request from inference entity id [%s]. Error message: [%s]",
|
||||||
|
SERVER_ERROR_OBJECT,
|
||||||
|
request.getInferenceEntityId(),
|
||||||
|
errorResponse.getErrorMessage()
|
||||||
|
),
|
||||||
|
HUGGING_FACE_ERROR,
|
||||||
|
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
|
||||||
|
);
|
||||||
|
} else if (e != null) {
|
||||||
|
return UnifiedChatCompletionException.fromThrowable(e);
|
||||||
|
} else {
|
||||||
|
return new UnifiedChatCompletionException(
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
|
||||||
|
createErrorType(errorResponse),
|
||||||
|
"stream_error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
|
||||||
|
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
|
||||||
|
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
|
||||||
|
: null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a structured error response specifically for streaming operations
|
||||||
|
* using HuggingFace APIs. This is separate from non-streaming error responses,
|
||||||
|
* which are handled by {@link HuggingFaceErrorResponseEntity}.
|
||||||
|
* An example error response for failed field validation for streaming operation would look like
|
||||||
|
* <code>
|
||||||
|
* {
|
||||||
|
* "error": "Input validation error: cannot compile regex from schema",
|
||||||
|
* "http_status_code": 422
|
||||||
|
* }
|
||||||
|
* </code>
|
||||||
|
*/
|
||||||
|
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
|
||||||
|
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
|
||||||
|
HUGGING_FACE_ERROR,
|
||||||
|
true,
|
||||||
|
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
|
||||||
|
);
|
||||||
|
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
HUGGING_FACE_ERROR,
|
||||||
|
true,
|
||||||
|
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
|
||||||
|
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));
|
||||||
|
|
||||||
|
ERROR_PARSER.declareObjectOrNull(
|
||||||
|
ConstructingObjectParser.optionalConstructorArg(),
|
||||||
|
ERROR_BODY_PARSER,
|
||||||
|
null,
|
||||||
|
new ParseField("error")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses a streaming HuggingFace error response from a JSON string.
|
||||||
|
*
|
||||||
|
* @param response the raw JSON string representing an error
|
||||||
|
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
|
||||||
|
*/
|
||||||
|
private static ErrorResponse fromString(String response) {
|
||||||
|
try (
|
||||||
|
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||||
|
.createParser(XContentParserConfiguration.EMPTY, response)
|
||||||
|
) {
|
||||||
|
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||||
|
} catch (Exception e) {
|
||||||
|
// swallow the error
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrorResponse.UNDEFINED_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private final Integer httpStatusCode;
|
||||||
|
|
||||||
|
StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
|
||||||
|
super(errorMessage);
|
||||||
|
this.httpStatusCode = httpStatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public Integer httpStatusCode() {
|
||||||
|
return httpStatusCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,17 +9,18 @@ package org.elasticsearch.xpack.inference.services.huggingface;
|
||||||
|
|
||||||
import org.elasticsearch.common.settings.SecureString;
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.Model;
|
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ModelSecrets;
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
|
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public abstract class HuggingFaceModel extends Model {
|
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
|
||||||
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
|
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
|
||||||
private final SecureString apiKey;
|
private final SecureString apiKey;
|
||||||
|
|
||||||
|
@ -38,6 +39,16 @@ public abstract class HuggingFaceModel extends Model {
|
||||||
return rateLimitServiceSettings;
|
return rateLimitServiceSettings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int rateLimitGroupingHash() {
|
||||||
|
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RateLimitSettings rateLimitSettings() {
|
||||||
|
return rateLimitServiceSettings.rateLimitSettings();
|
||||||
|
}
|
||||||
|
|
||||||
public SecureString apiKey() {
|
public SecureString apiKey() {
|
||||||
return apiKey;
|
return apiKey;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
|
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
|
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -64,7 +64,7 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
|
||||||
) {
|
) {
|
||||||
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
|
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
|
||||||
var truncatedInput = truncate(docsInput, model.getTokenLimit());
|
var truncatedInput = truncate(docsInput, model.getTokenLimit());
|
||||||
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
|
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
|
||||||
|
|
||||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,15 +26,21 @@ import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -42,16 +48,29 @@ import java.util.EnumSet;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is responsible for managing the Hugging Face inference service.
|
||||||
|
* It manages model creation, as well as chunked, non-chunked, and unified completion inference.
|
||||||
|
*/
|
||||||
public class HuggingFaceService extends HuggingFaceBaseService {
|
public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
public static final String NAME = "hugging_face";
|
public static final String NAME = "hugging_face";
|
||||||
|
|
||||||
private static final String SERVICE_NAME = "Hugging Face";
|
private static final String SERVICE_NAME = "Hugging Face";
|
||||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
|
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
TaskType.SPARSE_EMBEDDING,
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
TaskType.CHAT_COMPLETION
|
||||||
|
);
|
||||||
|
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
|
||||||
|
"hugging face chat completion",
|
||||||
|
OpenAiChatCompletionResponseEntity::fromResponse
|
||||||
|
);
|
||||||
|
|
||||||
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||||
super(factory, serviceComponents);
|
super(factory, serviceComponents);
|
||||||
|
@ -78,6 +97,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
context
|
context
|
||||||
);
|
);
|
||||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||||
|
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
NAME,
|
||||||
|
serviceSettings,
|
||||||
|
secretSettings,
|
||||||
|
context
|
||||||
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -139,7 +166,29 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
TimeValue timeout,
|
TimeValue timeout,
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
throwUnsupportedUnifiedCompletionOperation(NAME);
|
if (model instanceof HuggingFaceChatCompletionModel == false) {
|
||||||
|
listener.onFailure(createInvalidModelException(model));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
|
||||||
|
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
|
||||||
|
var manager = new GenericRequestManager<>(
|
||||||
|
getServiceComponents().threadPool(),
|
||||||
|
overriddenModel,
|
||||||
|
UNIFIED_CHAT_COMPLETION_HANDLER,
|
||||||
|
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
|
||||||
|
UnifiedChatInput.class
|
||||||
|
);
|
||||||
|
var errorMessage = HuggingFaceActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
|
||||||
|
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
|
||||||
|
|
||||||
|
action.execute(inputs, timeout, listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<TaskType> supportedStreamingTasks() {
|
||||||
|
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -149,7 +198,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EnumSet<TaskType> supportedTaskTypes() {
|
public EnumSet<TaskType> supportedTaskTypes() {
|
||||||
return supportedTaskTypes;
|
return SUPPORTED_TASK_TYPES;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -167,14 +216,15 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
return configuration.getOrCompute();
|
return configuration.getOrCompute();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Configuration() {}
|
||||||
|
|
||||||
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
|
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
|
||||||
() -> {
|
() -> {
|
||||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
||||||
|
|
||||||
configurationMap.put(
|
configurationMap.put(
|
||||||
URL,
|
URL,
|
||||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
|
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.")
|
||||||
.setDescription("The URL endpoint to use for the requests.")
|
|
||||||
.setLabel("URL")
|
.setLabel("URL")
|
||||||
.setRequired(true)
|
.setRequired(true)
|
||||||
.setSensitive(false)
|
.setSensitive(false)
|
||||||
|
@ -183,12 +233,12 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
.build()
|
.build()
|
||||||
);
|
);
|
||||||
|
|
||||||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
|
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
|
||||||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
|
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
|
||||||
|
|
||||||
return new InferenceServiceConfiguration.Builder().setService(NAME)
|
return new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||||
.setName(SERVICE_NAME)
|
.setName(SERVICE_NAME)
|
||||||
.setTaskTypes(supportedTaskTypes)
|
.setTaskTypes(SUPPORTED_TASK_TYPES)
|
||||||
.setConfigurations(configurationMap)
|
.setConfigurations(configurationMap)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,16 +7,26 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@ -26,6 +36,13 @@ import static org.elasticsearch.core.Strings.format;
|
||||||
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
|
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
|
||||||
*/
|
*/
|
||||||
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
|
|
||||||
|
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
|
||||||
|
static final String USER_ROLE = "user";
|
||||||
|
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
|
||||||
|
"hugging face completion",
|
||||||
|
OpenAiChatCompletionResponseEntity::fromResponse
|
||||||
|
);
|
||||||
private final Sender sender;
|
private final Sender sender;
|
||||||
private final ServiceComponents serviceComponents;
|
private final ServiceComponents serviceComponents;
|
||||||
|
|
||||||
|
@ -46,11 +63,7 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
serviceComponents.truncator(),
|
serviceComponents.truncator(),
|
||||||
serviceComponents.threadPool()
|
serviceComponents.threadPool()
|
||||||
);
|
);
|
||||||
var errorMessage = format(
|
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
|
||||||
"Failed to send Hugging Face %s request from inference entity id [%s]",
|
|
||||||
"text embeddings",
|
|
||||||
model.getInferenceEntityId()
|
|
||||||
);
|
|
||||||
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,11 +76,25 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
serviceComponents.truncator(),
|
serviceComponents.truncator(),
|
||||||
serviceComponents.threadPool()
|
serviceComponents.threadPool()
|
||||||
);
|
);
|
||||||
var errorMessage = format(
|
var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId());
|
||||||
"Failed to send Hugging Face %s request from inference entity id [%s]",
|
|
||||||
"ELSER",
|
|
||||||
model.getInferenceEntityId()
|
|
||||||
);
|
|
||||||
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
|
||||||
|
var manager = new GenericRequestManager<>(
|
||||||
|
serviceComponents.threadPool(),
|
||||||
|
model,
|
||||||
|
COMPLETION_HANDLER,
|
||||||
|
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||||
|
ChatCompletionInput.class
|
||||||
|
);
|
||||||
|
|
||||||
|
var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
|
||||||
|
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||||
|
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||||
|
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
|
|
||||||
|
@ -15,4 +16,6 @@ public interface HuggingFaceActionVisitor {
|
||||||
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
||||||
|
|
||||||
ExecutableAction create(HuggingFaceElserModel model);
|
ExecutableAction create(HuggingFaceElserModel model);
|
||||||
|
|
||||||
|
ExecutableAction create(HuggingFaceChatCompletionModel model);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class HuggingFaceChatCompletionModel extends HuggingFaceModel {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new {@link HuggingFaceChatCompletionModel} by copying properties from an existing model,
|
||||||
|
* replacing the {@code modelId} in the service settings with the one from the given {@link UnifiedCompletionRequest},
|
||||||
|
* if present. If the request does not specify a model ID, the original value is retained.
|
||||||
|
*
|
||||||
|
* @param model the original model to copy from
|
||||||
|
* @param request the request potentially containing an overridden model ID
|
||||||
|
* @return a new {@link HuggingFaceChatCompletionModel} with updated service settings
|
||||||
|
*/
|
||||||
|
public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) {
|
||||||
|
var originalModelServiceSettings = model.getServiceSettings();
|
||||||
|
var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings(
|
||||||
|
request.model() != null ? request.model() : originalModelServiceSettings.modelId(),
|
||||||
|
originalModelServiceSettings.uri(),
|
||||||
|
originalModelServiceSettings.rateLimitSettings()
|
||||||
|
);
|
||||||
|
|
||||||
|
return new HuggingFaceChatCompletionModel(
|
||||||
|
model.getInferenceEntityId(),
|
||||||
|
model.getTaskType(),
|
||||||
|
model.getConfigurations().getService(),
|
||||||
|
overriddenServiceSettings,
|
||||||
|
model.getSecretSettings()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFaceChatCompletionModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
Map<String, Object> serviceSettings,
|
||||||
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
this(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
service,
|
||||||
|
HuggingFaceChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||||
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
HuggingFaceChatCompletionModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
HuggingFaceChatCompletionServiceSettings serviceSettings,
|
||||||
|
@Nullable DefaultSecretSettings secretSettings
|
||||||
|
) {
|
||||||
|
super(
|
||||||
|
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
|
||||||
|
new ModelSecrets(secretSettings),
|
||||||
|
serviceSettings,
|
||||||
|
secretSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HuggingFaceChatCompletionServiceSettings getServiceSettings() {
|
||||||
|
return (HuggingFaceChatCompletionServiceSettings) super.getServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DefaultSecretSettings getSecretSettings() {
|
||||||
|
return (DefaultSecretSettings) super.getSecretSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
|
||||||
|
return creator.create(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getTokenLimit() {
|
||||||
|
throw new UnsupportedOperationException("Token Limit for chat completion is sent in request and not retrieved from the model");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,172 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Settings for the Hugging Face chat completion service.
|
||||||
|
* <p>
|
||||||
|
* This class contains the settings required to configure a Hugging Face chat completion service, including the model ID, URL, maximum input
|
||||||
|
* tokens, and rate limit settings.
|
||||||
|
* </p>
|
||||||
|
*/
|
||||||
|
public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject
|
||||||
|
implements
|
||||||
|
ServiceSettings,
|
||||||
|
HuggingFaceRateLimitServiceSettings {
|
||||||
|
|
||||||
|
public static final String NAME = "hugging_face_completion_service_settings";
|
||||||
|
// At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess
|
||||||
|
// 3000 requests per minute
|
||||||
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings.
|
||||||
|
* @param map the map of settings
|
||||||
|
* @param context the context for parsing the settings
|
||||||
|
* @return a new instance of {@link HuggingFaceChatCompletionServiceSettings}
|
||||||
|
*/
|
||||||
|
public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
|
||||||
|
var uri = extractUri(map, URL, validationException);
|
||||||
|
|
||||||
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
HuggingFaceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
return new HuggingFaceChatCompletionServiceSettings(modelId, uri, rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String modelId;
|
||||||
|
private final URI uri;
|
||||||
|
private final RateLimitSettings rateLimitSettings;
|
||||||
|
|
||||||
|
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) {
|
||||||
|
this(modelId, createUri(url), rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) {
|
||||||
|
this.modelId = modelId;
|
||||||
|
this.uri = uri;
|
||||||
|
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a stream input.
|
||||||
|
* @param in the stream input
|
||||||
|
* @throws IOException if an I/O error occurs
|
||||||
|
*/
|
||||||
|
public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException {
|
||||||
|
this.modelId = in.readOptionalString();
|
||||||
|
this.uri = createUri(in.readString());
|
||||||
|
this.rateLimitSettings = new RateLimitSettings(in);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RateLimitSettings rateLimitSettings() {
|
||||||
|
return rateLimitSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public URI uri() {
|
||||||
|
return uri;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String modelId() {
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
|
builder.endObject();
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
if (modelId != null) {
|
||||||
|
builder.field(MODEL_ID, modelId);
|
||||||
|
}
|
||||||
|
builder.field(URL, uri.toString());
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalString(modelId);
|
||||||
|
out.writeString(uri.toString());
|
||||||
|
rateLimitSettings.writeTo(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object object) {
|
||||||
|
if (this == object) return true;
|
||||||
|
if (object == null || getClass() != object.getClass()) return false;
|
||||||
|
HuggingFaceChatCompletionServiceSettings that = (HuggingFaceChatCompletionServiceSettings) object;
|
||||||
|
return Objects.equals(modelId, that.modelId)
|
||||||
|
&& Objects.equals(uri, that.uri)
|
||||||
|
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(modelId, uri, rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is responsible for creating Hugging Face chat completions HTTP requests.
|
||||||
|
* It handles the preparation of the HTTP request with the necessary headers and body.
|
||||||
|
*/
|
||||||
|
public class HuggingFaceUnifiedChatCompletionRequest implements Request {
|
||||||
|
|
||||||
|
private final HuggingFaceAccount account;
|
||||||
|
private final HuggingFaceChatCompletionModel model;
|
||||||
|
private final UnifiedChatInput unifiedChatInput;
|
||||||
|
|
||||||
|
public HuggingFaceUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
|
||||||
|
this.account = HuggingFaceAccount.of(model);
|
||||||
|
this.model = Objects.requireNonNull(model);
|
||||||
|
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an HTTP request to the Hugging Face API for chat completions.
|
||||||
|
* The request includes the necessary headers and the input data as a JSON entity.
|
||||||
|
*
|
||||||
|
* @return an HttpRequest object containing the HTTP POST request
|
||||||
|
*/
|
||||||
|
public HttpRequest createHttpRequest() {
|
||||||
|
HttpPost httpPost = new HttpPost(getURI());
|
||||||
|
|
||||||
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
|
Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
|
||||||
|
);
|
||||||
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||||
|
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
|
||||||
|
|
||||||
|
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public URI getURI() {
|
||||||
|
return account.uri();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getInferenceEntityId() {
|
||||||
|
return model.getInferenceEntityId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Request truncate() {
|
||||||
|
// Truncation is not applicable for chat completion requests
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean[] getTruncationInfo() {
|
||||||
|
// Truncation is not applicable for chat completion requests
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isStreaming() {
|
||||||
|
return unifiedChatInput.stream();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
|
||||||
|
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
|
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject {
|
||||||
|
|
||||||
|
private static final String MODEL_FIELD = "model";
|
||||||
|
private static final String MAX_TOKENS_FIELD = "max_tokens";
|
||||||
|
|
||||||
|
private final UnifiedChatInput unifiedChatInput;
|
||||||
|
private final HuggingFaceChatCompletionModel model;
|
||||||
|
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||||
|
|
||||||
|
public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
|
||||||
|
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||||
|
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||||
|
this.model = Objects.requireNonNull(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
unifiedRequestEntity.toXContent(builder, params);
|
||||||
|
|
||||||
|
if (model.getServiceSettings().modelId() != null) {
|
||||||
|
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
|
||||||
|
builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.endObject();
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
|
||||||
|
|
||||||
import org.apache.http.HttpHeaders;
|
import org.apache.http.HttpHeaders;
|
||||||
import org.apache.http.client.methods.HttpPost;
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
@ -24,25 +24,35 @@ import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||||
|
|
||||||
public class HuggingFaceInferenceRequest implements Request {
|
/**
|
||||||
|
* This class is responsible for creating Hugging Face embeddings HTTP requests.
|
||||||
|
* It handles the truncation of input data and prepares the HTTP request with the necessary headers and body.
|
||||||
|
*/
|
||||||
|
public class HuggingFaceEmbeddingsRequest implements Request {
|
||||||
|
|
||||||
private final Truncator truncator;
|
private final Truncator truncator;
|
||||||
private final HuggingFaceAccount account;
|
private final HuggingFaceAccount account;
|
||||||
private final Truncator.TruncationResult truncationResult;
|
private final Truncator.TruncationResult truncationResult;
|
||||||
private final HuggingFaceModel model;
|
private final HuggingFaceModel model;
|
||||||
|
|
||||||
public HuggingFaceInferenceRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
|
public HuggingFaceEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
|
||||||
this.truncator = Objects.requireNonNull(truncator);
|
this.truncator = Objects.requireNonNull(truncator);
|
||||||
this.account = HuggingFaceAccount.of(model);
|
this.account = HuggingFaceAccount.of(model);
|
||||||
this.truncationResult = Objects.requireNonNull(input);
|
this.truncationResult = Objects.requireNonNull(input);
|
||||||
this.model = Objects.requireNonNull(model);
|
this.model = Objects.requireNonNull(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an HTTP request to the Hugging Face API for embeddings.
|
||||||
|
* The request includes the necessary headers and the input data as a JSON entity.
|
||||||
|
*
|
||||||
|
* @return an HttpRequest object containing the HTTP POST request
|
||||||
|
*/
|
||||||
public HttpRequest createHttpRequest() {
|
public HttpRequest createHttpRequest() {
|
||||||
HttpPost httpPost = new HttpPost(account.uri());
|
HttpPost httpPost = new HttpPost(account.uri());
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(new HuggingFaceInferenceRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
|
Strings.toString(new HuggingFaceEmbeddingsRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
httpPost.setEntity(byteEntity);
|
httpPost.setEntity(byteEntity);
|
||||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||||
|
@ -64,7 +74,7 @@ public class HuggingFaceInferenceRequest implements Request {
|
||||||
public Request truncate() {
|
public Request truncate() {
|
||||||
var truncateResult = truncator.truncate(truncationResult.input());
|
var truncateResult = truncator.truncate(truncationResult.input());
|
||||||
|
|
||||||
return new HuggingFaceInferenceRequest(truncator, truncateResult, model);
|
return new HuggingFaceEmbeddingsRequest(truncator, truncateResult, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
|
||||||
|
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
@ -14,11 +14,15 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record HuggingFaceInferenceRequestEntity(List<String> inputs) implements ToXContentObject {
|
/**
|
||||||
|
* This class represents the request entity for Hugging Face embeddings.
|
||||||
|
* It contains a list of input strings that will be used to generate embeddings.
|
||||||
|
*/
|
||||||
|
public record HuggingFaceEmbeddingsRequestEntity(List<String> inputs) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String INPUTS_FIELD = "inputs";
|
private static final String INPUTS_FIELD = "inputs";
|
||||||
|
|
||||||
public HuggingFaceInferenceRequestEntity {
|
public HuggingFaceEmbeddingsRequestEntity {
|
||||||
Objects.requireNonNull(inputs);
|
Objects.requireNonNull(inputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,9 @@ public class HuggingFaceErrorResponseEntity extends ErrorResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Represents a structured error response specifically for non-streaming operations
|
||||||
|
* using HuggingFace APIs. This is separate from streaming error responses,
|
||||||
|
* which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity.
|
||||||
* An example error response for invalid auth would look like
|
* An example error response for invalid auth would look like
|
||||||
* <code>
|
* <code>
|
||||||
* {
|
* {
|
||||||
|
|
|
@ -29,6 +29,7 @@ import java.util.Locale;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.concurrent.Flow;
|
import java.util.concurrent.Flow;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
import static org.elasticsearch.core.Strings.format;
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
|
||||||
|
@ -37,6 +38,14 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
||||||
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
|
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OpenAiUnifiedChatCompletionResponseHandler(
|
||||||
|
String requestType,
|
||||||
|
ResponseParser parseFunction,
|
||||||
|
Function<HttpResult, ErrorResponse> errorParseFunction
|
||||||
|
) {
|
||||||
|
super(requestType, parseFunction, errorParseFunction);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
|
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
|
||||||
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
|
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
|
||||||
|
@ -59,7 +68,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
||||||
: new UnifiedChatCompletionException(
|
: new UnifiedChatCompletionException(
|
||||||
restStatus,
|
restStatus,
|
||||||
errorMessage,
|
errorMessage,
|
||||||
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
|
createErrorType(errorResponse),
|
||||||
restStatus.name().toLowerCase(Locale.ROOT)
|
restStatus.name().toLowerCase(Locale.ROOT)
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
|
@ -67,7 +76,11 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Exception buildMidStreamError(Request request, String message, Exception e) {
|
protected static String createErrorType(ErrorResponse errorResponse) {
|
||||||
|
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Exception buildMidStreamError(Request request, String message, Exception e) {
|
||||||
var errorResponse = OpenAiErrorResponse.fromString(message);
|
var errorResponse = OpenAiErrorResponse.fromString(message);
|
||||||
if (errorResponse instanceof OpenAiErrorResponse oer) {
|
if (errorResponse instanceof OpenAiErrorResponse oer) {
|
||||||
return new UnifiedChatCompletionException(
|
return new UnifiedChatCompletionException(
|
||||||
|
@ -88,7 +101,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
||||||
return new UnifiedChatCompletionException(
|
return new UnifiedChatCompletionException(
|
||||||
RestStatus.INTERNAL_SERVER_ERROR,
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
|
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
|
||||||
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
|
createErrorType(errorResponse),
|
||||||
"stream_error"
|
"stream_error"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -250,7 +250,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD));
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD));
|
||||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
|
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(
|
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||||
|
|
||||||
|
import org.apache.http.HttpResponse;
|
||||||
|
import org.apache.http.StatusLine;
|
||||||
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
|
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||||
|
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.isA;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
|
||||||
|
private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler(
|
||||||
|
"chat completions",
|
||||||
|
(a, b) -> mock()
|
||||||
|
);
|
||||||
|
|
||||||
|
public void testFailValidationWithAllFields() throws IOException {
|
||||||
|
var responseJson = """
|
||||||
|
{
|
||||||
|
"error": "a message",
|
||||||
|
"type": "validation"
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
var errorJson = invalidResponseJson(responseJson);
|
||||||
|
|
||||||
|
assertThat(errorJson, is("""
|
||||||
|
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
|
||||||
|
inference entity id [id] status [500]. \
|
||||||
|
Error message: [a message]",\
|
||||||
|
"type":"hugging_face_error"}}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFailValidationWithoutOptionalFields() throws IOException {
|
||||||
|
var responseJson = """
|
||||||
|
{
|
||||||
|
"error": "a message"
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
var errorJson = invalidResponseJson(responseJson);
|
||||||
|
|
||||||
|
assertThat(errorJson, is("""
|
||||||
|
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
|
||||||
|
inference entity id [id] status [500]. \
|
||||||
|
Error message: [a message]","type":"hugging_face_error"}}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFailValidationWithInvalidJson() throws IOException {
|
||||||
|
var responseJson = """
|
||||||
|
what? this isn't a json
|
||||||
|
""";
|
||||||
|
|
||||||
|
var errorJson = invalidResponseJson(responseJson);
|
||||||
|
|
||||||
|
assertThat(errorJson, is("""
|
||||||
|
{"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\
|
||||||
|
[500]","type":"ErrorResponse"}}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String invalidResponseJson(String responseJson) throws IOException {
|
||||||
|
var exception = invalidResponse(responseJson);
|
||||||
|
assertThat(exception, isA(RetryException.class));
|
||||||
|
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
|
||||||
|
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Exception invalidResponse(String responseJson) {
|
||||||
|
return expectThrows(
|
||||||
|
RetryException.class,
|
||||||
|
() -> responseHandler.validateResponse(
|
||||||
|
mock(),
|
||||||
|
mock(),
|
||||||
|
mockRequest(),
|
||||||
|
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
|
||||||
|
true
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Request mockRequest() {
|
||||||
|
var request = mock(Request.class);
|
||||||
|
when(request.getInferenceEntityId()).thenReturn("id");
|
||||||
|
when(request.isStreaming()).thenReturn(true);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static HttpResponse mock500Response() {
|
||||||
|
int statusCode = 500;
|
||||||
|
var statusLine = mock(StatusLine.class);
|
||||||
|
when(statusLine.getStatusCode()).thenReturn(statusCode);
|
||||||
|
|
||||||
|
var response = mock(HttpResponse.class);
|
||||||
|
when(response.getStatusLine()).thenReturn(statusLine);
|
||||||
|
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String toJson(UnifiedChatCompletionException e) throws IOException {
|
||||||
|
try (var builder = XContentFactory.jsonBuilder()) {
|
||||||
|
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||||
|
try {
|
||||||
|
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||||
|
} catch (IOException ex) {
|
||||||
|
throw new RuntimeException(ex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ package org.elasticsearch.xpack.inference.services.huggingface;
|
||||||
import org.apache.http.HttpHeaders;
|
import org.apache.http.HttpHeaders;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.ActionTestUtils;
|
||||||
import org.elasticsearch.action.support.PlainActionFuture;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
import org.elasticsearch.common.ValidationException;
|
import org.elasticsearch.common.ValidationException;
|
||||||
import org.elasticsearch.common.bytes.BytesArray;
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
@ -29,20 +30,29 @@ import org.elasticsearch.inference.Model;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.SimilarityMeasure;
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.test.http.MockResponse;
|
import org.elasticsearch.test.http.MockResponse;
|
||||||
import org.elasticsearch.test.http.MockWebServer;
|
import org.elasticsearch.test.http.MockWebServer;
|
||||||
import org.elasticsearch.threadpool.ThreadPool;
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xcontent.ToXContent;
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||||
|
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
|
@ -53,14 +63,19 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.EnumSet;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||||
|
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||||
|
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
|
@ -74,7 +89,12 @@ import static org.hamcrest.CoreMatchers.is;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
import static org.hamcrest.Matchers.isA;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class HuggingFaceServiceTests extends ESTestCase {
|
public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||||
|
@ -175,6 +195,438 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel() throws IOException {
|
||||||
|
var url = "url";
|
||||||
|
var model = "model";
|
||||||
|
var secret = "secret";
|
||||||
|
|
||||||
|
try (var service = createHuggingFaceService()) {
|
||||||
|
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||||
|
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||||
|
|
||||||
|
var completionsModel = (HuggingFaceChatCompletionModel) m;
|
||||||
|
|
||||||
|
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
|
||||||
|
assertThat(completionsModel.getServiceSettings().modelId(), is(model));
|
||||||
|
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||||
|
|
||||||
|
}, exception -> fail("Unexpected exception: " + exception));
|
||||||
|
|
||||||
|
service.parseRequestConfig(
|
||||||
|
"id",
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
getRequestConfigMap(
|
||||||
|
HuggingFaceChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model),
|
||||||
|
getSecretSettingsMap(secret)
|
||||||
|
),
|
||||||
|
modelVerificationListener
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_WithoutModelId() throws IOException {
|
||||||
|
var url = "url";
|
||||||
|
var secret = "secret";
|
||||||
|
|
||||||
|
try (var service = createHuggingFaceService()) {
|
||||||
|
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||||
|
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||||
|
|
||||||
|
var completionsModel = (HuggingFaceChatCompletionModel) m;
|
||||||
|
|
||||||
|
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
|
||||||
|
assertNull(completionsModel.getServiceSettings().modelId());
|
||||||
|
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||||
|
|
||||||
|
}, exception -> fail("Unexpected exception: " + exception));
|
||||||
|
|
||||||
|
service.parseRequestConfig(
|
||||||
|
"id",
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
getRequestConfigMap(getServiceSettingsMap(url), getSecretSettingsMap(secret)),
|
||||||
|
modelVerificationListener
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
||||||
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
||||||
|
var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION);
|
||||||
|
|
||||||
|
try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) {
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.infer(
|
||||||
|
mockModel,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of(""),
|
||||||
|
false,
|
||||||
|
new HashMap<>(),
|
||||||
|
InputType.INGEST,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||||
|
);
|
||||||
|
|
||||||
|
verify(factory, times(1)).createSender();
|
||||||
|
verify(sender, times(1)).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(sender, times(1)).close();
|
||||||
|
verifyNoMoreInteractions(factory);
|
||||||
|
verifyNoMoreInteractions(sender);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedCompletionInfer() throws Exception {
|
||||||
|
// The escapes are because the streaming response must be on a single line
|
||||||
|
String responseJson = """
|
||||||
|
data: {\
|
||||||
|
"id":"12345",\
|
||||||
|
"object":"chat.completion.chunk",\
|
||||||
|
"created":123456789,\
|
||||||
|
"model":"gpt-4o-mini",\
|
||||||
|
"system_fingerprint": "123456789",\
|
||||||
|
"choices":[\
|
||||||
|
{\
|
||||||
|
"index":0,\
|
||||||
|
"delta":{\
|
||||||
|
"content":"hello, world"\
|
||||||
|
},\
|
||||||
|
"logprobs":null,\
|
||||||
|
"finish_reason":"stop"\
|
||||||
|
}\
|
||||||
|
],\
|
||||||
|
"usage":{\
|
||||||
|
"prompt_tokens": 16,\
|
||||||
|
"completion_tokens": 28,\
|
||||||
|
"total_tokens": 44,\
|
||||||
|
"prompt_tokens_details": {\
|
||||||
|
"cached_tokens": 0,\
|
||||||
|
"audio_tokens": 0\
|
||||||
|
},\
|
||||||
|
"completion_tokens_details": {\
|
||||||
|
"reasoning_tokens": 0,\
|
||||||
|
"audio_tokens": 0,\
|
||||||
|
"accepted_prediction_tokens": 0,\
|
||||||
|
"rejected_prediction_tokens": 0\
|
||||||
|
}\
|
||||||
|
}\
|
||||||
|
}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.unifiedCompletionInfer(
|
||||||
|
model,
|
||||||
|
UnifiedCompletionRequest.of(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||||
|
),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||||
|
{"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """
|
||||||
|
"model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """
|
||||||
|
"usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}""");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedCompletionNonStreamingError() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"error": "Model not found."
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
|
||||||
|
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||||
|
var latch = new CountDownLatch(1);
|
||||||
|
service.unifiedCompletionInfer(
|
||||||
|
model,
|
||||||
|
UnifiedCompletionRequest.of(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||||
|
),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
|
||||||
|
try (var builder = XContentFactory.jsonBuilder()) {
|
||||||
|
var t = unwrapCause(e);
|
||||||
|
assertThat(t, isA(UnifiedChatCompletionException.class));
|
||||||
|
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||||
|
try {
|
||||||
|
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||||
|
} catch (IOException ex) {
|
||||||
|
throw new RuntimeException(ex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||||
|
|
||||||
|
assertThat(json, is("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"code":"not_found",\
|
||||||
|
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||||
|
[404]. Error message: [Model not found.]",\
|
||||||
|
"type":"hugging_face_error"\
|
||||||
|
}}"""));
|
||||||
|
} catch (IOException ex) {
|
||||||
|
throw new RuntimeException(ex);
|
||||||
|
}
|
||||||
|
}), latch::countDown)
|
||||||
|
);
|
||||||
|
assertTrue(latch.await(30, TimeUnit.SECONDS));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMidStreamUnifiedCompletionError() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
event: error
|
||||||
|
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||||
|
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\
|
||||||
|
If it should be supported, please open an issue.","http_status_code":422}}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
testStreamError("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"code":"422",\
|
||||||
|
"message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \
|
||||||
|
cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \
|
||||||
|
JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\
|
||||||
|
"type":"hugging_face_error"\
|
||||||
|
}}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
event: error
|
||||||
|
data: {"error":{"http_status_code":422}}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
testStreamError("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"code":"422",\
|
||||||
|
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||||
|
[unknown]",\
|
||||||
|
"type":"hugging_face_error"\
|
||||||
|
}}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
event: error
|
||||||
|
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||||
|
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \
|
||||||
|
Outlines.\\nIf it should be supported, please open an issue."}}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
testStreamError("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||||
|
[Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||||
|
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\
|
||||||
|
by Outlines.\\nIf it should be supported, please open an issue.]",\
|
||||||
|
"type":"hugging_face_error"\
|
||||||
|
}}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
event: error
|
||||||
|
data: {"error":{}}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
testStreamError("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||||
|
[unknown]",\
|
||||||
|
"type":"hugging_face_error"\
|
||||||
|
}}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedCompletionMalformedError() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
data: { invalid json }
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
testStreamError("""
|
||||||
|
{\
|
||||||
|
"error":{\
|
||||||
|
"code":"bad_request",\
|
||||||
|
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
|
||||||
|
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
|
||||||
|
"type":"x_content_parse_exception"\
|
||||||
|
}}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testStreamError(String expectedResponse) throws Exception {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.unifiedCompletionInfer(
|
||||||
|
model,
|
||||||
|
UnifiedCompletionRequest.of(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||||
|
),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
|
||||||
|
e = unwrapCause(e);
|
||||||
|
assertThat(e, isA(UnifiedChatCompletionException.class));
|
||||||
|
try (var builder = XContentFactory.jsonBuilder()) {
|
||||||
|
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||||
|
try {
|
||||||
|
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||||
|
} catch (IOException ex) {
|
||||||
|
throw new RuntimeException(ex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||||
|
|
||||||
|
assertThat(json, is(expectedResponse));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfer_StreamRequest() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
data: {\
|
||||||
|
"id":"12345",\
|
||||||
|
"object":"chat.completion.chunk",\
|
||||||
|
"created":123456789,\
|
||||||
|
"model":"gpt-4o-mini",\
|
||||||
|
"system_fingerprint": "123456789",\
|
||||||
|
"choices":[\
|
||||||
|
{\
|
||||||
|
"index":0,\
|
||||||
|
"delta":{\
|
||||||
|
"content":"hello, world"\
|
||||||
|
},\
|
||||||
|
"logprobs":null,\
|
||||||
|
"finish_reason":null\
|
||||||
|
}\
|
||||||
|
]\
|
||||||
|
}
|
||||||
|
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
streamCompletion().hasNoErrors().hasEvent("""
|
||||||
|
{"completion":[{"delta":"hello, world"}]}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
private InferenceEventsAssertion streamCompletion() throws Exception {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
service.infer(
|
||||||
|
model,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of("abc"),
|
||||||
|
true,
|
||||||
|
new HashMap<>(),
|
||||||
|
InputType.INGEST,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfer_StreamRequest_ErrorResponse() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "You didn't provide an API key..."
|
||||||
|
}
|
||||||
|
}""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||||
|
|
||||||
|
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
|
||||||
|
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
|
||||||
|
assertThat(
|
||||||
|
e.getMessage(),
|
||||||
|
equalTo(
|
||||||
|
"Received an authentication error status code for request from inference entity id [id] status [401]. "
|
||||||
|
+ "Error message: [You didn't provide an API key...]"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfer_StreamRequestRetry() throws Exception {
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": "server busy"
|
||||||
|
}
|
||||||
|
}"""));
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||||
|
data: {\
|
||||||
|
"id":"12345",\
|
||||||
|
"object":"chat.completion.chunk",\
|
||||||
|
"created":123456789,\
|
||||||
|
"model":"gpt-4o-mini",\
|
||||||
|
"system_fingerprint": "123456789",\
|
||||||
|
"choices":[\
|
||||||
|
{\
|
||||||
|
"index":0,\
|
||||||
|
"delta":{\
|
||||||
|
"content":"hello, world"\
|
||||||
|
},\
|
||||||
|
"logprobs":null,\
|
||||||
|
"finish_reason":null\
|
||||||
|
}\
|
||||||
|
]\
|
||||||
|
}
|
||||||
|
|
||||||
|
"""));
|
||||||
|
|
||||||
|
streamCompletion().hasNoErrors().hasEvent("""
|
||||||
|
{"completion":[{"delta":"hello, world"}]}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSupportsStreaming() throws IOException {
|
||||||
|
try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) {
|
||||||
|
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
|
||||||
|
assertFalse(service.canStream(TaskType.ANY));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
|
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
|
||||||
try (var service = createHuggingFaceService()) {
|
try (var service = createHuggingFaceService()) {
|
||||||
var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret"));
|
var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret"));
|
||||||
|
@ -258,6 +710,25 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException {
|
||||||
|
try (var service = createHuggingFaceService()) {
|
||||||
|
var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret"));
|
||||||
|
|
||||||
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
|
"id",
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
persistedConfig.config(),
|
||||||
|
persistedConfig.secrets()
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||||
|
|
||||||
|
var chatCompletionModel = (HuggingFaceChatCompletionModel) model;
|
||||||
|
assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is("url"));
|
||||||
|
assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||||
try (var service = createHuggingFaceService()) {
|
try (var service = createHuggingFaceService()) {
|
||||||
var persistedConfig = getPersistedConfigMap(
|
var persistedConfig = getPersistedConfigMap(
|
||||||
|
@ -821,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
{
|
{
|
||||||
"service": "hugging_face",
|
"service": "hugging_face",
|
||||||
"name": "Hugging Face",
|
"name": "Hugging Face",
|
||||||
"task_types": ["text_embedding", "sparse_embedding"],
|
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
|
||||||
"configurations": {
|
"configurations": {
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"description": "API Key for the provider you're connecting to.",
|
"description": "API Key for the provider you're connecting to.",
|
||||||
|
@ -830,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
"sensitive": true,
|
"sensitive": true,
|
||||||
"updatable": true,
|
"updatable": true,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||||
},
|
},
|
||||||
"rate_limit.requests_per_minute": {
|
"rate_limit.requests_per_minute": {
|
||||||
"description": "Minimize the number of rate limit errors.",
|
"description": "Minimize the number of rate limit errors.",
|
||||||
|
@ -839,17 +1310,16 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||||
},
|
},
|
||||||
"url": {
|
"url": {
|
||||||
"default_value": "https://api.openai.com/v1/embeddings",
|
|
||||||
"description": "The URL endpoint to use for the requests.",
|
"description": "The URL endpoint to use for the requests.",
|
||||||
"label": "URL",
|
"label": "URL",
|
||||||
"required": true,
|
"required": true,
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,10 +24,13 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsT
|
||||||
import org.elasticsearch.xpack.inference.InputTypeTests;
|
import org.elasticsearch.xpack.inference.InputTypeTests;
|
||||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -38,6 +41,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
|
@ -425,4 +429,107 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"object": "chat.completion",
|
||||||
|
"id": "",
|
||||||
|
"created": 1745855316,
|
||||||
|
"model": "/repository",
|
||||||
|
"system_fingerprint": "3.2.3-sha-a1f3ebe",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello there, how may I assist you today?"
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 8,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 58
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
|
||||||
|
|
||||||
|
assertChatCompletionRequest();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
|
||||||
|
var settings = buildSettingsWithRetryFields(
|
||||||
|
TimeValue.timeValueMillis(1),
|
||||||
|
TimeValue.timeValueMinutes(1),
|
||||||
|
TimeValue.timeValueSeconds(0)
|
||||||
|
);
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"invalid_field": "unexpected"
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
|
||||||
|
sender,
|
||||||
|
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
|
||||||
|
);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]")
|
||||||
|
);
|
||||||
|
|
||||||
|
assertChatCompletionRequest();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
|
||||||
|
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
|
||||||
|
var actionCreator = new HuggingFaceActionCreator(sender, threadPool);
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
return listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertChatCompletionRequest() throws IOException {
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||||
|
assertThat(
|
||||||
|
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
||||||
|
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
||||||
|
);
|
||||||
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.size(), is(4));
|
||||||
|
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
|
||||||
|
assertThat(requestMap.get("model"), is("model"));
|
||||||
|
assertThat(requestMap.get("n"), is(1));
|
||||||
|
assertThat(requestMap.get("stream"), is(false));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,243 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.test.http.MockRequest;
|
||||||
|
import org.elasticsearch.test.http.MockResponse;
|
||||||
|
import org.elasticsearch.test.http.MockWebServer;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
|
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||||
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.COMPLETION_HANDLER;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.USER_ROLE;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
public class HuggingFaceChatCompletionActionTests extends ESTestCase {
|
||||||
|
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||||
|
private final MockWebServer webServer = new MockWebServer();
|
||||||
|
private ThreadPool threadPool;
|
||||||
|
private HttpClientManager clientManager;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() throws Exception {
|
||||||
|
webServer.start();
|
||||||
|
threadPool = createThreadPool(inferenceUtilityPool());
|
||||||
|
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void shutdown() throws IOException {
|
||||||
|
clientManager.close();
|
||||||
|
terminate(threadPool);
|
||||||
|
webServer.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||||
|
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "result content"
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var action = createAction(getUrl(webServer), sender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content"))));
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
|
||||||
|
MockRequest request = webServer.requests().get(0);
|
||||||
|
|
||||||
|
assertNull(request.getUri().getQuery());
|
||||||
|
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()));
|
||||||
|
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(request.getBody());
|
||||||
|
assertThat(requestMap.size(), is(4));
|
||||||
|
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
|
||||||
|
assertThat(requestMap.get("model"), is("model"));
|
||||||
|
assertThat(requestMap.get("n"), is(1));
|
||||||
|
assertThat(requestMap.get("stream"), is(false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
|
||||||
|
try (var sender = mock(Sender.class)) {
|
||||||
|
var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", sender));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExecute_ThrowsElasticsearchException() {
|
||||||
|
var sender = mock(Sender.class);
|
||||||
|
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||||
|
|
||||||
|
var action = createAction(getUrl(webServer), sender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
assertThat(thrownException.getMessage(), is("failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
||||||
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
|
doAnswer(invocation -> {
|
||||||
|
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
||||||
|
listener.onFailure(new IllegalStateException("failed"));
|
||||||
|
|
||||||
|
return Void.TYPE;
|
||||||
|
}).when(sender).send(any(), any(), any(), any());
|
||||||
|
|
||||||
|
var action = createAction(getUrl(webServer), sender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
assertThat(thrownException.getMessage(), is("Failed to send hugging face chat completions request. Cause: failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello there, how may I assist you today?"
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var action = createAction(getUrl(webServer), sender);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
|
||||||
|
assertThat(thrownException.getMessage(), is("hugging face chat completions only accepts 1 input"));
|
||||||
|
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ExecutableAction createAction(String url, Sender sender) {
|
||||||
|
var model = createCompletionModel(url, "secret", "model");
|
||||||
|
var manager = new GenericRequestManager<>(
|
||||||
|
threadPool,
|
||||||
|
model,
|
||||||
|
COMPLETION_HANDLER,
|
||||||
|
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||||
|
ChatCompletionInput.class
|
||||||
|
);
|
||||||
|
var errorMessage = constructFailedToSendRequestMessage("hugging face chat completions");
|
||||||
|
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "hugging face chat completions");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
|
public class HuggingFaceChatCompletionModelTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testThrowsURISyntaxException_ForInvalidUrl() {
|
||||||
|
var thrownException = expectThrows(IllegalArgumentException.class, () -> createCompletionModel("^^", "secret", "id"));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) {
|
||||||
|
return new HuggingFaceChatCompletionModel(
|
||||||
|
"id",
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
"service",
|
||||||
|
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
|
||||||
|
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) {
|
||||||
|
return new HuggingFaceChatCompletionModel(
|
||||||
|
"id",
|
||||||
|
TaskType.CHAT_COMPLETION,
|
||||||
|
"service",
|
||||||
|
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
|
||||||
|
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
|
||||||
|
var model = createCompletionModel("url", "api_key", "model_name");
|
||||||
|
var request = new UnifiedCompletionRequest(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||||
|
"different_model",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
|
||||||
|
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||||
|
|
||||||
|
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
|
||||||
|
var model = createCompletionModel("url", "api_key", null);
|
||||||
|
var request = new UnifiedCompletionRequest(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||||
|
"different_model",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
|
||||||
|
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||||
|
|
||||||
|
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
|
||||||
|
var model = createCompletionModel("url", "api_key", null);
|
||||||
|
var request = new UnifiedCompletionRequest(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
|
||||||
|
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||||
|
|
||||||
|
assertNull(overriddenModel.getServiceSettings().modelId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
|
||||||
|
var model = createCompletionModel("url", "api_key", "model_name");
|
||||||
|
var request = new UnifiedCompletionRequest(
|
||||||
|
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||||
|
null, // not overriding model
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
|
||||||
|
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||||
|
|
||||||
|
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,271 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
|
public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<
|
||||||
|
HuggingFaceChatCompletionServiceSettings> {
|
||||||
|
|
||||||
|
public static final String MODEL_ID = "some model";
|
||||||
|
public static final String CORRECT_URL = "https://www.elastic.co";
|
||||||
|
public static final int RATE_LIMIT = 2;
|
||||||
|
|
||||||
|
public void testFromMap_AllFields_Success() {
|
||||||
|
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
MODEL_ID,
|
||||||
|
ServiceFields.URL,
|
||||||
|
CORRECT_URL,
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
serviceSettings,
|
||||||
|
is(
|
||||||
|
new HuggingFaceChatCompletionServiceSettings(
|
||||||
|
MODEL_ID,
|
||||||
|
ServiceUtils.createUri(CORRECT_URL),
|
||||||
|
new RateLimitSettings(RATE_LIMIT)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_MissingModelId_Success() {
|
||||||
|
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.URL,
|
||||||
|
CORRECT_URL,
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
serviceSettings,
|
||||||
|
is(new HuggingFaceChatCompletionServiceSettings(null, ServiceUtils.createUri(CORRECT_URL), new RateLimitSettings(RATE_LIMIT)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_MissingRateLimit_Success() {
|
||||||
|
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(serviceSettings, is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), null)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_MissingUrl_ThrowsException() {
|
||||||
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
MODEL_ID,
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
containsString(
|
||||||
|
Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [url];", ServiceFields.URL)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_EmptyUrl_ThrowsException() {
|
||||||
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
MODEL_ID,
|
||||||
|
ServiceFields.URL,
|
||||||
|
"",
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
containsString(
|
||||||
|
Strings.format(
|
||||||
|
"Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
|
||||||
|
ServiceFields.URL
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_InvalidUrl_ThrowsException() {
|
||||||
|
String invalidUrl = "https://www.elastic^^co";
|
||||||
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
MODEL_ID,
|
||||||
|
ServiceFields.URL,
|
||||||
|
invalidUrl,
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
containsString(
|
||||||
|
Strings.format(
|
||||||
|
"Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]",
|
||||||
|
invalidUrl,
|
||||||
|
ServiceFields.URL
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_WritesAllValues() throws IOException {
|
||||||
|
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
MODEL_ID,
|
||||||
|
ServiceFields.URL,
|
||||||
|
CORRECT_URL,
|
||||||
|
RateLimitSettings.FIELD_NAME,
|
||||||
|
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
serviceSettings.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
var expected = XContentHelper.stripWhitespace("""
|
||||||
|
{
|
||||||
|
"model_id": "some model",
|
||||||
|
"url": "https://www.elastic.co",
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
|
||||||
|
assertThat(xContentResult, is(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
|
||||||
|
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||||
|
new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)),
|
||||||
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
serviceSettings.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
var expected = XContentHelper.stripWhitespace("""
|
||||||
|
{
|
||||||
|
"url": "https://www.elastic.co",
|
||||||
|
"rate_limit": {
|
||||||
|
"requests_per_minute": 3000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
assertThat(xContentResult, is(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<HuggingFaceChatCompletionServiceSettings> instanceReader() {
|
||||||
|
return HuggingFaceChatCompletionServiceSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceChatCompletionServiceSettings createTestInstance() {
|
||||||
|
return createRandomWithNonNullUrl();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceChatCompletionServiceSettings instance)
|
||||||
|
throws IOException {
|
||||||
|
return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceChatCompletionServiceSettings mutateInstanceForVersion(
|
||||||
|
HuggingFaceChatCompletionServiceSettings instance,
|
||||||
|
TransportVersion version
|
||||||
|
) {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() {
|
||||||
|
return createRandom(randomAlphaOfLength(15));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static HuggingFaceChatCompletionServiceSettings createRandom(String url) {
|
||||||
|
var modelId = randomAlphaOfLength(8);
|
||||||
|
|
||||||
|
return new HuggingFaceChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> getServiceSettingsMap(String url, String model) {
|
||||||
|
var map = new HashMap<String, Object>();
|
||||||
|
|
||||||
|
map.put(ServiceFields.URL, url);
|
||||||
|
map.put(ServiceFields.MODEL_ID, model);
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,16 +12,17 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequestEntity;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
|
||||||
public class HuggingFaceInferenceRequestEntityTests extends ESTestCase {
|
public class HuggingFaceEmbeddingsRequestEntityTests extends ESTestCase {
|
||||||
|
|
||||||
public void testXContent() throws IOException {
|
public void testXContent() throws IOException {
|
||||||
var entity = new HuggingFaceInferenceRequestEntity(List.of("abc"));
|
var entity = new HuggingFaceEmbeddingsRequestEntity(List.of("abc"));
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.inference.common.Truncator;
|
import org.elasticsearch.xpack.inference.common.Truncator;
|
||||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.contains;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
public class HuggingFaceInferenceRequestTests extends ESTestCase {
|
public class HuggingFaceEmbeddingsRequestTests extends ESTestCase {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public void testCreateRequest() throws URISyntaxException, IOException {
|
public void testCreateRequest() throws URISyntaxException, IOException {
|
||||||
var huggingFaceRequest = createRequest("www.google.com", "secret", "abc");
|
var huggingFaceRequest = createRequest("www.google.com", "secret", "abc");
|
||||||
|
@ -67,9 +68,9 @@ public class HuggingFaceInferenceRequestTests extends ESTestCase {
|
||||||
assertTrue(truncatedRequest.getTruncationInfo()[0]);
|
assertTrue(truncatedRequest.getTruncationInfo()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static HuggingFaceInferenceRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
|
public static HuggingFaceEmbeddingsRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
|
||||||
|
|
||||||
return new HuggingFaceInferenceRequest(
|
return new HuggingFaceEmbeddingsRequest(
|
||||||
TruncatorTests.createTruncator(),
|
TruncatorTests.createTruncator(),
|
||||||
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
|
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
|
||||||
HuggingFaceEmbeddingsModelTests.createModel(url, apiKey)
|
HuggingFaceEmbeddingsModelTests.createModel(url, apiKey)
|
|
@ -0,0 +1,69 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
|
||||||
|
|
||||||
|
public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase {
|
||||||
|
|
||||||
|
private static final String ROLE = "user";
|
||||||
|
|
||||||
|
public void testModelUserFieldsSerialization() throws IOException {
|
||||||
|
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
|
||||||
|
new UnifiedCompletionRequest.ContentString("Hello, world!"),
|
||||||
|
ROLE,
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
);
|
||||||
|
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
|
||||||
|
messageList.add(message);
|
||||||
|
|
||||||
|
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
|
||||||
|
|
||||||
|
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
|
||||||
|
HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint");
|
||||||
|
|
||||||
|
HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
|
||||||
|
|
||||||
|
XContentBuilder builder = JsonXContent.contentBuilder();
|
||||||
|
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
|
||||||
|
String jsonString = Strings.toString(builder);
|
||||||
|
String expectedJson = """
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "test-endpoint",
|
||||||
|
"n": 1,
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
assertJsonEquals(jsonString, expectedJson);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||||
|
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
|
import static org.hamcrest.Matchers.aMapWithSize;
|
||||||
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
|
public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testCreateRequest_WithStreaming() throws IOException {
|
||||||
|
var request = createRequest("url", "secret", randomAlphaOfLength(15), "model", true);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap.get("stream"), is(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
|
||||||
|
String input = randomAlphaOfLength(5);
|
||||||
|
var request = createRequest("url", "secret", input, "model", true);
|
||||||
|
var truncatedRequest = request.truncate();
|
||||||
|
assertThat(request.getURI().toString(), is("url"));
|
||||||
|
|
||||||
|
var httpRequest = truncatedRequest.createHttpRequest();
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap, aMapWithSize(5));
|
||||||
|
|
||||||
|
// We do not truncate for Hugging Face chat completions
|
||||||
|
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
|
||||||
|
assertThat(requestMap.get("model"), is("model"));
|
||||||
|
assertThat(requestMap.get("n"), is(1));
|
||||||
|
assertTrue((Boolean) requestMap.get("stream"));
|
||||||
|
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTruncationInfo_ReturnsNull() {
|
||||||
|
var request = createRequest("url", "secret", randomAlphaOfLength(5), "model", true);
|
||||||
|
assertNull(request.getTruncationInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceUnifiedChatCompletionRequest createRequest(String url, String apiKey, String input, @Nullable String model) {
|
||||||
|
return createRequest(url, apiKey, input, model, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceUnifiedChatCompletionRequest createRequest(
|
||||||
|
@Nullable String url,
|
||||||
|
String apiKey,
|
||||||
|
String input,
|
||||||
|
@Nullable String model,
|
||||||
|
boolean stream
|
||||||
|
) {
|
||||||
|
var chatCompletionModel = HuggingFaceChatCompletionModelTests.createCompletionModel(url, apiKey, model);
|
||||||
|
return new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -19,6 +19,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatComple
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
|
public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
|
||||||
|
|
||||||
public void testJsonLiteral() {
|
public void testJsonLiteral() {
|
||||||
|
@ -182,6 +184,73 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testJsonNullFunctionName() throws IOException {
|
||||||
|
String json = """
|
||||||
|
{
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"id": "",
|
||||||
|
"created": 1746800254,
|
||||||
|
"model": "/repository",
|
||||||
|
"system_fingerprint": "3.2.3-sha-a1f3ebe",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": null,
|
||||||
|
"arguments": " \\\""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, json)) {
|
||||||
|
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser
|
||||||
|
.parse(parser);
|
||||||
|
|
||||||
|
// Assertions to verify the parsed object
|
||||||
|
assertThat(chunk.id(), is(""));
|
||||||
|
assertThat(chunk.model(), is("/repository"));
|
||||||
|
assertThat(chunk.object(), is("chat.completion.chunk"));
|
||||||
|
assertNull(chunk.usage());
|
||||||
|
|
||||||
|
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
|
||||||
|
assertThat(choices.size(), is(1));
|
||||||
|
|
||||||
|
// First choice assertions
|
||||||
|
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0);
|
||||||
|
assertNull(firstChoice.delta().content());
|
||||||
|
assertNull(firstChoice.delta().refusal());
|
||||||
|
assertThat(firstChoice.delta().role(), is("assistant"));
|
||||||
|
assertNull(firstChoice.finishReason());
|
||||||
|
assertThat(firstChoice.index(), is(0));
|
||||||
|
|
||||||
|
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = firstChoice.delta()
|
||||||
|
.toolCalls();
|
||||||
|
assertThat(toolCalls.size(), is(1));
|
||||||
|
|
||||||
|
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
|
||||||
|
assertThat(toolCall.index(), is(0));
|
||||||
|
assertThat(toolCall.id(), is("8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a"));
|
||||||
|
assertThat(toolCall.type(), is("function"));
|
||||||
|
assertNull(toolCall.function().name());
|
||||||
|
assertThat(toolCall.function().arguments(), is(" \""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException {
|
public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException {
|
||||||
// Generate random values for the JSON fields
|
// Generate random values for the JSON fields
|
||||||
int toolCallIndex = randomIntBetween(0, 10);
|
int toolCallIndex = randomIntBetween(0, 10);
|
||||||
|
|
Loading…
Reference in New Issue