Add Hugging Face Chat Completion support to Inference Plugin (#127254)

* Add Hugging Face Chat Completion support to Inference Plugin

* Add support for streaming chat completion task for HuggingFace

* [CI] Auto commit changes from spotless

* Add support for non-streaming completion task for HuggingFace

* Remove RequestManager for HF Chat Completion Task

* Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests

* Refactored Hugging Face Action Creator, added Unit Tests

* Add Hugging Face Server Test

* [CI] Auto commit changes from spotless

* Removed parameters from media type for Chat Completion Request and unit tests

* Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT

* Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService

* Update minimal supported version and add Hugging Face transport version constants

* Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests

* Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests

* Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param

* Removed getFirst() method calls for backport compatibility

* Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing

* Refactored tests to use stripWhitespace method for readability

* Refactored javadoc for HuggingFaceService

* Renamed HF chat completion TransportVersion constant names

* Added random string generation in unit test

* Refactored javadocs for HuggingFace requests

* Refactored tests to reduce duplication

* Added changelog file

* Add HuggingFaceChatCompletionResponseHandler and associated tests

* Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types

* Refactor HuggingFace error handling to improve response structure and add streaming support

* Allowing null function name for hugging face models

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
This commit is contained in:
Jan-Kazlouski-elastic 2025-05-19 19:37:19 +03:00 committed by GitHub
parent 54f26680ea
commit d1ad917855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 2334 additions and 49 deletions

View File

@ -0,0 +1,5 @@
pr: 127254
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []

View File

@ -175,6 +175,7 @@ public class TransportVersions {
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@ -255,6 +256,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -123,7 +123,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));
var providers = providers(services);
@ -140,7 +140,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
"streaming_completion_test_service",
"hugging_face"
).toArray()
)
);
@ -148,11 +149,14 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
assertThat(services.size(), equalTo(5));
var providers = providers(services);
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
);
}
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {

View File

@ -78,6 +78,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.Goog
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
@ -357,6 +358,13 @@ public class InferenceNamedWriteablesProvider {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceChatCompletionServiceSettings.NAME,
HuggingFaceChatCompletionServiceSettings::new
)
);
}
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

View File

@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
import java.util.Locale;
import java.util.Optional;
import static org.elasticsearch.core.Strings.format;
/**
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
*/
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
private static final String HUGGING_FACE_ERROR = "hugging_face_error";
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
}
@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof HuggingFaceErrorResponseEntity
? new UnifiedChatCompletionException(
restStatus,
errorMessage,
HUGGING_FACE_ERROR,
restStatus.name().toLowerCase(Locale.ROOT)
)
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}
@Override
protected Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
HUGGING_FACE_ERROR,
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
createErrorType(errorResponse),
"stream_error"
);
}
}
private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
: null;
}
/**
* Represents a structured error response specifically for streaming operations
* using HuggingFace APIs. This is separate from non-streaming error responses,
* which are handled by {@link HuggingFaceErrorResponseEntity}.
* An example error response for failed field validation for streaming operation would look like
* <code>
* {
* "error": "Input validation error: cannot compile regex from schema",
* "http_status_code": 422
* }
* </code>
*/
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
);
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
);
static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));
ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}
/**
* Parses a streaming HuggingFace error response from a JSON string.
*
* @param response the raw JSON string representing an error
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
*/
private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
@Nullable
private final Integer httpStatusCode;
StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
super(errorMessage);
this.httpStatusCode = httpStatusCode;
}
@Nullable
public Integer httpStatusCode() {
return httpStatusCode;
}
}
}

View File

@ -9,17 +9,18 @@ package org.elasticsearch.xpack.inference.services.huggingface;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.Objects;
public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;
@ -38,6 +39,16 @@ public abstract class HuggingFaceModel extends Model {
return rateLimitServiceSettings;
}
@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}
public SecureString apiKey() {
return apiKey;
}

View File

@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
import java.util.List;
import java.util.Objects;
@ -64,7 +64,7 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}

View File

@ -26,15 +26,21 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -42,16 +48,29 @@ import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
/**
* This class is responsible for managing the Hugging Face inference service.
* It manages model creation, as well as chunked, non-chunked, and unified completion inference.
*/
public class HuggingFaceService extends HuggingFaceBaseService {
public static final String NAME = "hugging_face";
private static final String SERVICE_NAME = "Hugging Face";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
"hugging face chat completion",
OpenAiChatCompletionResponseEntity::fromResponse
);
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
@ -78,6 +97,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
context
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
@ -139,7 +166,29 @@ public class HuggingFaceService extends HuggingFaceBaseService {
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof HuggingFaceChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = HuggingFaceActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}
@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}
@Override
@ -149,7 +198,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES;
}
@Override
@ -167,14 +216,15 @@ public class HuggingFaceService extends HuggingFaceBaseService {
return configuration.getOrCompute();
}
private Configuration() {}
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
URL,
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
.setDescription("The URL endpoint to use for the requests.")
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.")
.setLabel("URL")
.setRequired(true)
.setSensitive(false)
@ -183,12 +233,12 @@ public class HuggingFaceService extends HuggingFaceBaseService {
.build()
);
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES)
.setConfigurations(configurationMap)
.build();
}

View File

@ -7,16 +7,26 @@
package org.elasticsearch.xpack.inference.services.huggingface.action;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import java.util.Objects;
@ -26,6 +36,13 @@ import static org.elasticsearch.core.Strings.format;
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
*/
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
static final String USER_ROLE = "user";
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"hugging face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);
private final Sender sender;
private final ServiceComponents serviceComponents;
@ -46,11 +63,7 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"text embeddings",
model.getInferenceEntityId()
);
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}
@ -63,11 +76,25 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var errorMessage = format(
"Failed to send Hugging Face %s request from inference entity id [%s]",
"ELSER",
model.getInferenceEntityId()
);
var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, errorMessage);
}
@Override
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
}
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
}
}

View File

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.huggingface.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
@ -15,4 +16,6 @@ public interface HuggingFaceActionVisitor {
ExecutableAction create(HuggingFaceEmbeddingsModel model);
ExecutableAction create(HuggingFaceElserModel model);
ExecutableAction create(HuggingFaceChatCompletionModel model);
}

View File

@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.completion;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.Map;
public class HuggingFaceChatCompletionModel extends HuggingFaceModel {
/**
* Creates a new {@link HuggingFaceChatCompletionModel} by copying properties from an existing model,
* replacing the {@code modelId} in the service settings with the one from the given {@link UnifiedCompletionRequest},
* if present. If the request does not specify a model ID, the original value is retained.
*
* @param model the original model to copy from
* @param request the request potentially containing an overridden model ID
* @return a new {@link HuggingFaceChatCompletionModel} with updated service settings
*/
public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) {
var originalModelServiceSettings = model.getServiceSettings();
var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings(
request.model() != null ? request.model() : originalModelServiceSettings.modelId(),
originalModelServiceSettings.uri(),
originalModelServiceSettings.rateLimitSettings()
);
return new HuggingFaceChatCompletionModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
overriddenServiceSettings,
model.getSecretSettings()
);
}
public HuggingFaceChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
HuggingFaceChatCompletionServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets)
);
}
HuggingFaceChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
HuggingFaceChatCompletionServiceSettings serviceSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
new ModelSecrets(secretSettings),
serviceSettings,
secretSettings
);
}
@Override
public HuggingFaceChatCompletionServiceSettings getServiceSettings() {
return (HuggingFaceChatCompletionServiceSettings) super.getServiceSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
@Override
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
return creator.create(this);
}
@Override
public Integer getTokenLimit() {
throw new UnsupportedOperationException("Token Limit for chat completion is sent in request and not retrieved from the model");
}
}

View File

@ -0,0 +1,172 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
/**
* Settings for the Hugging Face chat completion service.
* <p>
* This class contains the settings required to configure a Hugging Face chat completion service, including the model ID, URL, maximum input
* tokens, and rate limit settings.
* </p>
*/
public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject
implements
ServiceSettings,
HuggingFaceRateLimitServiceSettings {
public static final String NAME = "hugging_face_completion_service_settings";
// At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess
// 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
/**
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings.
* @param map the map of settings
* @param context the context for parsing the settings
* @return a new instance of {@link HuggingFaceChatCompletionServiceSettings}
*/
public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
var uri = extractUri(map, URL, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new HuggingFaceChatCompletionServiceSettings(modelId, uri, rateLimitSettings);
}
private final String modelId;
private final URI uri;
private final RateLimitSettings rateLimitSettings;
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) {
this(modelId, createUri(url), rateLimitSettings);
}
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) {
this.modelId = modelId;
this.uri = uri;
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
/**
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a stream input.
* @param in the stream input
* @throws IOException if an I/O error occurs
*/
public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException {
this.modelId = in.readOptionalString();
this.uri = createUri(in.readString());
this.rateLimitSettings = new RateLimitSettings(in);
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public URI uri() {
return uri;
}
@Override
public String modelId() {
return modelId;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
if (modelId != null) {
builder.field(MODEL_ID, modelId);
}
builder.field(URL, uri.toString());
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelId);
out.writeString(uri.toString());
rateLimitSettings.writeTo(out);
}
@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
HuggingFaceChatCompletionServiceSettings that = (HuggingFaceChatCompletionServiceSettings) object;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(uri, that.uri)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(modelId, uri, rateLimitSettings);
}
}

View File

@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
/**
* This class is responsible for creating Hugging Face chat completions HTTP requests.
* It handles the preparation of the HTTP request with the necessary headers and body.
*/
public class HuggingFaceUnifiedChatCompletionRequest implements Request {
private final HuggingFaceAccount account;
private final HuggingFaceChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
public HuggingFaceUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
this.account = HuggingFaceAccount.of(model);
this.model = Objects.requireNonNull(model);
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
}
/**
* Creates an HTTP request to the Hugging Face API for chat completions.
* The request includes the necessary headers and the input data as a JSON entity.
*
* @return an HttpRequest object containing the HTTP POST request
*/
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(getURI());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
return new HttpRequest(httpPost, getInferenceEntityId());
}
public URI getURI() {
return account.uri();
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public Request truncate() {
// Truncation is not applicable for chat completion requests
return this;
}
@Override
public boolean[] getTruncationInfo() {
// Truncation is not applicable for chat completion requests
return null;
}
@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import java.io.IOException;
import java.util.Objects;
public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";
private static final String MAX_TOKENS_FIELD = "max_tokens";
private final UnifiedChatInput unifiedChatInput;
private final HuggingFaceChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
if (model.getServiceSettings().modelId() != null) {
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
}
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}
builder.endObject();
return builder;
}
}

View File

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request;
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
@ -24,25 +24,35 @@ import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
public class HuggingFaceInferenceRequest implements Request {
/**
* This class is responsible for creating Hugging Face embeddings HTTP requests.
* It handles the truncation of input data and prepares the HTTP request with the necessary headers and body.
*/
public class HuggingFaceEmbeddingsRequest implements Request {
private final Truncator truncator;
private final HuggingFaceAccount account;
private final Truncator.TruncationResult truncationResult;
private final HuggingFaceModel model;
public HuggingFaceInferenceRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
public HuggingFaceEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
this.truncator = Objects.requireNonNull(truncator);
this.account = HuggingFaceAccount.of(model);
this.truncationResult = Objects.requireNonNull(input);
this.model = Objects.requireNonNull(model);
}
/**
* Creates an HTTP request to the Hugging Face API for embeddings.
* The request includes the necessary headers and the input data as a JSON entity.
*
* @return an HttpRequest object containing the HTTP POST request
*/
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new HuggingFaceInferenceRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
Strings.toString(new HuggingFaceEmbeddingsRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
@ -64,7 +74,7 @@ public class HuggingFaceInferenceRequest implements Request {
public Request truncate() {
var truncateResult = truncator.truncate(truncationResult.input());
return new HuggingFaceInferenceRequest(truncator, truncateResult, model);
return new HuggingFaceEmbeddingsRequest(truncator, truncateResult, model);
}
@Override

View File

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request;
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
@ -14,11 +14,15 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record HuggingFaceInferenceRequestEntity(List<String> inputs) implements ToXContentObject {
/**
* This class represents the request entity for Hugging Face embeddings.
* It contains a list of input strings that will be used to generate embeddings.
*/
public record HuggingFaceEmbeddingsRequestEntity(List<String> inputs) implements ToXContentObject {
private static final String INPUTS_FIELD = "inputs";
public HuggingFaceInferenceRequestEntity {
public HuggingFaceEmbeddingsRequestEntity {
Objects.requireNonNull(inputs);
}

View File

@ -21,6 +21,9 @@ public class HuggingFaceErrorResponseEntity extends ErrorResponse {
}
/**
* Represents a structured error response specifically for non-streaming operations
* using HuggingFace APIs. This is separate from streaming error responses,
* which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity.
* An example error response for invalid auth would look like
* <code>
* {

View File

@ -29,6 +29,7 @@ import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Flow;
import java.util.function.Function;
import static org.elasticsearch.core.Strings.format;
@ -37,6 +38,14 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
}
public OpenAiUnifiedChatCompletionResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction
) {
super(requestType, parseFunction, errorParseFunction);
}
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
@ -59,7 +68,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
@ -67,7 +76,11 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
}
}
private static Exception buildMidStreamError(Request request, String message, Exception e) {
protected static String createErrorType(ErrorResponse errorResponse) {
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
}
protected Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
@ -88,7 +101,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
createErrorType(errorResponse),
"stream_error"
);
}

View File

@ -250,7 +250,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
}
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(

View File

@ -0,0 +1,131 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface;
import org.apache.http.HttpResponse;
import org.apache.http.StatusLine;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler(
"chat completions",
(a, b) -> mock()
);
public void testFailValidationWithAllFields() throws IOException {
var responseJson = """
{
"error": "a message",
"type": "validation"
}
""";
var errorJson = invalidResponseJson(responseJson);
assertThat(errorJson, is("""
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
inference entity id [id] status [500]. \
Error message: [a message]",\
"type":"hugging_face_error"}}"""));
}
public void testFailValidationWithoutOptionalFields() throws IOException {
var responseJson = """
{
"error": "a message"
}
""";
var errorJson = invalidResponseJson(responseJson);
assertThat(errorJson, is("""
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
inference entity id [id] status [500]. \
Error message: [a message]","type":"hugging_face_error"}}"""));
}
public void testFailValidationWithInvalidJson() throws IOException {
var responseJson = """
what? this isn't a json
""";
var errorJson = invalidResponseJson(responseJson);
assertThat(errorJson, is("""
{"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\
[500]","type":"ErrorResponse"}}"""));
}
private String invalidResponseJson(String responseJson) throws IOException {
var exception = invalidResponse(responseJson);
assertThat(exception, isA(RetryException.class));
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
}
private Exception invalidResponse(String responseJson) {
return expectThrows(
RetryException.class,
() -> responseHandler.validateResponse(
mock(),
mock(),
mockRequest(),
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);
}
private static Request mockRequest() {
var request = mock(Request.class);
when(request.getInferenceEntityId()).thenReturn("id");
when(request.isStreaming()).thenReturn(true);
return request;
}
private static HttpResponse mock500Response() {
int statusCode = 500;
var statusLine = mock(StatusLine.class);
when(statusLine.getStatusCode()).thenReturn(statusCode);
var response = mock(HttpResponse.class);
when(response.getStatusLine()).thenReturn(statusLine);
return response;
}
private String toJson(UnifiedChatCompletionException e) throws IOException {
try (var builder = XContentFactory.jsonBuilder()) {
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
}
}
}

View File

@ -12,6 +12,7 @@ package org.elasticsearch.xpack.inference.services.huggingface;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
@ -29,20 +30,29 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
@ -53,14 +63,19 @@ import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@ -74,7 +89,12 @@ import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class HuggingFaceServiceTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
@ -175,6 +195,438 @@ public class HuggingFaceServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel() throws IOException {
var url = "url";
var model = "model";
var secret = "secret";
try (var service = createHuggingFaceService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
var completionsModel = (HuggingFaceChatCompletionModel) m;
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
assertThat(completionsModel.getServiceSettings().modelId(), is(model));
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
}, exception -> fail("Unexpected exception: " + exception));
service.parseRequestConfig(
"id",
TaskType.COMPLETION,
getRequestConfigMap(
HuggingFaceChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model),
getSecretSettingsMap(secret)
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_WithoutModelId() throws IOException {
var url = "url";
var secret = "secret";
try (var service = createHuggingFaceService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
var completionsModel = (HuggingFaceChatCompletionModel) m;
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
assertNull(completionsModel.getServiceSettings().modelId());
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
}, exception -> fail("Unexpected exception: " + exception));
service.parseRequestConfig(
"id",
TaskType.COMPLETION,
getRequestConfigMap(getServiceSettingsMap(url), getSecretSettingsMap(secret)),
modelVerificationListener
);
}
}
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION);
try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
);
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}
verify(sender, times(1)).close();
verifyNoMoreInteractions(factory);
verifyNoMoreInteractions(sender);
}
public void testUnifiedCompletionInfer() throws Exception {
// The escapes are because the streaming response must be on a single line
String responseJson = """
data: {\
"id":"12345",\
"object":"chat.completion.chunk",\
"created":123456789,\
"model":"gpt-4o-mini",\
"system_fingerprint": "123456789",\
"choices":[\
{\
"index":0,\
"delta":{\
"content":"hello, world"\
},\
"logprobs":null,\
"finish_reason":"stop"\
}\
],\
"usage":{\
"prompt_tokens": 16,\
"completion_tokens": 28,\
"total_tokens": 44,\
"prompt_tokens_details": {\
"cached_tokens": 0,\
"audio_tokens": 0\
},\
"completion_tokens_details": {\
"reasoning_tokens": 0,\
"audio_tokens": 0,\
"accepted_prediction_tokens": 0,\
"rejected_prediction_tokens": 0\
}\
}\
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
{"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """
"model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """
"usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}""");
}
}
public void testUnifiedCompletionNonStreamingError() throws Exception {
String responseJson = """
{
"error": "Model not found."
}
""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
var latch = new CountDownLatch(1);
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
try (var builder = XContentFactory.jsonBuilder()) {
var t = unwrapCause(e);
assertThat(t, isA(UnifiedChatCompletionException.class));
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is("""
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [Model not found.]",\
"type":"hugging_face_error"\
}}"""));
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}), latch::countDown)
);
assertTrue(latch.await(30, TimeUnit.SECONDS));
}
}
public void testMidStreamUnifiedCompletionError() throws Exception {
String responseJson = """
event: error
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\
If it should be supported, please open an issue.","http_status_code":422}}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"code":"422",\
"message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \
cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \
JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\
"type":"hugging_face_error"\
}}""");
}
public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception {
String responseJson = """
event: error
data: {"error":{"http_status_code":422}}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"code":"422",\
"message":"Received an error response for request from inference entity id [id]. Error message: \
[unknown]",\
"type":"hugging_face_error"\
}}""");
}
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception {
String responseJson = """
event: error
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \
Outlines.\\nIf it should be supported, please open an issue."}}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"message":"Received an error response for request from inference entity id [id]. Error message: \
[Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\
by Outlines.\\nIf it should be supported, please open an issue.]",\
"type":"hugging_face_error"\
}}""");
}
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception {
String responseJson = """
event: error
data: {"error":{}}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"message":"Received an error response for request from inference entity id [id]. Error message: \
[unknown]",\
"type":"hugging_face_error"\
}}""");
}
public void testUnifiedCompletionMalformedError() throws Exception {
String responseJson = """
data: { invalid json }
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"code":"bad_request",\
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
"type":"x_content_parse_exception"\
}}""");
}
private void testStreamError(String expectedResponse) throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
e = unwrapCause(e);
assertThat(e, isA(UnifiedChatCompletionException.class));
try (var builder = XContentFactory.jsonBuilder()) {
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is(expectedResponse));
}
});
}
}
public void testInfer_StreamRequest() throws Exception {
String responseJson = """
data: {\
"id":"12345",\
"object":"chat.completion.chunk",\
"created":123456789,\
"model":"gpt-4o-mini",\
"system_fingerprint": "123456789",\
"choices":[\
{\
"index":0,\
"delta":{\
"content":"hello, world"\
},\
"logprobs":null,\
"finish_reason":null\
}\
]\
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
private InferenceEventsAssertion streamCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
public void testInfer_StreamRequest_ErrorResponse() throws Exception {
String responseJson = """
{
"error": {
"message": "You didn't provide an API key..."
}
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
assertThat(
e.getMessage(),
equalTo(
"Received an authentication error status code for request from inference entity id [id] status [401]. "
+ "Error message: [You didn't provide an API key...]"
)
);
}
public void testInfer_StreamRequestRetry() throws Exception {
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
{
"error": {
"message": "server busy"
}
}"""));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
data: {\
"id":"12345",\
"object":"chat.completion.chunk",\
"created":123456789,\
"model":"gpt-4o-mini",\
"system_fingerprint": "123456789",\
"choices":[\
{\
"index":0,\
"delta":{\
"content":"hello, world"\
},\
"logprobs":null,\
"finish_reason":null\
}\
]\
}
"""));
streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
public void testSupportsStreaming() throws IOException {
try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) {
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
assertFalse(service.canStream(TaskType.ANY));
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
try (var service = createHuggingFaceService()) {
var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret"));
@ -258,6 +710,25 @@ public class HuggingFaceServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException {
try (var service = createHuggingFaceService()) {
var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret"));
var model = service.parsePersistedConfigWithSecrets(
"id",
TaskType.COMPLETION,
persistedConfig.config(),
persistedConfig.secrets()
);
assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class));
var chatCompletionModel = (HuggingFaceChatCompletionModel) model;
assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is("url"));
assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret"));
}
}
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createHuggingFaceService()) {
var persistedConfig = getPersistedConfigMap(
@ -821,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
{
"service": "hugging_face",
"name": "Hugging Face",
"task_types": ["text_embedding", "sparse_embedding"],
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
"configurations": {
"api_key": {
"description": "API Key for the provider you're connecting to.",
@ -830,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
"sensitive": true,
"updatable": true,
"type": "str",
"supported_task_types": ["text_embedding", "sparse_embedding"]
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
},
"rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.",
@ -839,17 +1310,16 @@ public class HuggingFaceServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding", "sparse_embedding"]
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
},
"url": {
"default_value": "https://api.openai.com/v1/embeddings",
"description": "The URL endpoint to use for the requests.",
"label": "URL",
"required": true,
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "sparse_embedding"]
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
}
}
}

View File

@ -24,10 +24,13 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsT
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
import org.junit.After;
@ -38,6 +41,7 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
@ -425,4 +429,107 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
}
}
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "chat.completion",
"id": "",
"created": 1745855316,
"model": "/repository",
"system_fingerprint": "3.2.3-sha-a1f3ebe",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello there, how may I assist you today?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 8,
"completion_tokens": 50,
"total_tokens": 58
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
assertChatCompletionRequest();
}
}
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
var settings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"invalid_field": "unexpected"
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
sender,
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]")
);
assertChatCompletionRequest();
}
}
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
var actionCreator = new HuggingFaceActionCreator(sender, threadPool);
var action = actionCreator.create(model);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
return listener;
}
private void assertChatCompletionRequest() throws IOException {
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
);
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertThat(requestMap.get("stream"), is(false));
}
}

View File

@ -0,0 +1,243 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.action;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.COMPLETION_HANDLER;
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.USER_ROLE;
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
public class HuggingFaceChatCompletionActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "result content"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content"))));
assertThat(webServer.requests(), hasSize(1));
MockRequest request = webServer.requests().get(0);
assertNull(request.getUri().getQuery());
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()));
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(request.getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertThat(requestMap.get("stream"), is(false));
}
}
public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
try (var sender = mock(Sender.class)) {
var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", sender));
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
}
}
public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("failed"));
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
var sender = mock(Sender.class);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("Failed to send hugging face chat completions request. Cause: failed"));
}
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello there, how may I assist you today?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("hugging face chat completions only accepts 1 input"));
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}
}
private ExecutableAction createAction(String url, Sender sender) {
var model = createCompletionModel(url, "secret", "model");
var manager = new GenericRequestManager<>(
threadPool,
model,
COMPLETION_HANDLER,
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
var errorMessage = constructFailedToSendRequestMessage("hugging face chat completions");
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "hugging face chat completions");
}
}

View File

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.completion;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.List;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
public class HuggingFaceChatCompletionModelTests extends ESTestCase {
public void testThrowsURISyntaxException_ForInvalidUrl() {
var thrownException = expectThrows(IllegalArgumentException.class, () -> createCompletionModel("^^", "secret", "id"));
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
}
public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) {
return new HuggingFaceChatCompletionModel(
"id",
TaskType.COMPLETION,
"service",
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) {
return new HuggingFaceChatCompletionModel(
"id",
TaskType.CHAT_COMPLETION,
"service",
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
var model = createCompletionModel("url", "api_key", "model_name");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
var model = createCompletionModel("url", "api_key", null);
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
var model = createCompletionModel("url", "api_key", null);
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null,
null,
null,
null,
null,
null,
null
);
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
assertNull(overriddenModel.getServiceSettings().modelId());
}
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
var model = createCompletionModel("url", "api_key", "model_name");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null, // not overriding model
null,
null,
null,
null,
null,
null
);
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
}
}

View File

@ -0,0 +1,271 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<
HuggingFaceChatCompletionServiceSettings> {
public static final String MODEL_ID = "some model";
public static final String CORRECT_URL = "https://www.elastic.co";
public static final int RATE_LIMIT = 2;
public void testFromMap_AllFields_Success() {
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.MODEL_ID,
MODEL_ID,
ServiceFields.URL,
CORRECT_URL,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
serviceSettings,
is(
new HuggingFaceChatCompletionServiceSettings(
MODEL_ID,
ServiceUtils.createUri(CORRECT_URL),
new RateLimitSettings(RATE_LIMIT)
)
)
);
}
public void testFromMap_MissingModelId_Success() {
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.URL,
CORRECT_URL,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
);
assertThat(
serviceSettings,
is(new HuggingFaceChatCompletionServiceSettings(null, ServiceUtils.createUri(CORRECT_URL), new RateLimitSettings(RATE_LIMIT)))
);
}
public void testFromMap_MissingRateLimit_Success() {
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), null)));
}
public void testFromMap_MissingUrl_ThrowsException() {
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.MODEL_ID,
MODEL_ID,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
thrownException.getMessage(),
containsString(
Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [url];", ServiceFields.URL)
)
);
}
public void testFromMap_EmptyUrl_ThrowsException() {
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.MODEL_ID,
MODEL_ID,
ServiceFields.URL,
"",
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
ServiceFields.URL
)
)
);
}
public void testFromMap_InvalidUrl_ThrowsException() {
String invalidUrl = "https://www.elastic^^co";
var thrownException = expectThrows(
ValidationException.class,
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.MODEL_ID,
MODEL_ID,
ServiceFields.URL,
invalidUrl,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
)
);
assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]",
invalidUrl,
ServiceFields.URL
)
)
);
}
public void testToXContent_WritesAllValues() throws IOException {
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.MODEL_ID,
MODEL_ID,
ServiceFields.URL,
CORRECT_URL,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
)
),
ConfigurationParseContext.PERSISTENT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
var expected = XContentHelper.stripWhitespace("""
{
"model_id": "some model",
"url": "https://www.elastic.co",
"rate_limit": {
"requests_per_minute": 2
}
}
""");
assertThat(xContentResult, is(expected));
}
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)),
ConfigurationParseContext.PERSISTENT
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
var expected = XContentHelper.stripWhitespace("""
{
"url": "https://www.elastic.co",
"rate_limit": {
"requests_per_minute": 3000
}
}
""");
assertThat(xContentResult, is(expected));
}
@Override
protected Writeable.Reader<HuggingFaceChatCompletionServiceSettings> instanceReader() {
return HuggingFaceChatCompletionServiceSettings::new;
}
@Override
protected HuggingFaceChatCompletionServiceSettings createTestInstance() {
return createRandomWithNonNullUrl();
}
@Override
protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceChatCompletionServiceSettings instance)
throws IOException {
return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl);
}
@Override
protected HuggingFaceChatCompletionServiceSettings mutateInstanceForVersion(
HuggingFaceChatCompletionServiceSettings instance,
TransportVersion version
) {
return instance;
}
private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() {
return createRandom(randomAlphaOfLength(15));
}
private static HuggingFaceChatCompletionServiceSettings createRandom(String url) {
var modelId = randomAlphaOfLength(8);
return new HuggingFaceChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom());
}
public static Map<String, Object> getServiceSettingsMap(String url, String model) {
var map = new HashMap<String, Object>();
map.put(ServiceFields.URL, url);
map.put(ServiceFields.MODEL_ID, model);
return map;
}
}

View File

@ -12,16 +12,17 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequestEntity;
import java.io.IOException;
import java.util.List;
import static org.hamcrest.CoreMatchers.is;
public class HuggingFaceInferenceRequestEntityTests extends ESTestCase {
public class HuggingFaceEmbeddingsRequestEntityTests extends ESTestCase {
public void testXContent() throws IOException {
var entity = new HuggingFaceInferenceRequestEntity(List.of("abc"));
var entity = new HuggingFaceEmbeddingsRequestEntity(List.of("abc"));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
import java.io.IOException;
import java.net.URI;
@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class HuggingFaceInferenceRequestTests extends ESTestCase {
public class HuggingFaceEmbeddingsRequestTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testCreateRequest() throws URISyntaxException, IOException {
var huggingFaceRequest = createRequest("www.google.com", "secret", "abc");
@ -67,9 +68,9 @@ public class HuggingFaceInferenceRequestTests extends ESTestCase {
assertTrue(truncatedRequest.getTruncationInfo()[0]);
}
public static HuggingFaceInferenceRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
public static HuggingFaceEmbeddingsRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
return new HuggingFaceInferenceRequest(
return new HuggingFaceEmbeddingsRequest(
TruncatorTests.createTruncator(),
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
HuggingFaceEmbeddingsModelTests.createModel(url, apiKey)

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
import java.io.IOException;
import java.util.ArrayList;
import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals;
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase {
private static final String ROLE = "user";
public void testModelUserFieldsSerialization() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(message);
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint");
HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
],
"model": "test-endpoint",
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
}
}
""";
assertJsonEquals(jsonString, expectedJson);
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase {
public void testCreateRequest_WithStreaming() throws IOException {
var request = createRequest("url", "secret", randomAlphaOfLength(15), "model", true);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap.get("stream"), is(true));
}
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
String input = randomAlphaOfLength(5);
var request = createRequest("url", "secret", input, "model", true);
var truncatedRequest = request.truncate();
assertThat(request.getURI().toString(), is("url"));
var httpRequest = truncatedRequest.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(5));
// We do not truncate for Hugging Face chat completions
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
assertThat(requestMap.get("model"), is("model"));
assertThat(requestMap.get("n"), is(1));
assertTrue((Boolean) requestMap.get("stream"));
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
}
public void testTruncationInfo_ReturnsNull() {
var request = createRequest("url", "secret", randomAlphaOfLength(5), "model", true);
assertNull(request.getTruncationInfo());
}
public static HuggingFaceUnifiedChatCompletionRequest createRequest(String url, String apiKey, String input, @Nullable String model) {
return createRequest(url, apiKey, input, model, false);
}
public static HuggingFaceUnifiedChatCompletionRequest createRequest(
@Nullable String url,
String apiKey,
String input,
@Nullable String model,
boolean stream
) {
var chatCompletionModel = HuggingFaceChatCompletionModelTests.createCompletionModel(url, apiKey, model);
return new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
}
}

View File

@ -19,6 +19,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatComple
import java.io.IOException;
import java.util.List;
import static org.hamcrest.Matchers.is;
public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
public void testJsonLiteral() {
@ -182,6 +184,73 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
}
}
public void testJsonNullFunctionName() throws IOException {
String json = """
{
"object": "chat.completion.chunk",
"id": "",
"created": 1746800254,
"model": "/repository",
"system_fingerprint": "3.2.3-sha-a1f3ebe",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a",
"type": "function",
"function": {
"name": null,
"arguments": " \\\""
}
}
]
},
"logprobs": null,
"finish_reason": null
}
],
"usage": null
}
""";
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, json)) {
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser
.parse(parser);
// Assertions to verify the parsed object
assertThat(chunk.id(), is(""));
assertThat(chunk.model(), is("/repository"));
assertThat(chunk.object(), is("chat.completion.chunk"));
assertNull(chunk.usage());
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
assertThat(choices.size(), is(1));
// First choice assertions
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0);
assertNull(firstChoice.delta().content());
assertNull(firstChoice.delta().refusal());
assertThat(firstChoice.delta().role(), is("assistant"));
assertNull(firstChoice.finishReason());
assertThat(firstChoice.index(), is(0));
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = firstChoice.delta()
.toolCalls();
assertThat(toolCalls.size(), is(1));
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
assertThat(toolCall.index(), is(0));
assertThat(toolCall.id(), is("8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a"));
assertThat(toolCall.type(), is("function"));
assertNull(toolCall.function().name());
assertThat(toolCall.function().arguments(), is(" \""));
}
}
public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException {
// Generate random values for the JSON fields
int toolCallIndex = randomIntBetween(0, 10);