Add Hugging Face Chat Completion support to Inference Plugin (#127254)
* Add Hugging Face Chat Completion support to Inference Plugin * Add support for streaming chat completion task for HuggingFace * [CI] Auto commit changes from spotless * Add support for non-streaming completion task for HuggingFace * Remove RequestManager for HF Chat Completion Task * Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests * Refactored Hugging Face Action Creator, added Unit Tests * Add Hugging Face Server Test * [CI] Auto commit changes from spotless * Removed parameters from media type for Chat Completion Request and unit tests * Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT * Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService * Update minimal supported version and add Hugging Face transport version constants * Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests * Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests * Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param * Removed getFirst() method calls for backport compatibility * Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing * Refactored tests to use stripWhitespace method for readability * Refactored javadoc for HuggingFaceService * Renamed HF chat completion TransportVersion constant names * Added random string generation in unit test * Refactored javadocs for HuggingFace requests * Refactored tests to reduce duplication * Added changelog file * Add HuggingFaceChatCompletionResponseHandler and associated tests * Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types * Refactor HuggingFace error handling to improve response structure and add streaming support * Allowing null function name for hugging face models --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
This commit is contained in:
parent
54f26680ea
commit
d1ad917855
|
@ -0,0 +1,5 @@
|
|||
pr: 127254
|
||||
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -175,6 +175,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
|
||||
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
|
||||
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
|
||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
|
||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||
|
@ -255,6 +256,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
|
||||
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
|
||||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -123,7 +123,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
|
||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.COMPLETION);
|
||||
assertThat(services.size(), equalTo(10));
|
||||
assertThat(services.size(), equalTo(11));
|
||||
|
||||
var providers = providers(services);
|
||||
|
||||
|
@ -140,7 +140,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"deepseek",
|
||||
"googleaistudio",
|
||||
"openai",
|
||||
"streaming_completion_test_service"
|
||||
"streaming_completion_test_service",
|
||||
"hugging_face"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
@ -148,11 +149,14 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
|
||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||
assertThat(services.size(), equalTo(4));
|
||||
assertThat(services.size(), equalTo(5));
|
||||
|
||||
var providers = providers(services);
|
||||
|
||||
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
|
||||
);
|
||||
}
|
||||
|
||||
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
||||
|
|
|
@ -78,6 +78,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.Goog
|
|||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
||||
|
@ -357,6 +358,13 @@ public class InferenceNamedWriteablesProvider {
|
|||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
HuggingFaceChatCompletionServiceSettings.NAME,
|
||||
HuggingFaceChatCompletionServiceSettings::new
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
|
||||
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
/**
|
||||
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
|
||||
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
|
||||
*/
|
||||
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
|
||||
|
||||
private static final String HUGGING_FACE_ERROR = "hugging_face_error";
|
||||
|
||||
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
|
||||
assert request.isStreaming() : "Only streaming requests support this format";
|
||||
var responseStatusCode = result.response().getStatusLine().getStatusCode();
|
||||
if (request.isStreaming()) {
|
||||
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
|
||||
var restStatus = toRestStatus(responseStatusCode);
|
||||
return errorResponse instanceof HuggingFaceErrorResponseEntity
|
||||
? new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
errorMessage,
|
||||
HUGGING_FACE_ERROR,
|
||||
restStatus.name().toLowerCase(Locale.ROOT)
|
||||
)
|
||||
: new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
errorMessage,
|
||||
createErrorType(errorResponse),
|
||||
restStatus.name().toLowerCase(Locale.ROOT)
|
||||
);
|
||||
} else {
|
||||
return super.buildError(message, request, result, errorResponse);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exception buildMidStreamError(Request request, String message, Exception e) {
|
||||
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
|
||||
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
|
||||
return new UnifiedChatCompletionException(
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
format(
|
||||
"%s for request from inference entity id [%s]. Error message: [%s]",
|
||||
SERVER_ERROR_OBJECT,
|
||||
request.getInferenceEntityId(),
|
||||
errorResponse.getErrorMessage()
|
||||
),
|
||||
HUGGING_FACE_ERROR,
|
||||
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
|
||||
);
|
||||
} else if (e != null) {
|
||||
return UnifiedChatCompletionException.fromThrowable(e);
|
||||
} else {
|
||||
return new UnifiedChatCompletionException(
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
|
||||
createErrorType(errorResponse),
|
||||
"stream_error"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
|
||||
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
|
||||
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
|
||||
: null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a structured error response specifically for streaming operations
|
||||
* using HuggingFace APIs. This is separate from non-streaming error responses,
|
||||
* which are handled by {@link HuggingFaceErrorResponseEntity}.
|
||||
* An example error response for failed field validation for streaming operation would look like
|
||||
* <code>
|
||||
* {
|
||||
* "error": "Input validation error: cannot compile regex from schema",
|
||||
* "http_status_code": 422
|
||||
* }
|
||||
* </code>
|
||||
*/
|
||||
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
|
||||
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
|
||||
HUGGING_FACE_ERROR,
|
||||
true,
|
||||
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
|
||||
);
|
||||
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
HUGGING_FACE_ERROR,
|
||||
true,
|
||||
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
|
||||
);
|
||||
|
||||
static {
|
||||
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
|
||||
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));
|
||||
|
||||
ERROR_PARSER.declareObjectOrNull(
|
||||
ConstructingObjectParser.optionalConstructorArg(),
|
||||
ERROR_BODY_PARSER,
|
||||
null,
|
||||
new ParseField("error")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a streaming HuggingFace error response from a JSON string.
|
||||
*
|
||||
* @param response the raw JSON string representing an error
|
||||
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
|
||||
*/
|
||||
private static ErrorResponse fromString(String response) {
|
||||
try (
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(XContentParserConfiguration.EMPTY, response)
|
||||
) {
|
||||
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private final Integer httpStatusCode;
|
||||
|
||||
StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
|
||||
super(errorMessage);
|
||||
this.httpStatusCode = httpStatusCode;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Integer httpStatusCode() {
|
||||
return httpStatusCode;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -9,17 +9,18 @@ package org.elasticsearch.xpack.inference.services.huggingface;
|
|||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class HuggingFaceModel extends Model {
|
||||
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
|
||||
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
|
||||
private final SecureString apiKey;
|
||||
|
||||
|
@ -38,6 +39,16 @@ public abstract class HuggingFaceModel extends Model {
|
|||
return rateLimitServiceSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int rateLimitGroupingHash() {
|
||||
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitServiceSettings.rateLimitSettings();
|
||||
}
|
||||
|
||||
public SecureString apiKey() {
|
||||
return apiKey;
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
@ -64,7 +64,7 @@ public class HuggingFaceRequestManager extends BaseRequestManager {
|
|||
) {
|
||||
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
|
||||
var truncatedInput = truncate(docsInput, model.getTokenLimit());
|
||||
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
|
||||
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
|
||||
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
|
|
@ -26,15 +26,21 @@ import org.elasticsearch.inference.TaskType;
|
|||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -42,16 +48,29 @@ import java.util.EnumSet;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
|
||||
|
||||
/**
|
||||
* This class is responsible for managing the Hugging Face inference service.
|
||||
* It manages model creation, as well as chunked, non-chunked, and unified completion inference.
|
||||
*/
|
||||
public class HuggingFaceService extends HuggingFaceBaseService {
|
||||
public static final String NAME = "hugging_face";
|
||||
|
||||
private static final String SERVICE_NAME = "Hugging Face";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
|
||||
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
TaskType.SPARSE_EMBEDDING,
|
||||
TaskType.COMPLETION,
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
|
||||
"hugging face chat completion",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
|
||||
public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
|
@ -78,6 +97,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
context
|
||||
);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
@ -139,7 +166,29 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
throwUnsupportedUnifiedCompletionOperation(NAME);
|
||||
if (model instanceof HuggingFaceChatCompletionModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model;
|
||||
var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest());
|
||||
var manager = new GenericRequestManager<>(
|
||||
getServiceComponents().threadPool(),
|
||||
overriddenModel,
|
||||
UNIFIED_CHAT_COMPLETION_HANDLER,
|
||||
unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
|
||||
UnifiedChatInput.class
|
||||
);
|
||||
var errorMessage = HuggingFaceActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
|
||||
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
|
||||
|
||||
action.execute(inputs, timeout, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -149,7 +198,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
|
||||
@Override
|
||||
public EnumSet<TaskType> supportedTaskTypes() {
|
||||
return supportedTaskTypes;
|
||||
return SUPPORTED_TASK_TYPES;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -167,14 +216,15 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
return configuration.getOrCompute();
|
||||
}
|
||||
|
||||
private Configuration() {}
|
||||
|
||||
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
|
||||
() -> {
|
||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
||||
|
||||
configurationMap.put(
|
||||
URL,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
|
||||
.setDescription("The URL endpoint to use for the requests.")
|
||||
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.")
|
||||
.setLabel("URL")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
|
@ -183,12 +233,12 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
.build()
|
||||
);
|
||||
|
||||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
|
||||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
|
||||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
|
||||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
|
||||
|
||||
return new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||
.setName(SERVICE_NAME)
|
||||
.setTaskTypes(supportedTaskTypes)
|
||||
.setTaskTypes(SUPPORTED_TASK_TYPES)
|
||||
.setConfigurations(configurationMap)
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -7,16 +7,26 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -26,6 +36,13 @@ import static org.elasticsearch.core.Strings.format;
|
|||
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type.
|
||||
*/
|
||||
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||
|
||||
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
|
||||
static final String USER_ROLE = "user";
|
||||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
|
||||
"hugging face completion",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
private final Sender sender;
|
||||
private final ServiceComponents serviceComponents;
|
||||
|
||||
|
@ -46,11 +63,7 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
|||
serviceComponents.truncator(),
|
||||
serviceComponents.threadPool()
|
||||
);
|
||||
var errorMessage = format(
|
||||
"Failed to send Hugging Face %s request from inference entity id [%s]",
|
||||
"text embeddings",
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
|
||||
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
||||
}
|
||||
|
||||
|
@ -63,11 +76,25 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
|||
serviceComponents.truncator(),
|
||||
serviceComponents.threadPool()
|
||||
);
|
||||
var errorMessage = format(
|
||||
"Failed to send Hugging Face %s request from inference entity id [%s]",
|
||||
"ELSER",
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId());
|
||||
return new SenderExecutableAction(sender, requestCreator, errorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(HuggingFaceChatCompletionModel model) {
|
||||
var manager = new GenericRequestManager<>(
|
||||
serviceComponents.threadPool(),
|
||||
model,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
|
||||
var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
|
||||
}
|
||||
|
||||
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
|
||||
|
@ -15,4 +16,6 @@ public interface HuggingFaceActionVisitor {
|
|||
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
||||
|
||||
ExecutableAction create(HuggingFaceElserModel model);
|
||||
|
||||
ExecutableAction create(HuggingFaceChatCompletionModel model);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class HuggingFaceChatCompletionModel extends HuggingFaceModel {
|
||||
|
||||
/**
|
||||
* Creates a new {@link HuggingFaceChatCompletionModel} by copying properties from an existing model,
|
||||
* replacing the {@code modelId} in the service settings with the one from the given {@link UnifiedCompletionRequest},
|
||||
* if present. If the request does not specify a model ID, the original value is retained.
|
||||
*
|
||||
* @param model the original model to copy from
|
||||
* @param request the request potentially containing an overridden model ID
|
||||
* @return a new {@link HuggingFaceChatCompletionModel} with updated service settings
|
||||
*/
|
||||
public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) {
|
||||
var originalModelServiceSettings = model.getServiceSettings();
|
||||
var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings(
|
||||
request.model() != null ? request.model() : originalModelServiceSettings.modelId(),
|
||||
originalModelServiceSettings.uri(),
|
||||
originalModelServiceSettings.rateLimitSettings()
|
||||
);
|
||||
|
||||
return new HuggingFaceChatCompletionModel(
|
||||
model.getInferenceEntityId(),
|
||||
model.getTaskType(),
|
||||
model.getConfigurations().getService(),
|
||||
overriddenServiceSettings,
|
||||
model.getSecretSettings()
|
||||
);
|
||||
}
|
||||
|
||||
public HuggingFaceChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
HuggingFaceChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
HuggingFaceChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
HuggingFaceChatCompletionServiceSettings serviceSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
|
||||
new ModelSecrets(secretSettings),
|
||||
serviceSettings,
|
||||
secretSettings
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HuggingFaceChatCompletionServiceSettings getServiceSettings() {
|
||||
return (HuggingFaceChatCompletionServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
|
||||
return creator.create(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTokenLimit() {
|
||||
throw new UnsupportedOperationException("Token Limit for chat completion is sent in request and not retrieved from the model");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
|
||||
|
||||
/**
|
||||
* Settings for the Hugging Face chat completion service.
|
||||
* <p>
|
||||
* This class contains the settings required to configure a Hugging Face chat completion service, including the model ID, URL, maximum input
|
||||
* tokens, and rate limit settings.
|
||||
* </p>
|
||||
*/
|
||||
public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject
|
||||
implements
|
||||
ServiceSettings,
|
||||
HuggingFaceRateLimitServiceSettings {
|
||||
|
||||
public static final String NAME = "hugging_face_completion_service_settings";
|
||||
// At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess
|
||||
// 3000 requests per minute
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||
|
||||
/**
|
||||
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings.
|
||||
* @param map the map of settings
|
||||
* @param context the context for parsing the settings
|
||||
* @return a new instance of {@link HuggingFaceChatCompletionServiceSettings}
|
||||
*/
|
||||
public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
var uri = extractUri(map, URL, validationException);
|
||||
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
HuggingFaceService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
return new HuggingFaceChatCompletionServiceSettings(modelId, uri, rateLimitSettings);
|
||||
}
|
||||
|
||||
private final String modelId;
|
||||
private final URI uri;
|
||||
private final RateLimitSettings rateLimitSettings;
|
||||
|
||||
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) {
|
||||
this(modelId, createUri(url), rateLimitSettings);
|
||||
}
|
||||
|
||||
public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) {
|
||||
this.modelId = modelId;
|
||||
this.uri = uri;
|
||||
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a stream input.
|
||||
* @param in the stream input
|
||||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException {
|
||||
this.modelId = in.readOptionalString();
|
||||
this.uri = createUri(in.readString());
|
||||
this.rateLimitSettings = new RateLimitSettings(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI uri() {
|
||||
return uri;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
if (modelId != null) {
|
||||
builder.field(MODEL_ID, modelId);
|
||||
}
|
||||
builder.field(URL, uri.toString());
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalString(modelId);
|
||||
out.writeString(uri.toString());
|
||||
rateLimitSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (this == object) return true;
|
||||
if (object == null || getClass() != object.getClass()) return false;
|
||||
HuggingFaceChatCompletionServiceSettings that = (HuggingFaceChatCompletionServiceSettings) object;
|
||||
return Objects.equals(modelId, that.modelId)
|
||||
&& Objects.equals(uri, that.uri)
|
||||
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, uri, rateLimitSettings);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
/**
|
||||
* This class is responsible for creating Hugging Face chat completions HTTP requests.
|
||||
* It handles the preparation of the HTTP request with the necessary headers and body.
|
||||
*/
|
||||
public class HuggingFaceUnifiedChatCompletionRequest implements Request {
|
||||
|
||||
private final HuggingFaceAccount account;
|
||||
private final HuggingFaceChatCompletionModel model;
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
|
||||
public HuggingFaceUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
|
||||
this.account = HuggingFaceAccount.of(model);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an HTTP request to the Hugging Face API for chat completions.
|
||||
* The request includes the necessary headers and the input data as a JSON entity.
|
||||
*
|
||||
* @return an HttpRequest object containing the HTTP POST request
|
||||
*/
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(getURI());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
public URI getURI() {
|
||||
return account.uri();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
// Truncation is not applicable for chat completion requests
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
// Truncation is not applicable for chat completion requests
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStreaming() {
|
||||
return unifiedChatInput.stream();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
|
||||
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject {
|
||||
|
||||
private static final String MODEL_FIELD = "model";
|
||||
private static final String MAX_TOKENS_FIELD = "max_tokens";
|
||||
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
private final HuggingFaceChatCompletionModel model;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
|
||||
public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
unifiedRequestEntity.toXContent(builder, params);
|
||||
|
||||
if (model.getServiceSettings().modelId() != null) {
|
||||
builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
|
||||
}
|
||||
|
||||
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
|
||||
builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
|
@ -24,25 +24,35 @@ import java.util.Objects;
|
|||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
public class HuggingFaceInferenceRequest implements Request {
|
||||
/**
|
||||
* This class is responsible for creating Hugging Face embeddings HTTP requests.
|
||||
* It handles the truncation of input data and prepares the HTTP request with the necessary headers and body.
|
||||
*/
|
||||
public class HuggingFaceEmbeddingsRequest implements Request {
|
||||
|
||||
private final Truncator truncator;
|
||||
private final HuggingFaceAccount account;
|
||||
private final Truncator.TruncationResult truncationResult;
|
||||
private final HuggingFaceModel model;
|
||||
|
||||
public HuggingFaceInferenceRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
|
||||
public HuggingFaceEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) {
|
||||
this.truncator = Objects.requireNonNull(truncator);
|
||||
this.account = HuggingFaceAccount.of(model);
|
||||
this.truncationResult = Objects.requireNonNull(input);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an HTTP request to the Hugging Face API for embeddings.
|
||||
* The request includes the necessary headers and the input data as a JSON entity.
|
||||
*
|
||||
* @return an HttpRequest object containing the HTTP POST request
|
||||
*/
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(account.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new HuggingFaceInferenceRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
|
||||
Strings.toString(new HuggingFaceEmbeddingsRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||
|
@ -64,7 +74,7 @@ public class HuggingFaceInferenceRequest implements Request {
|
|||
public Request truncate() {
|
||||
var truncateResult = truncator.truncate(truncationResult.input());
|
||||
|
||||
return new HuggingFaceInferenceRequest(truncator, truncateResult, model);
|
||||
return new HuggingFaceEmbeddingsRequest(truncator, truncateResult, model);
|
||||
}
|
||||
|
||||
@Override
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings;
|
||||
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
@ -14,11 +14,15 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public record HuggingFaceInferenceRequestEntity(List<String> inputs) implements ToXContentObject {
|
||||
/**
|
||||
* This class represents the request entity for Hugging Face embeddings.
|
||||
* It contains a list of input strings that will be used to generate embeddings.
|
||||
*/
|
||||
public record HuggingFaceEmbeddingsRequestEntity(List<String> inputs) implements ToXContentObject {
|
||||
|
||||
private static final String INPUTS_FIELD = "inputs";
|
||||
|
||||
public HuggingFaceInferenceRequestEntity {
|
||||
public HuggingFaceEmbeddingsRequestEntity {
|
||||
Objects.requireNonNull(inputs);
|
||||
}
|
||||
|
|
@ -21,6 +21,9 @@ public class HuggingFaceErrorResponseEntity extends ErrorResponse {
|
|||
}
|
||||
|
||||
/**
|
||||
* Represents a structured error response specifically for non-streaming operations
|
||||
* using HuggingFace APIs. This is separate from streaming error responses,
|
||||
* which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity.
|
||||
* An example error response for invalid auth would look like
|
||||
* <code>
|
||||
* {
|
||||
|
|
|
@ -29,6 +29,7 @@ import java.util.Locale;
|
|||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
|
@ -37,6 +38,14 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
super(requestType, parseFunction, OpenAiErrorResponse::fromResponse);
|
||||
}
|
||||
|
||||
public OpenAiUnifiedChatCompletionResponseHandler(
|
||||
String requestType,
|
||||
ResponseParser parseFunction,
|
||||
Function<HttpResult, ErrorResponse> errorParseFunction
|
||||
) {
|
||||
super(requestType, parseFunction, errorParseFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
|
||||
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
|
||||
|
@ -59,7 +68,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
: new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
errorMessage,
|
||||
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
|
||||
createErrorType(errorResponse),
|
||||
restStatus.name().toLowerCase(Locale.ROOT)
|
||||
);
|
||||
} else {
|
||||
|
@ -67,7 +76,11 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
}
|
||||
}
|
||||
|
||||
private static Exception buildMidStreamError(Request request, String message, Exception e) {
|
||||
protected static String createErrorType(ErrorResponse errorResponse) {
|
||||
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
|
||||
}
|
||||
|
||||
protected Exception buildMidStreamError(Request request, String message, Exception e) {
|
||||
var errorResponse = OpenAiErrorResponse.fromString(message);
|
||||
if (errorResponse instanceof OpenAiErrorResponse oer) {
|
||||
return new UnifiedChatCompletionException(
|
||||
|
@ -88,7 +101,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
|
|||
return new UnifiedChatCompletionException(
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
|
||||
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
|
||||
createErrorType(errorResponse),
|
||||
"stream_error"
|
||||
);
|
||||
}
|
||||
|
|
|
@ -250,7 +250,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
|
|||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD));
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
|
||||
PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD));
|
||||
}
|
||||
|
||||
public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.StatusLine;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
|
||||
private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler(
|
||||
"chat completions",
|
||||
(a, b) -> mock()
|
||||
);
|
||||
|
||||
public void testFailValidationWithAllFields() throws IOException {
|
||||
var responseJson = """
|
||||
{
|
||||
"error": "a message",
|
||||
"type": "validation"
|
||||
}
|
||||
""";
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson);
|
||||
|
||||
assertThat(errorJson, is("""
|
||||
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
|
||||
inference entity id [id] status [500]. \
|
||||
Error message: [a message]",\
|
||||
"type":"hugging_face_error"}}"""));
|
||||
}
|
||||
|
||||
public void testFailValidationWithoutOptionalFields() throws IOException {
|
||||
var responseJson = """
|
||||
{
|
||||
"error": "a message"
|
||||
}
|
||||
""";
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson);
|
||||
|
||||
assertThat(errorJson, is("""
|
||||
{"error":{"code":"bad_request","message":"Received a server error status code for request from \
|
||||
inference entity id [id] status [500]. \
|
||||
Error message: [a message]","type":"hugging_face_error"}}"""));
|
||||
}
|
||||
|
||||
public void testFailValidationWithInvalidJson() throws IOException {
|
||||
var responseJson = """
|
||||
what? this isn't a json
|
||||
""";
|
||||
|
||||
var errorJson = invalidResponseJson(responseJson);
|
||||
|
||||
assertThat(errorJson, is("""
|
||||
{"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\
|
||||
[500]","type":"ErrorResponse"}}"""));
|
||||
}
|
||||
|
||||
private String invalidResponseJson(String responseJson) throws IOException {
|
||||
var exception = invalidResponse(responseJson);
|
||||
assertThat(exception, isA(RetryException.class));
|
||||
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
|
||||
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
|
||||
}
|
||||
|
||||
private Exception invalidResponse(String responseJson) {
|
||||
return expectThrows(
|
||||
RetryException.class,
|
||||
() -> responseHandler.validateResponse(
|
||||
mock(),
|
||||
mock(),
|
||||
mockRequest(),
|
||||
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
|
||||
true
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static Request mockRequest() {
|
||||
var request = mock(Request.class);
|
||||
when(request.getInferenceEntityId()).thenReturn("id");
|
||||
when(request.isStreaming()).thenReturn(true);
|
||||
return request;
|
||||
}
|
||||
|
||||
private static HttpResponse mock500Response() {
|
||||
int statusCode = 500;
|
||||
var statusLine = mock(StatusLine.class);
|
||||
when(statusLine.getStatusCode()).thenReturn(statusCode);
|
||||
|
||||
var response = mock(HttpResponse.class);
|
||||
when(response.getStatusLine()).thenReturn(statusLine);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
private String toJson(UnifiedChatCompletionException e) throws IOException {
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ package org.elasticsearch.xpack.inference.services.huggingface;
|
|||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionTestUtils;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
|
@ -29,20 +30,29 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
|
@ -53,14 +63,19 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
||||
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
|
@ -74,7 +89,12 @@ import static org.hamcrest.CoreMatchers.is;
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class HuggingFaceServiceTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
|
@ -175,6 +195,438 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel() throws IOException {
|
||||
var url = "url";
|
||||
var model = "model";
|
||||
var secret = "secret";
|
||||
|
||||
try (var service = createHuggingFaceService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||
|
||||
var completionsModel = (HuggingFaceChatCompletionModel) m;
|
||||
|
||||
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
|
||||
assertThat(completionsModel.getServiceSettings().modelId(), is(model));
|
||||
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
getRequestConfigMap(
|
||||
HuggingFaceChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model),
|
||||
getSecretSettingsMap(secret)
|
||||
),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_WithoutModelId() throws IOException {
|
||||
var url = "url";
|
||||
var secret = "secret";
|
||||
|
||||
try (var service = createHuggingFaceService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
|
||||
assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||
|
||||
var completionsModel = (HuggingFaceChatCompletionModel) m;
|
||||
|
||||
assertThat(completionsModel.getServiceSettings().uri().toString(), is(url));
|
||||
assertNull(completionsModel.getServiceSettings().modelId());
|
||||
assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret));
|
||||
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
getRequestConfigMap(getServiceSettingsMap(url), getSecretSettingsMap(secret)),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
||||
var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION);
|
||||
|
||||
try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) {
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(
|
||||
mockModel,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of(""),
|
||||
false,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
}
|
||||
|
||||
verify(sender, times(1)).close();
|
||||
verifyNoMoreInteractions(factory);
|
||||
verifyNoMoreInteractions(sender);
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionInfer() throws Exception {
|
||||
// The escapes are because the streaming response must be on a single line
|
||||
String responseJson = """
|
||||
data: {\
|
||||
"id":"12345",\
|
||||
"object":"chat.completion.chunk",\
|
||||
"created":123456789,\
|
||||
"model":"gpt-4o-mini",\
|
||||
"system_fingerprint": "123456789",\
|
||||
"choices":[\
|
||||
{\
|
||||
"index":0,\
|
||||
"delta":{\
|
||||
"content":"hello, world"\
|
||||
},\
|
||||
"logprobs":null,\
|
||||
"finish_reason":"stop"\
|
||||
}\
|
||||
],\
|
||||
"usage":{\
|
||||
"prompt_tokens": 16,\
|
||||
"completion_tokens": 28,\
|
||||
"total_tokens": 44,\
|
||||
"prompt_tokens_details": {\
|
||||
"cached_tokens": 0,\
|
||||
"audio_tokens": 0\
|
||||
},\
|
||||
"completion_tokens_details": {\
|
||||
"reasoning_tokens": 0,\
|
||||
"audio_tokens": 0,\
|
||||
"accepted_prediction_tokens": 0,\
|
||||
"rejected_prediction_tokens": 0\
|
||||
}\
|
||||
}\
|
||||
}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
{"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """
|
||||
"model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """
|
||||
"usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}""");
|
||||
}
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionNonStreamingError() throws Exception {
|
||||
String responseJson = """
|
||||
{
|
||||
"error": "Model not found."
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||
var latch = new CountDownLatch(1);
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
var t = unwrapCause(e);
|
||||
assertThat(t, isA(UnifiedChatCompletionException.class));
|
||||
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
[404]. Error message: [Model not found.]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}"""));
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
}), latch::countDown)
|
||||
);
|
||||
assertTrue(latch.await(30, TimeUnit.SECONDS));
|
||||
}
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionError() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\
|
||||
If it should be supported, please open an issue.","http_status_code":422}}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"422",\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \
|
||||
cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \
|
||||
JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: {"error":{"http_status_code":422}}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"422",\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||
[unknown]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \
|
||||
Outlines.\\nIf it should be supported, please open an issue."}}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||
[Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \
|
||||
{\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\
|
||||
by Outlines.\\nIf it should be supported, please open an issue.]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: {"error":{}}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||
[unknown]",\
|
||||
"type":"hugging_face_error"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionMalformedError() throws Exception {
|
||||
String responseJson = """
|
||||
data: { invalid json }
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"bad_request",\
|
||||
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
|
||||
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
|
||||
"type":"x_content_parse_exception"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
private void testStreamError(String expectedResponse) throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
|
||||
e = unwrapCause(e);
|
||||
assertThat(e, isA(UnifiedChatCompletionException.class));
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is(expectedResponse));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_StreamRequest() throws Exception {
|
||||
String responseJson = """
|
||||
data: {\
|
||||
"id":"12345",\
|
||||
"object":"chat.completion.chunk",\
|
||||
"created":123456789,\
|
||||
"model":"gpt-4o-mini",\
|
||||
"system_fingerprint": "123456789",\
|
||||
"choices":[\
|
||||
{\
|
||||
"index":0,\
|
||||
"delta":{\
|
||||
"content":"hello, world"\
|
||||
},\
|
||||
"logprobs":null,\
|
||||
"finish_reason":null\
|
||||
}\
|
||||
]\
|
||||
}
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
streamCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
private InferenceEventsAssertion streamCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(
|
||||
model,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of("abc"),
|
||||
true,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_StreamRequest_ErrorResponse() throws Exception {
|
||||
String responseJson = """
|
||||
{
|
||||
"error": {
|
||||
"message": "You didn't provide an API key..."
|
||||
}
|
||||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||
|
||||
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
|
||||
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an authentication error status code for request from inference entity id [id] status [401]. "
|
||||
+ "Error message: [You didn't provide an API key...]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testInfer_StreamRequestRetry() throws Exception {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
|
||||
{
|
||||
"error": {
|
||||
"message": "server busy"
|
||||
}
|
||||
}"""));
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||
data: {\
|
||||
"id":"12345",\
|
||||
"object":"chat.completion.chunk",\
|
||||
"created":123456789,\
|
||||
"model":"gpt-4o-mini",\
|
||||
"system_fingerprint": "123456789",\
|
||||
"choices":[\
|
||||
{\
|
||||
"index":0,\
|
||||
"delta":{\
|
||||
"content":"hello, world"\
|
||||
},\
|
||||
"logprobs":null,\
|
||||
"finish_reason":null\
|
||||
}\
|
||||
]\
|
||||
}
|
||||
|
||||
"""));
|
||||
|
||||
streamCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
public void testSupportsStreaming() throws IOException {
|
||||
try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) {
|
||||
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
|
||||
assertFalse(service.canStream(TaskType.ANY));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
|
||||
try (var service = createHuggingFaceService()) {
|
||||
var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret"));
|
||||
|
@ -258,6 +710,25 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException {
|
||||
try (var service = createHuggingFaceService()) {
|
||||
var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret"));
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
persistedConfig.config(),
|
||||
persistedConfig.secrets()
|
||||
);
|
||||
|
||||
assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class));
|
||||
|
||||
var chatCompletionModel = (HuggingFaceChatCompletionModel) model;
|
||||
assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is("url"));
|
||||
assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||
try (var service = createHuggingFaceService()) {
|
||||
var persistedConfig = getPersistedConfigMap(
|
||||
|
@ -821,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
{
|
||||
"service": "hugging_face",
|
||||
"name": "Hugging Face",
|
||||
"task_types": ["text_embedding", "sparse_embedding"],
|
||||
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
|
||||
"configurations": {
|
||||
"api_key": {
|
||||
"description": "API Key for the provider you're connecting to.",
|
||||
|
@ -830,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
"sensitive": true,
|
||||
"updatable": true,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"rate_limit.requests_per_minute": {
|
||||
"description": "Minimize the number of rate limit errors.",
|
||||
|
@ -839,17 +1310,16 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "int",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"url": {
|
||||
"default_value": "https://api.openai.com/v1/embeddings",
|
||||
"description": "The URL endpoint to use for the requests.",
|
||||
"label": "URL",
|
||||
"required": true,
|
||||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,10 +24,13 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsT
|
|||
import org.elasticsearch.xpack.inference.InputTypeTests;
|
||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||
import org.junit.After;
|
||||
|
@ -38,6 +41,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
|
@ -425,4 +429,107 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "chat.completion",
|
||||
"id": "",
|
||||
"created": 1745855316,
|
||||
"model": "/repository",
|
||||
"system_fingerprint": "3.2.3-sha-a1f3ebe",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there, how may I assist you today?"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool));
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?"))));
|
||||
|
||||
assertChatCompletionRequest();
|
||||
}
|
||||
}
|
||||
|
||||
public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException {
|
||||
var settings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"invalid_field": "unexpected"
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = createChatCompletionFuture(
|
||||
sender,
|
||||
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]")
|
||||
);
|
||||
|
||||
assertChatCompletionRequest();
|
||||
}
|
||||
}
|
||||
|
||||
private PlainActionFuture<InferenceServiceResults> createChatCompletionFuture(Sender sender, ServiceComponents threadPool) {
|
||||
var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
|
||||
var actionCreator = new HuggingFaceActionCreator(sender, threadPool);
|
||||
var action = actionCreator.create(model);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
return listener;
|
||||
}
|
||||
|
||||
private void assertChatCompletionRequest() throws IOException {
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
assertThat(
|
||||
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
||||
);
|
||||
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello"))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertThat(requestMap.get("stream"), is(false));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.action;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockRequest;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.COMPLETION_HANDLER;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.USER_ROLE;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class HuggingFaceChatCompletionActionTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "result content"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content"))));
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
|
||||
MockRequest request = webServer.requests().get(0);
|
||||
|
||||
assertNull(request.getUri().getQuery());
|
||||
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()));
|
||||
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(request.getBody());
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertThat(requestMap.get("stream"), is(false));
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
|
||||
try (var sender = mock(Sender.class)) {
|
||||
var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", sender));
|
||||
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to send hugging face chat completions request. Cause: failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there, how may I assist you today?"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("hugging face chat completions only accepts 1 input"));
|
||||
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
private ExecutableAction createAction(String url, Sender sender) {
|
||||
var model = createCompletionModel(url, "secret", "model");
|
||||
var manager = new GenericRequestManager<>(
|
||||
threadPool,
|
||||
model,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
var errorMessage = constructFailedToSendRequestMessage("hugging face chat completions");
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "hugging face chat completions");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class HuggingFaceChatCompletionModelTests extends ESTestCase {
|
||||
|
||||
public void testThrowsURISyntaxException_ForInvalidUrl() {
|
||||
var thrownException = expectThrows(IllegalArgumentException.class, () -> createCompletionModel("^^", "secret", "id"));
|
||||
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||
}
|
||||
|
||||
public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) {
|
||||
return new HuggingFaceChatCompletionModel(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
"service",
|
||||
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) {
|
||||
return new HuggingFaceChatCompletionModel(
|
||||
"id",
|
||||
TaskType.CHAT_COMPLETION,
|
||||
"service",
|
||||
new HuggingFaceChatCompletionServiceSettings(modelId, url, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() {
|
||||
var model = createCompletionModel("url", "api_key", "model_name");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() {
|
||||
var model = createCompletionModel("url", "api_key", null);
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() {
|
||||
var model = createCompletionModel("url", "api_key", null);
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||
|
||||
assertNull(overriddenModel.getServiceSettings().modelId());
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
|
||||
var model = createCompletionModel("url", "api_key", "model_name");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null, // not overriding model
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = HuggingFaceChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.completion;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<
|
||||
HuggingFaceChatCompletionServiceSettings> {
|
||||
|
||||
public static final String MODEL_ID = "some model";
|
||||
public static final String CORRECT_URL = "https://www.elastic.co";
|
||||
public static final int RATE_LIMIT = 2;
|
||||
|
||||
public void testFromMap_AllFields_Success() {
|
||||
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.MODEL_ID,
|
||||
MODEL_ID,
|
||||
ServiceFields.URL,
|
||||
CORRECT_URL,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new HuggingFaceChatCompletionServiceSettings(
|
||||
MODEL_ID,
|
||||
ServiceUtils.createUri(CORRECT_URL),
|
||||
new RateLimitSettings(RATE_LIMIT)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_MissingModelId_Success() {
|
||||
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
CORRECT_URL,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(
|
||||
serviceSettings,
|
||||
is(new HuggingFaceChatCompletionServiceSettings(null, ServiceUtils.createUri(CORRECT_URL), new RateLimitSettings(RATE_LIMIT)))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_MissingRateLimit_Success() {
|
||||
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), null)));
|
||||
}
|
||||
|
||||
public void testFromMap_MissingUrl_ThrowsException() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.MODEL_ID,
|
||||
MODEL_ID,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString(
|
||||
Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [url];", ServiceFields.URL)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_EmptyUrl_ThrowsException() {
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.MODEL_ID,
|
||||
MODEL_ID,
|
||||
ServiceFields.URL,
|
||||
"",
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
|
||||
ServiceFields.URL
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_InvalidUrl_ThrowsException() {
|
||||
String invalidUrl = "https://www.elastic^^co";
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.MODEL_ID,
|
||||
MODEL_ID,
|
||||
ServiceFields.URL,
|
||||
invalidUrl,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]",
|
||||
invalidUrl,
|
||||
ServiceFields.URL
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.MODEL_ID,
|
||||
MODEL_ID,
|
||||
ServiceFields.URL,
|
||||
CORRECT_URL,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
var expected = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"model_id": "some model",
|
||||
"url": "https://www.elastic.co",
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 2
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
assertThat(xContentResult, is(expected));
|
||||
}
|
||||
|
||||
public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException {
|
||||
var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
var expected = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"url": "https://www.elastic.co",
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 3000
|
||||
}
|
||||
}
|
||||
""");
|
||||
assertThat(xContentResult, is(expected));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<HuggingFaceChatCompletionServiceSettings> instanceReader() {
|
||||
return HuggingFaceChatCompletionServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceChatCompletionServiceSettings createTestInstance() {
|
||||
return createRandomWithNonNullUrl();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceChatCompletionServiceSettings instance)
|
||||
throws IOException {
|
||||
return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceChatCompletionServiceSettings mutateInstanceForVersion(
|
||||
HuggingFaceChatCompletionServiceSettings instance,
|
||||
TransportVersion version
|
||||
) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() {
|
||||
return createRandom(randomAlphaOfLength(15));
|
||||
}
|
||||
|
||||
private static HuggingFaceChatCompletionServiceSettings createRandom(String url) {
|
||||
var modelId = randomAlphaOfLength(8);
|
||||
|
||||
return new HuggingFaceChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom());
|
||||
}
|
||||
|
||||
public static Map<String, Object> getServiceSettingsMap(String url, String model) {
|
||||
var map = new HashMap<String, Object>();
|
||||
|
||||
map.put(ServiceFields.URL, url);
|
||||
map.put(ServiceFields.MODEL_ID, model);
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
|
@ -12,16 +12,17 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequestEntity;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
||||
public class HuggingFaceInferenceRequestEntityTests extends ESTestCase {
|
||||
public class HuggingFaceEmbeddingsRequestEntityTests extends ESTestCase {
|
||||
|
||||
public void testXContent() throws IOException {
|
||||
var entity = new HuggingFaceInferenceRequestEntity(List.of("abc"));
|
||||
var entity = new HuggingFaceEmbeddingsRequestEntity(List.of("abc"));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentType;
|
|||
import org.elasticsearch.xpack.inference.common.Truncator;
|
||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
|
@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.contains;
|
|||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class HuggingFaceInferenceRequestTests extends ESTestCase {
|
||||
public class HuggingFaceEmbeddingsRequestTests extends ESTestCase {
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testCreateRequest() throws URISyntaxException, IOException {
|
||||
var huggingFaceRequest = createRequest("www.google.com", "secret", "abc");
|
||||
|
@ -67,9 +68,9 @@ public class HuggingFaceInferenceRequestTests extends ESTestCase {
|
|||
assertTrue(truncatedRequest.getTruncationInfo()[0]);
|
||||
}
|
||||
|
||||
public static HuggingFaceInferenceRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
|
||||
public static HuggingFaceEmbeddingsRequest createRequest(String url, String apiKey, String input) throws URISyntaxException {
|
||||
|
||||
return new HuggingFaceInferenceRequest(
|
||||
return new HuggingFaceEmbeddingsRequest(
|
||||
TruncatorTests.createTruncator(),
|
||||
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
|
||||
HuggingFaceEmbeddingsModelTests.createModel(url, apiKey)
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel;
|
||||
|
||||
public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase {
|
||||
|
||||
private static final String ROLE = "user";
|
||||
|
||||
public void testModelUserFieldsSerialization() throws IOException {
|
||||
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
|
||||
new UnifiedCompletionRequest.ContentString("Hello, world!"),
|
||||
ROLE,
|
||||
null,
|
||||
null
|
||||
);
|
||||
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
|
||||
messageList.add(message);
|
||||
|
||||
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
|
||||
|
||||
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
|
||||
HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint");
|
||||
|
||||
HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model);
|
||||
|
||||
XContentBuilder builder = JsonXContent.contentBuilder();
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
|
||||
String jsonString = Strings.toString(builder);
|
||||
String expectedJson = """
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Hello, world!",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"model": "test-endpoint",
|
||||
"n": 1,
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}
|
||||
""";
|
||||
assertJsonEquals(jsonString, expectedJson);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request;
|
||||
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase {
|
||||
|
||||
public void testCreateRequest_WithStreaming() throws IOException {
|
||||
var request = createRequest("url", "secret", randomAlphaOfLength(15), "model", true);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap.get("stream"), is(true));
|
||||
}
|
||||
|
||||
public void testTruncate_DoesNotReduceInputTextSize() throws IOException {
|
||||
String input = randomAlphaOfLength(5);
|
||||
var request = createRequest("url", "secret", input, "model", true);
|
||||
var truncatedRequest = request.truncate();
|
||||
assertThat(request.getURI().toString(), is("url"));
|
||||
|
||||
var httpRequest = truncatedRequest.createHttpRequest();
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(5));
|
||||
|
||||
// We do not truncate for Hugging Face chat completions
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
|
||||
assertThat(requestMap.get("model"), is("model"));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertTrue((Boolean) requestMap.get("stream"));
|
||||
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
|
||||
}
|
||||
|
||||
public void testTruncationInfo_ReturnsNull() {
|
||||
var request = createRequest("url", "secret", randomAlphaOfLength(5), "model", true);
|
||||
assertNull(request.getTruncationInfo());
|
||||
}
|
||||
|
||||
public static HuggingFaceUnifiedChatCompletionRequest createRequest(String url, String apiKey, String input, @Nullable String model) {
|
||||
return createRequest(url, apiKey, input, model, false);
|
||||
}
|
||||
|
||||
public static HuggingFaceUnifiedChatCompletionRequest createRequest(
|
||||
@Nullable String url,
|
||||
String apiKey,
|
||||
String input,
|
||||
@Nullable String model,
|
||||
boolean stream
|
||||
) {
|
||||
var chatCompletionModel = HuggingFaceChatCompletionModelTests.createCompletionModel(url, apiKey, model);
|
||||
return new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
|
||||
}
|
||||
|
||||
}
|
|
@ -19,6 +19,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatComple
|
|||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
|
||||
|
||||
public void testJsonLiteral() {
|
||||
|
@ -182,6 +184,73 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testJsonNullFunctionName() throws IOException {
|
||||
String json = """
|
||||
{
|
||||
"object": "chat.completion.chunk",
|
||||
"id": "",
|
||||
"created": 1746800254,
|
||||
"model": "/repository",
|
||||
"system_fingerprint": "3.2.3-sha-a1f3ebe",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": null,
|
||||
"arguments": " \\\""
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": null
|
||||
}
|
||||
],
|
||||
"usage": null
|
||||
}
|
||||
""";
|
||||
|
||||
try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, json)) {
|
||||
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser
|
||||
.parse(parser);
|
||||
|
||||
// Assertions to verify the parsed object
|
||||
assertThat(chunk.id(), is(""));
|
||||
assertThat(chunk.model(), is("/repository"));
|
||||
assertThat(chunk.object(), is("chat.completion.chunk"));
|
||||
assertNull(chunk.usage());
|
||||
|
||||
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
|
||||
assertThat(choices.size(), is(1));
|
||||
|
||||
// First choice assertions
|
||||
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0);
|
||||
assertNull(firstChoice.delta().content());
|
||||
assertNull(firstChoice.delta().refusal());
|
||||
assertThat(firstChoice.delta().role(), is("assistant"));
|
||||
assertNull(firstChoice.finishReason());
|
||||
assertThat(firstChoice.index(), is(0));
|
||||
|
||||
List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = firstChoice.delta()
|
||||
.toolCalls();
|
||||
assertThat(toolCalls.size(), is(1));
|
||||
|
||||
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
|
||||
assertThat(toolCall.index(), is(0));
|
||||
assertThat(toolCall.id(), is("8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a"));
|
||||
assertThat(toolCall.type(), is("function"));
|
||||
assertNull(toolCall.function().name());
|
||||
assertThat(toolCall.function().arguments(), is(" \""));
|
||||
}
|
||||
}
|
||||
|
||||
public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException {
|
||||
// Generate random values for the JSON fields
|
||||
int toolCallIndex = randomIntBetween(0, 10);
|
||||
|
|
Loading…
Reference in New Issue