[ML] Integrate OpenAi Chat Completion in SageMaker (#127767)

SageMaker now supports Completion and Chat Completion using the OpenAI
interfaces.

Additionally:
- Fixed bug related to timeouts being nullable, default to 30s timeout
- Exposed existing OpenAi request/response parsing logic for reuse
This commit is contained in:
Pat Whelan 2025-05-27 14:50:10 -04:00 committed by GitHub
parent 13f3864ed9
commit 28307688f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 669 additions and 150 deletions

View File

@ -0,0 +1,5 @@
pr: 127767
summary: Integrate `OpenAi` Chat Completion in `SageMaker`
area: Machine Learning
type: enhancement
issues: []

View File

@ -180,6 +180,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@ -264,6 +265,7 @@ public class TransportVersions {
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View File

@ -124,7 +124,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(11));
assertThat(services.size(), equalTo(12));
var providers = providers(services);
@ -142,7 +142,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"googleaistudio",
"openai",
"streaming_completion_test_service",
"hugging_face"
"hugging_face",
"sagemaker"
).toArray()
)
);
@ -150,13 +151,15 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(5));
assertThat(services.size(), equalTo(6));
var providers = providers(services);
assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
)
);
}

View File

@ -12,13 +12,12 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;
/**
* Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher}
@ -34,19 +33,13 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
Deque<ServerSentEvent> item,
ParseChunkFunction<ParsedChunk> parseFunction,
XContentParserConfiguration parserConfig,
Logger logger
) throws Exception {
XContentParserConfiguration parserConfig
) {
var results = new ArrayDeque<ParsedChunk>(item.size());
for (ServerSentEvent event : item) {
if (event.hasData()) {
try {
var delta = parseFunction.apply(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw e;
}
var delta = parseFunction.apply(parserConfig, event);
delta.forEach(results::offer);
}
}
@ -55,7 +48,7 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
@FunctionalInterface
public interface ParseChunkFunction<ParsedChunk> {
Iterator<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException;
Stream<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event);
}
@Override

View File

@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
private final boolean stream;
public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
Objects.requireNonNull(unifiedChatInput);
this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream());
}
this.unifiedRequest = unifiedChatInput.getRequest();
this.stream = unifiedChatInput.stream();
public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) {
this.unifiedRequest = Objects.requireNonNull(unifiedRequest);
this.stream = stream;
}
@Override

View File

@ -9,8 +9,10 @@ package org.elasticsearch.xpack.inference.services.openai;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
@ -20,11 +22,10 @@ import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import java.io.IOException;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig);
if (results.isEmpty()) {
upstream().request(1);
@ -122,10 +123,9 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
}
}
private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
throws IOException {
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@ -167,11 +167,14 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
consumeUntilObjectEnd(parser); // end choices
return ""; // stopped
}).stream()
.filter(Objects::nonNull)
.filter(Predicate.not(String::isEmpty))
.map(StreamingChatCompletionResults.Result::new)
.iterator();
}).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new);
} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
e,
event
);
}
}
}

View File

@ -50,7 +50,6 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
@ -81,6 +80,10 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
}
protected Exception buildMidStreamError(Request request, String message, Exception e) {
return buildMidStreamError(request.getInferenceEntityId(), message, e);
}
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
@ -88,7 +91,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
inferenceEntityId,
errorResponse.getErrorMessage()
),
oer.type(),
@ -100,7 +103,7 @@ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatComple
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
createErrorType(errorResponse),
"stream_error"
);

View File

@ -22,11 +22,10 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentE
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@ -75,7 +74,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
} else if (event.hasData()) {
try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
delta.forEach(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.data(), e);
@ -90,12 +89,12 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
}
}
private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
XContentParserConfiguration parserConfig,
ServerSentEvent event
) throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@ -106,7 +105,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);
return Collections.singleton(chunk).iterator();
return Stream.of(chunk);
}
}

View File

@ -67,7 +67,11 @@ public class OpenAiChatCompletionResponseEntity {
*/
public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return fromResponse(response.body());
}
public static ChatCompletionResults fromResponse(byte[] response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)) {
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
}
}

View File

@ -47,6 +47,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidMod
public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
private static final int DEFAULT_BATCH_SIZE = 256;
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
private final SageMakerModelBuilder modelBuilder;
private final SageMakerClient client;
private final SageMakerSchemas schemas;
@ -128,7 +129,7 @@ public class SageMakerService implements InferenceService {
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
@ -148,7 +149,7 @@ public class SageMakerService implements InferenceService {
client.invokeStream(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
@ -160,7 +161,7 @@ public class SageMakerService implements InferenceService {
client.invoke(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
@ -201,7 +202,7 @@ public class SageMakerService implements InferenceService {
public void unifiedCompletionInfer(
Model model,
UnifiedCompletionRequest request,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
@ -217,7 +218,7 @@ public class SageMakerService implements InferenceService {
client.invokeStream(
regionAndSecrets,
sagemakerRequest,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
@ -235,7 +236,7 @@ public class SageMakerService implements InferenceService {
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
if (model instanceof SageMakerModel == false) {

View File

@ -12,10 +12,12 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -39,7 +41,7 @@ public class SageMakerSchemas {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload());
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
streamSchemas = schemas.entrySet()
.stream()
@ -88,7 +90,16 @@ public class SageMakerSchemas {
)
),
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
).toList();
)
// Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates
.collect(
() -> new HashMap<String, NamedWriteableRegistry.Entry>(),
(map, entry) -> map.putIfAbsent(entry.name, entry),
Map::putAll
)
.values()
.stream()
.toList();
}
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {

View File

@ -20,6 +20,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
@ -66,16 +67,16 @@ public class SageMakerStreamSchema extends SageMakerSchema {
}
public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
return streamResponse(model, response, payload::streamResponseBody, this::error);
return new StreamingChatCompletionResults(streamResponse(model, response, payload::streamResponseBody, this::error));
}
private InferenceServiceResults streamResponse(
private <T> Flow.Publisher<T> streamResponse(
SageMakerModel model,
SageMakerClient.SageMakerStream response,
CheckedBiFunction<SageMakerModel, SdkBytes, InferenceServiceResults.Result, Exception> parseFunction,
CheckedBiFunction<SageMakerModel, SdkBytes, T, Exception> parseFunction,
BiFunction<SageMakerModel, Exception, Exception> errorFunction
) {
return new StreamingChatCompletionResults(downstream -> {
return downstream -> {
response.responseStream().subscribe(new Flow.Subscriber<>() {
private volatile Flow.Subscription upstream;
@ -118,7 +119,7 @@ public class SageMakerStreamSchema extends SageMakerSchema {
downstream.onComplete();
}
});
});
};
}
public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
@ -126,7 +127,9 @@ public class SageMakerStreamSchema extends SageMakerSchema {
}
public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError);
return new StreamingUnifiedChatCompletionResults(
streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError)
);
}
public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {

View File

@ -9,9 +9,10 @@ package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.core.SdkBytes;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
@ -38,9 +39,9 @@ public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
* This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would
* handle the request translation for both streaming and non-streaming.
*/
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;
InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
}

View File

@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiStreamingProcessor;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload;
import java.util.ArrayDeque;
import java.util.Map;
import java.util.stream.Stream;
public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload {
private static final XContent jsonXContent = JsonXContent.jsonXContent;
private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters();
private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(
LoggingDeprecationHandler.INSTANCE
);
private static final String USER_FIELD = "user";
private static final String USER_ROLE = "user";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private static final ResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(
"sagemaker openai chat completion",
((request, result) -> {
assert false : "do not call this";
throw new UnsupportedOperationException("SageMaker should not call this object's response parser.");
})
);
@Override
public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception {
return completion(model, new UnifiedChatCompletionRequestEntity(request, true), request.maxCompletionTokens());
}
private SdkBytes completion(SageMakerModel model, UnifiedChatCompletionRequestEntity requestEntity, @Nullable Long maxCompletionTokens)
throws Exception {
if (model.apiTaskSettings() instanceof SageMakerOpenAiTaskSettings apiTaskSettings) {
return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> {
requestEntity.toXContent(builder, params);
if (Strings.isNullOrEmpty(apiTaskSettings.user()) == false) {
builder.field(USER_FIELD, apiTaskSettings.user());
}
if (maxCompletionTokens != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, maxCompletionTokens);
}
return builder;
}));
} else {
throw createUnsupportedSchemaException(model);
}
}
@Override
public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) {
var serverSentEvents = serverSentEvents(response);
var results = serverSentEvents.flatMap(event -> {
if ("error".equals(event.type())) {
throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), null);
} else {
try {
return OpenAiUnifiedStreamingProcessor.parse(parserConfig, event);
} catch (Exception e) {
throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), e);
}
}
})
.collect(
() -> new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>(),
ArrayDeque::offer,
ArrayDeque::addAll
);
return new StreamingUnifiedChatCompletionResults.Results(results);
}
/*
* We should be safe to use ServerSentEventParser. It was built knowing Apache HTTP will have leftover bytes for us to manage,
* but SageMaker uses Netty and (likely, hopefully) doesn't have that problem.
*/
private Stream<ServerSentEvent> serverSentEvents(SdkBytes response) {
return new ServerSentEventParser().parse(response.asByteArray()).stream().filter(ServerSentEvent::hasData);
}
@Override
public String api() {
return "openai";
}
@Override
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return SageMakerOpenAiTaskSettings.fromMap(taskSettings, validationException);
}
@Override
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of(
new NamedWriteableRegistry.Entry(
SageMakerStoredTaskSchema.class,
SageMakerOpenAiTaskSettings.NAME,
SageMakerOpenAiTaskSettings::new
)
);
}
@Override
public String accept(SageMakerModel model) {
return APPLICATION_JSON;
}
@Override
public String contentType(SageMakerModel model) {
return APPLICATION_JSON;
}
@Override
public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
return completion(
model,
new UnifiedChatCompletionRequestEntity(new UnifiedChatInput(request.input(), USER_ROLE, request.stream())),
null
);
}
@Override
public InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
return OpenAiChatCompletionResponseEntity.fromResponse(response.body().asByteArray());
}
@Override
public StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) {
var serverSentEvents = serverSentEvents(response);
var results = serverSentEvents.flatMap(event -> OpenAiStreamingProcessor.parse(parserConfig, event))
.collect(() -> new ArrayDeque<StreamingChatCompletionResults.Result>(), ArrayDeque::offer, ArrayDeque::addAll);
return new StreamingChatCompletionResults.Results(results);
}
}

View File

@ -40,7 +40,6 @@ import java.util.Map;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
@ -64,14 +63,18 @@ public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
@Override
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return ApiTaskSettings.fromMap(taskSettings, validationException);
return SageMakerOpenAiTaskSettings.fromMap(taskSettings, validationException);
}
@Override
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of(
new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new),
new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new)
new NamedWriteableRegistry.Entry(
SageMakerStoredTaskSchema.class,
SageMakerOpenAiTaskSettings.NAME,
SageMakerOpenAiTaskSettings::new
)
);
}
@ -88,7 +91,7 @@ public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
@Override
public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings
&& model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) {
&& model.apiTaskSettings() instanceof SageMakerOpenAiTaskSettings apiTaskSettings) {
try (var builder = JsonXContent.contentBuilder()) {
builder.startObject();
if (request.query() != null) {
@ -178,52 +181,4 @@ public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
return new ApiServiceSettings(dimensions, false);
}
}
record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
private static final String NAME = "sagemaker_openai_text_embeddings_task_settings";
private static final String USER_FIELD = "user";
ApiTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalString());
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(user);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return user != null ? builder.field(USER_FIELD, user) : builder;
}
@Override
public boolean isEmpty() {
return user == null;
}
@Override
public ApiTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();
var newTaskSettings = fromMap(newSettings, validationException);
validationException.throwIfValidationErrorsExist();
return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user);
}
static ApiTaskSettings fromMap(Map<String, Object> map, ValidationException exception) {
var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception);
return new ApiTaskSettings(user);
}
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.io.IOException;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
record SageMakerOpenAiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
static final String NAME = "sagemaker_openai_task_settings";
private static final String USER_FIELD = "user";
SageMakerOpenAiTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalString());
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(user);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return user != null ? builder.field(USER_FIELD, user) : builder;
}
@Override
public boolean isEmpty() {
return user == null;
}
@Override
public SageMakerOpenAiTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();
var newTaskSettings = fromMap(newSettings, validationException);
validationException.throwIfValidationErrorsExist();
return new SageMakerOpenAiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user);
}
static SageMakerOpenAiTaskSettings fromMap(Map<String, Object> map, ValidationException exception) {
var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception);
return new SageMakerOpenAiTaskSettings(user);
}
}

View File

@ -11,6 +11,7 @@ import software.amazon.awssdk.core.SdkBytes;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
@ -88,27 +89,33 @@ public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPa
}
public final void testWithUnknownApiServiceSettings() {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(mock());
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
when(model.api()).thenReturn("serviceApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
// skip the test if we don't have SageMakerStoredServiceSchema for this payload
if (randomApiServiceSettings() != SageMakerStoredServiceSchema.NO_OP) {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(mock());
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
when(model.api()).thenReturn("serviceApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:"));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:"));
}
}
public final void testWithUnknownApiTaskSettings() {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings());
when(model.apiTaskSettings()).thenReturn(mock());
when(model.api()).thenReturn("taskApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
// skip the test if we don't have SageMakerStoredTaskSchema for this payload
if (randomApiTaskSettings() != SageMakerStoredTaskSchema.NO_OP) {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings());
when(model.apiTaskSettings()).thenReturn(mock());
when(model.api()).thenReturn("taskApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:"));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:"));
}
}
public final void testUpdate() throws IOException {
@ -131,12 +138,6 @@ public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPa
});
assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty());
}
if (payload instanceof SageMakerStoredTaskSchema taskSchema) {
var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings);
var otherTaskSettingsAsMap = toMap(otherTaskSettings);
taskSchema.updatedTaskSettings(otherTaskSettingsAsMap);
}
}
protected static SageMakerInferenceRequest randomRequest() {
@ -153,4 +154,8 @@ public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPa
protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) {
assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue));
}
protected static void assertJsonSdkBytes(SdkBytes sdkBytes, String expectedValue) throws IOException {
assertThat(sdkBytes.asUtf8String(), equalTo(XContentHelper.stripWhitespace(expectedValue)));
}
}

View File

@ -11,12 +11,12 @@ import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
@ -41,15 +41,18 @@ public class SageMakerSchemasTests extends ESTestCase {
private static final SageMakerSchemas schemas = new SageMakerSchemas();
public void testSupportedTaskTypes() {
assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING));
assertThat(
schemas.supportedTaskTypes(),
containsInAnyOrder(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION)
);
}
public void testSupportedStreamingTasks() {
assertThat(schemas.supportedStreamingTasks(), empty());
assertThat(schemas.supportedStreamingTasks(), containsInAnyOrder(TaskType.COMPLETION, TaskType.CHAT_COMPLETION));
}
public void testSchemaFor() {
var payloads = Stream.of(new OpenAiTextEmbeddingPayload());
var payloads = Stream.of(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
payloads.forEach(payload -> {
payload.supportedTasks().forEach(taskType -> {
var model = mockModel(taskType, payload.api());
@ -59,7 +62,7 @@ public class SageMakerSchemasTests extends ESTestCase {
}
public void testStreamSchemaFor() {
var payloads = Stream.<SageMakerStreamSchemaPayload>of(/* For when we add support for streaming payloads */);
var payloads = Stream.<SageMakerStreamSchemaPayload>of(new OpenAiCompletionPayload());
payloads.forEach(payload -> {
payload.supportedTasks().forEach(taskType -> {
var model = mockModel(taskType, payload.api());
@ -77,10 +80,11 @@ public class SageMakerSchemasTests extends ESTestCase {
public void testMissingTaskTypeThrowsException() {
var knownPayload = new OpenAiTextEmbeddingPayload();
var unknownTaskType = TaskType.COMPLETION;
var unknownTaskType = TaskType.RERANK;
var knownModel = mockModel(unknownTaskType, knownPayload.api());
assertThrows(
"Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]",
"Task [rerank] is not compatible for service [sagemaker] and api [openai]. "
+ "Supported tasks: [text_embedding, completion, chat_completion]",
ElasticsearchStatusException.class,
() -> schemas.schemaFor(knownModel)
);
@ -105,7 +109,10 @@ public class SageMakerSchemasTests extends ESTestCase {
}
public void testNamedWriteables() {
var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables());
var namedWriteables = Stream.of(
new OpenAiTextEmbeddingPayload().namedWriteables(),
new OpenAiCompletionPayload().namedWriteables()
);
var expectedNamedWriteables = Stream.concat(
namedWriteables.flatMap(names -> names.map(entry -> entry.name)),

View File

@ -0,0 +1,283 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class OpenAiCompletionPayloadTests extends SageMakerSchemaPayloadTestCase<OpenAiCompletionPayload> {
@Override
protected OpenAiCompletionPayload payload() {
return new OpenAiCompletionPayload();
}
@Override
protected String expectedApi() {
return "openai";
}
@Override
protected Set<TaskType> expectedSupportedTaskTypes() {
return Set.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
}
@Override
protected SageMakerStoredServiceSchema randomApiServiceSettings() {
return SageMakerStoredServiceSchema.NO_OP;
}
@Override
protected SageMakerStoredTaskSchema randomApiTaskSettings() {
return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings();
}
public void testRequest() throws Exception {
var sdkByes = payload.requestBytes(mockModel("coolPerson"), request(false));
assertJsonSdkBytes(sdkByes, """
{
"messages": [
{
"content": "hello",
"role": "user"
}
],
"n": 1,
"stream": false,
"user": "coolPerson"
}""");
}
public void testRequestWithoutUser() throws Exception {
var sdkByes = payload.requestBytes(mockModel(null), request(false));
assertJsonSdkBytes(sdkByes, """
{
"messages": [
{
"content": "hello",
"role": "user"
}
],
"n": 1,
"stream": false
}""");
}
public void testStreamRequest() throws Exception {
var sdkByes = payload.requestBytes(mockModel("user"), request(true));
assertJsonSdkBytes(sdkByes, """
{
"messages":[
{
"content": "hello",
"role": "user"
}
],
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
},
"user": "user"
}""");
}
public void testStreamRequestWithoutUser() throws Exception {
var sdkByes = payload.requestBytes(mockModel(null), request(true));
assertJsonSdkBytes(sdkByes, """
{
"messages":[
{
"content": "hello",
"role": "user"
}
],
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
}
}""");
}
private SageMakerInferenceRequest request(boolean stream) {
return new SageMakerInferenceRequest(null, null, null, List.of("hello"), stream, InputType.UNSPECIFIED);
}
private SageMakerModel mockModel(String user) {
SageMakerModel model = mock();
when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings(user));
return model;
}
public void testResponse() throws Exception {
var responseJson = """
{
"id": "some-id",
"object": "chat.completion",
"created": 1705397787,
"model": "gpt-3.5-turbo-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "result"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 46,
"completion_tokens": 39,
"total_tokens": 85
},
"system_fingerprint": null
}
""";
var chatCompletionResults = (ChatCompletionResults) payload.responseBody(
mockModel(),
InvokeEndpointResponse.builder().body(SdkBytes.fromUtf8String(responseJson)).build()
);
assertThat(chatCompletionResults.getResults().size(), is(1));
assertThat(chatCompletionResults.getResults().get(0).content(), is("result"));
}
public void testStreamResponse() throws Exception {
var responseJson = dataPayload("""
{
"id":"12345",
"object":"chat.completion.chunk",
"created":123456789,
"model":"gpt-4o-mini",
"system_fingerprint": "123456789",
"choices":[
{
"index":0,
"delta":{
"content":"test"
},
"logprobs":null,
"finish_reason":null
}
]
}
""");
var streamingResults = payload.streamResponseBody(mockModel(), responseJson);
assertThat(streamingResults.results().size(), is(1));
assertThat(streamingResults.results().iterator().next().delta(), is("test"));
}
private SdkBytes dataPayload(String json) throws IOException {
return SdkBytes.fromUtf8String("data: " + XContentHelper.stripWhitespace(json) + "\n\n");
}
private SageMakerModel mockModel() {
SageMakerModel model = mock();
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
return model;
}
public void testChatCompletionRequest() throws Exception {
var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), "user", null, null);
var unifiedRequest = new UnifiedCompletionRequest(List.of(message), null, null, null, null, null, null, null);
var sdkBytes = payload.chatCompletionRequestBytes(mockModel("coolUser"), unifiedRequest);
assertJsonSdkBytes(sdkBytes, """
{
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
],
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
},
"user": "coolUser"
}
""");
}
public void testChatCompletionResponse() throws Exception {
var responseJson = """
{
"id": "chunk1",
"choices": [
{
"delta": {
"content": "example_content",
"refusal": "example_refusal",
"role": "assistant",
"tool_calls": [
{
"index": 1,
"id": "tool1",
"function": {
"arguments": "example_arguments",
"name": "example_function"
},
"type": "function"
}
]
},
"finish_reason": "example_reason",
"index": 0
}
],
"model": "example_model",
"object": "example_object",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 5,
"total_tokens": 15
}
}
""";
var chatCompletionResponse = payload.chatCompletionResponseBody(mockModel(), dataPayload(responseJson));
XContentBuilder builder = JsonXContent.contentBuilder();
chatCompletionResponse.toXContentChunked(null).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, null);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
assertEquals(XContentHelper.stripWhitespace(responseJson), Strings.toString(builder).trim());
}
}

View File

@ -65,7 +65,7 @@ public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestC
public void testRequestWithSingleInput() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
var sdkByes = payload.requestBytes(model, request);
@ -76,7 +76,7 @@ public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestC
public void testRequestWithArrayInput() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(
null,
null,
@ -94,7 +94,7 @@ public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestC
public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(
null,
null,
@ -112,7 +112,7 @@ public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestC
public void testRequestWithOptionals() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user"));
when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings("user"));
var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
var sdkByes = payload.requestBytes(model, request);

View File

@ -13,26 +13,26 @@ import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
import java.util.Map;
public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiTaskSettings> {
public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase<SageMakerOpenAiTaskSettings> {
@Override
protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
protected SageMakerOpenAiTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
var validationException = new ValidationException();
var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException);
var settings = SageMakerOpenAiTaskSettings.fromMap(mutableMap, validationException);
validationException.throwIfValidationErrorsExist();
return settings;
}
@Override
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiTaskSettings> instanceReader() {
return OpenAiTextEmbeddingPayload.ApiTaskSettings::new;
protected Writeable.Reader<SageMakerOpenAiTaskSettings> instanceReader() {
return SageMakerOpenAiTaskSettings::new;
}
@Override
protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() {
protected SageMakerOpenAiTaskSettings createTestInstance() {
return randomApiTaskSettings();
}
static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() {
return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString());
static SageMakerOpenAiTaskSettings randomApiTaskSettings() {
return new SageMakerOpenAiTaskSettings(randomOptionalString());
}
}