[ML] Integrate with DeepSeek API (#122218)
Integrating for Chat Completion and Completion task types, both calling the chat completion API for DeepSeek.
This commit is contained in:
parent
d553455092
commit
9f89a3b318
|
@ -0,0 +1,5 @@
|
|||
pr: 122218
|
||||
summary: Integrate with `DeepSeek` API
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -147,6 +147,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
|
||||
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
|
||||
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
|
||||
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
||||
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
||||
|
@ -183,6 +184,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
|
||||
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
|
||||
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithoutTaskType() throws IOException {
|
||||
List<Object> services = getAllServices();
|
||||
assertThat(services.size(), equalTo(20));
|
||||
assertThat(services.size(), equalTo(21));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -41,6 +41,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"elastic",
|
||||
"elasticsearch",
|
||||
"googleaistudio",
|
||||
|
@ -114,7 +115,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.COMPLETION);
|
||||
assertThat(services.size(), equalTo(9));
|
||||
assertThat(services.size(), equalTo(10));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -130,6 +131,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"googleaistudio",
|
||||
"openai",
|
||||
"streaming_completion_test_service"
|
||||
|
@ -141,7 +143,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||
assertThat(services.size(), equalTo(3));
|
||||
assertThat(services.size(), equalTo(4));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -149,7 +151,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
|
||||
assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
|
||||
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
|
@ -58,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbedd
|
|||
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
|
||||
|
@ -153,6 +154,7 @@ public class InferenceNamedWriteablesProvider {
|
|||
addUnifiedNamedWriteables(namedWriteables);
|
||||
|
||||
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
|
||||
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -116,6 +116,7 @@ import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService;
|
|||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
|
||||
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
|
||||
|
@ -362,6 +363,7 @@ public class InferencePlugin extends Plugin
|
|||
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
|
||||
ElasticsearchInternalService::new
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.deepseek;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
|
||||
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
public class DeepSeekChatCompletionRequest implements Request {
|
||||
private static final String MODEL_FIELD = "model";
|
||||
private static final String MAX_TOKENS = "max_tokens";
|
||||
|
||||
private final DeepSeekChatCompletionModel model;
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
|
||||
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(model.uri());
|
||||
|
||||
httpPost.setEntity(createEntity());
|
||||
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
|
||||
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
private ByteArrayEntity createEntity() {
|
||||
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
|
||||
try (var builder = JsonXContent.contentBuilder()) {
|
||||
builder.startObject();
|
||||
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
builder.field(MODEL_FIELD, modelId);
|
||||
|
||||
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
|
||||
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
|
||||
} catch (IOException e) {
|
||||
throw new ElasticsearchException("Failed to serialize request payload.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return model.uri();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStreaming() {
|
||||
return unifiedChatInput.stream();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
|
||||
|
||||
public class DeepSeekRequestManager extends BaseRequestManager {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);
|
||||
|
||||
private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
|
||||
private static final ResponseHandler COMPLETION = createCompletionHandler();
|
||||
|
||||
private final DeepSeekChatCompletionModel model;
|
||||
|
||||
public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
|
||||
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(
|
||||
InferenceInputs inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
switch (inferenceInputs) {
|
||||
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
|
||||
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
|
||||
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
|
||||
}
|
||||
}
|
||||
|
||||
private void execute(
|
||||
UnifiedChatInput inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
private void execute(
|
||||
ChatCompletionInput inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
|
||||
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
|
||||
private static ResponseHandler createChatCompletionHandler() {
|
||||
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
private static ResponseHandler createCompletionHandler() {
|
||||
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
|
||||
}
|
||||
}
|
|
@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
|
|||
|
||||
public static final String USER_FIELD = "user";
|
||||
private static final String MODEL_FIELD = "model";
|
||||
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
|
||||
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
private final OpenAiChatCompletionModel model;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
|
||||
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
|
@ -41,6 +44,10 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
|
|||
builder.field(USER_FIELD, model.getTaskSettings().user());
|
||||
}
|
||||
|
||||
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
|
||||
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
|
|
|
@ -17,12 +17,15 @@ import java.util.Objects;
|
|||
|
||||
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
|
||||
private static final String MODEL_FIELD = "model";
|
||||
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
|
||||
|
||||
private final UnifiedChatInput unifiedChatInput;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
private final String modelId;
|
||||
|
||||
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
|
||||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.modelId = Objects.requireNonNull(modelId);
|
||||
}
|
||||
|
||||
|
@ -31,6 +34,11 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implement
|
|||
builder.startObject();
|
||||
unifiedRequestEntity.toXContent(builder, params);
|
||||
builder.field(MODEL_FIELD, modelId);
|
||||
|
||||
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
|
||||
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
|
|
|
@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
|
|||
public static final String MESSAGES_FIELD = "messages";
|
||||
private static final String ROLE_FIELD = "role";
|
||||
private static final String CONTENT_FIELD = "content";
|
||||
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
|
||||
private static final String STOP_FIELD = "stop";
|
||||
private static final String TEMPERATURE_FIELD = "temperature";
|
||||
private static final String TOOL_CHOICE_FIELD = "tool_choice";
|
||||
|
@ -104,10 +103,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
|
|||
}
|
||||
builder.endArray();
|
||||
|
||||
if (unifiedRequest.maxCompletionTokens() != null) {
|
||||
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
|
||||
}
|
||||
|
||||
// Underlying providers expect OpenAI to only return 1 possible choice.
|
||||
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.deepseek;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||
|
||||
/**
|
||||
* Design notes:
|
||||
* This provider tries to match the OpenAI, so we'll design around that as well.
|
||||
*
|
||||
* Task Type:
|
||||
* - Chat Completion
|
||||
*
|
||||
* Service Settings:
|
||||
* - api_key
|
||||
* - model
|
||||
* - url
|
||||
*
|
||||
* Task Settings:
|
||||
* - nothing?
|
||||
*
|
||||
* Rate Limiting:
|
||||
* - The website claims to want unlimited, so we're setting it as MAX_INT per minute?
|
||||
*/
|
||||
public class DeepSeekChatCompletionModel extends Model {
|
||||
// Per-node rate limit group and settings, limiting the outbound requests this node can make to INTEGER.MAX_VALUE per minute.
|
||||
private static final Object RATE_LIMIT_GROUP = new Object();
|
||||
private static final RateLimitSettings RATE_LIMIT_SETTINGS = new RateLimitSettings(Integer.MAX_VALUE);
|
||||
|
||||
private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions");
|
||||
private final DeepSeekServiceSettings serviceSettings;
|
||||
private final DefaultSecretSettings secretSettings;
|
||||
|
||||
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, DeepSeekServiceSettings.NAME, DeepSeekServiceSettings::new));
|
||||
}
|
||||
|
||||
public static DeepSeekChatCompletionModel createFromNewInput(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettingsMap
|
||||
) {
|
||||
var validationException = new ValidationException();
|
||||
|
||||
var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
var uri = createOptionalUri(
|
||||
extractOptionalString(serviceSettingsMap, URL, ModelConfigurations.SERVICE_SETTINGS, validationException)
|
||||
);
|
||||
var secureApiToken = extractRequiredSecureString(
|
||||
serviceSettingsMap,
|
||||
"api_key",
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
var serviceSettings = new DeepSeekServiceSettings(model, uri);
|
||||
var taskSettings = new EmptyTaskSettings();
|
||||
var secretSettings = new DefaultSecretSettings(secureApiToken);
|
||||
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||
return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings));
|
||||
}
|
||||
|
||||
public static DeepSeekChatCompletionModel readFromStorage(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettingsMap,
|
||||
Map<String, Object> secrets
|
||||
) {
|
||||
var validationException = new ValidationException();
|
||||
|
||||
var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
var uri = createOptionalUri(
|
||||
extractOptionalString(serviceSettingsMap, "url", ModelConfigurations.SERVICE_SETTINGS, validationException)
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
var serviceSettings = new DeepSeekServiceSettings(model, uri);
|
||||
var taskSettings = new EmptyTaskSettings();
|
||||
var secretSettings = DefaultSecretSettings.fromMap(secrets);
|
||||
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||
return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings));
|
||||
}
|
||||
|
||||
private DeepSeekChatCompletionModel(
|
||||
DeepSeekServiceSettings serviceSettings,
|
||||
DefaultSecretSettings secretSettings,
|
||||
ModelConfigurations configurations,
|
||||
ModelSecrets secrets
|
||||
) {
|
||||
super(configurations, secrets);
|
||||
this.serviceSettings = serviceSettings;
|
||||
this.secretSettings = secretSettings;
|
||||
}
|
||||
|
||||
public SecureString apiKey() {
|
||||
return secretSettings.apiKey();
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return serviceSettings.modelId();
|
||||
}
|
||||
|
||||
public URI uri() {
|
||||
return serviceSettings.uri() != null ? serviceSettings.uri() : DEFAULT_URI;
|
||||
}
|
||||
|
||||
public Object rateLimitGroup() {
|
||||
return RATE_LIMIT_GROUP;
|
||||
}
|
||||
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return RATE_LIMIT_SETTINGS;
|
||||
}
|
||||
|
||||
private record DeepSeekServiceSettings(String modelId, URI uri) implements ServiceSettings {
|
||||
private static final String NAME = "deep_seek_service_settings";
|
||||
|
||||
DeepSeekServiceSettings {
|
||||
Objects.requireNonNull(modelId);
|
||||
}
|
||||
|
||||
DeepSeekServiceSettings(StreamInput in) throws IOException {
|
||||
this(in.readString(), in.readOptional(url -> URI.create(url.readString())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_DEEPSEEK;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
out.writeOptionalString(uri != null ? uri.toString() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ToXContentObject getFilteredXContentObject() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MODEL_ID, modelId);
|
||||
if (uri != null) {
|
||||
builder.field(URL, uri.toString());
|
||||
}
|
||||
return builder.endObject();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,233 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.deepseek;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.Strings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DeepSeekRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.SenderService;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
|
||||
public class DeepSeekService extends SenderService {
|
||||
private static final String NAME = "deepseek";
|
||||
private static final String CHAT_COMPLETION_ERROR_PREFIX = "deepseek chat completions";
|
||||
private static final String COMPLETION_ERROR_PREFIX = "deepseek completions";
|
||||
private static final String SERVICE_NAME = "DeepSeek";
|
||||
// The task types exposed via the _inference/_services API
|
||||
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
|
||||
TaskType.COMPLETION,
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
|
||||
public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doInfer(
|
||||
Model model,
|
||||
InferenceInputs inputs,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
doInfer(model, inputs, timeout, COMPLETION_ERROR_PREFIX, listener);
|
||||
}
|
||||
|
||||
private void doInfer(
|
||||
Model model,
|
||||
InferenceInputs inputs,
|
||||
TimeValue timeout,
|
||||
String errorPrefix,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
|
||||
var requestCreator = new DeepSeekRequestManager(deepSeekModel, getServiceComponents().threadPool());
|
||||
var errorMessage = constructFailedToSendRequestMessage(errorPrefix);
|
||||
var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage);
|
||||
action.execute(inputs, timeout, listener);
|
||||
} else {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doUnifiedCompletionInfer(
|
||||
Model model,
|
||||
UnifiedChatInput inputs,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
doInfer(model, inputs, timeout, CHAT_COMPLETION_ERROR_PREFIX, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doChunkedInfer(
|
||||
Model model,
|
||||
DocumentsOnlyInput inputs,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<List<ChunkedInference>> listener
|
||||
) {
|
||||
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parseRequestConfig(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
ActionListener<Model> parsedModelListener
|
||||
) {
|
||||
ActionListener.completeWith(parsedModelListener, () -> {
|
||||
var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
try {
|
||||
return DeepSeekChatCompletionModel.createFromNewInput(modelId, taskType, NAME, serviceSettingsMap);
|
||||
} finally {
|
||||
throwIfNotEmptyMap(serviceSettingsMap, NAME);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfigWithSecrets(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
Map<String, Object> secrets
|
||||
) {
|
||||
var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
var secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
||||
return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, secretSettingsMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
||||
return parsePersistedConfigWithSecrets(modelId, taskType, config, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceServiceConfiguration getConfiguration() {
|
||||
return Configuration.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EnumSet<TaskType> supportedTaskTypes() {
|
||||
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_DEEPSEEK;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return EnumSet.of(TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkModelConfig(Model model, ActionListener<Model> listener) {
|
||||
// TODO: Remove this function once all services have been updated to use the new model validators
|
||||
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
|
||||
}
|
||||
|
||||
private static class Configuration {
|
||||
public static InferenceServiceConfiguration get() {
|
||||
return configuration.getOrCompute();
|
||||
}
|
||||
|
||||
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
|
||||
() -> {
|
||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
||||
|
||||
configurationMap.put(
|
||||
MODEL_ID,
|
||||
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription(
|
||||
"The name of the model to use for the inference task."
|
||||
)
|
||||
.setLabel("Model ID")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
);
|
||||
|
||||
configurationMap.putAll(
|
||||
DefaultSecretSettings.toSettingsConfigurationWithDescription(
|
||||
"The DeepSeek API authentication key. For more details about generating DeepSeek API keys, "
|
||||
+ "refer to https://api-docs.deepseek.com.",
|
||||
SUPPORTED_TASK_TYPES_FOR_SERVICES_API
|
||||
)
|
||||
);
|
||||
|
||||
configurationMap.put(
|
||||
URL,
|
||||
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDefaultValue(
|
||||
"https://api.deepseek.com/chat/completions"
|
||||
)
|
||||
.setDescription("The URL endpoint to use for the requests.")
|
||||
.setLabel("URL")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
);
|
||||
|
||||
return new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||
.setName(SERVICE_NAME)
|
||||
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
|
||||
.setConfigurations(configurationMap)
|
||||
.build();
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,441 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.deepseek;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
|
||||
import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener;
|
||||
import static org.elasticsearch.common.Strings.format;
|
||||
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class DeepSeekServiceTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testParseRequestConfig() throws IOException, URISyntaxException {
|
||||
parseRequestConfig(format("""
|
||||
{
|
||||
"service_settings": {
|
||||
"api_key": "12345",
|
||||
"model_id": "some-cool-model",
|
||||
"url": "%s"
|
||||
}
|
||||
}
|
||||
""", webServer.getUri(null).toString()), assertNoFailureListener(model -> {
|
||||
if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
|
||||
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
|
||||
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
|
||||
assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null)));
|
||||
} else {
|
||||
fail("Expected DeepSeekModel, found " + (model != null ? model.getClass().getSimpleName() : "null"));
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
public void testParseRequestConfigWithoutApiKey() throws IOException {
|
||||
parseRequestConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
"model_id": "some-cool-model"
|
||||
}
|
||||
}
|
||||
""", assertNoSuccessListener(e -> {
|
||||
if (e instanceof ValidationException ve) {
|
||||
assertThat(
|
||||
ve.getMessage(),
|
||||
equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [api_key];")
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
public void testParseRequestConfigWithoutModel() throws IOException {
|
||||
parseRequestConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
"api_key": "1234"
|
||||
}
|
||||
}
|
||||
""", assertNoSuccessListener(e -> {
|
||||
if (e instanceof ValidationException ve) {
|
||||
assertThat(
|
||||
ve.getMessage(),
|
||||
equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
public void testParseRequestConfigWithExtraSettings() throws IOException {
|
||||
parseRequestConfig(
|
||||
"""
|
||||
{
|
||||
"service_settings": {
|
||||
"api_key": "12345",
|
||||
"model_id": "some-cool-model",
|
||||
"so": "extra"
|
||||
}
|
||||
}
|
||||
""",
|
||||
assertNoSuccessListener(
|
||||
e -> assertThat(
|
||||
e.getMessage(),
|
||||
equalTo("Model configuration contains settings [{so=extra}] unknown to the [deepseek] service")
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig() throws IOException {
|
||||
var deepSeekModel = parsePersistedConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
"model_id": "some-cool-model"
|
||||
},
|
||||
"secret_settings": {
|
||||
"api_key": "12345"
|
||||
}
|
||||
}
|
||||
""");
|
||||
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
|
||||
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithUrl() throws IOException {
|
||||
var deepSeekModel = parsePersistedConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
"model_id": "some-cool-model",
|
||||
"url": "http://localhost:989"
|
||||
},
|
||||
"secret_settings": {
|
||||
"api_key": "12345"
|
||||
}
|
||||
}
|
||||
""");
|
||||
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
|
||||
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
|
||||
assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989")));
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithoutApiKey() {
|
||||
assertThrows(
|
||||
"Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];",
|
||||
ValidationException.class,
|
||||
() -> parsePersistedConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
"model_id": "some-cool-model"
|
||||
},
|
||||
"secret_settings": {
|
||||
}
|
||||
}
|
||||
""")
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithoutModel() {
|
||||
assertThrows(
|
||||
"Validation Failed: 1: [service_settings] does not contain the required setting [model];",
|
||||
ValidationException.class,
|
||||
() -> parsePersistedConfig("""
|
||||
{
|
||||
"service_settings": {
|
||||
},
|
||||
"secret_settings": {
|
||||
"api_key": "12345"
|
||||
}
|
||||
}
|
||||
""")
|
||||
);
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithoutServiceSettings() {
|
||||
assertThrows(
|
||||
"Validation Failed: 1: [service_settings] does not contain the required setting [model];",
|
||||
ElasticsearchStatusException.class,
|
||||
() -> parsePersistedConfig("""
|
||||
{
|
||||
"secret_settings": {
|
||||
"api_key": "12345"
|
||||
}
|
||||
}
|
||||
""")
|
||||
);
|
||||
}
|
||||
|
||||
public void testDoUnifiedInfer() throws Exception {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||
data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \
|
||||
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
|
||||
"object": "chat.completion.chunk", "system_fingerprint": "fp_1234"}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
"""));
|
||||
doUnifiedCompletionInfer().hasNoErrors().hasEvent("""
|
||||
{"id":"12345","choices":[{"delta":{"content":"hello, world","role":"assistant"},"index":0}],""" + """
|
||||
"model":"deepseek-chat","object":"chat.completion.chunk"}""");
|
||||
}
|
||||
|
||||
public void testDoInfer() throws Exception {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||
{"choices": [{"message": {"content": "hello, world", "role": "assistant"}, "finish_reason": "stop", "index": 0, \
|
||||
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
|
||||
"object": "chat.completion", "system_fingerprint": "fp_1234"}"""));
|
||||
try (var service = createService()) {
|
||||
var model = createModel(service, TaskType.COMPLETION);
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
assertThat(result, isA(ChatCompletionResults.class));
|
||||
var completionResults = (ChatCompletionResults) result;
|
||||
assertThat(
|
||||
completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(),
|
||||
equalTo(List.of("hello, world"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testDoInferStream() throws Exception {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||
data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \
|
||||
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
|
||||
"object": "chat.completion.chunk", "system_fingerprint": "fp_1234"}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
"""));
|
||||
try (var service = createService()) {
|
||||
var model = createModel(service, TaskType.COMPLETION);
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
||||
InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionError() {
|
||||
String responseJson = """
|
||||
{
|
||||
"error": {
|
||||
"message": "The model `deepseek-not-chat` does not exist or you do not have access to it.",
|
||||
"type": "invalid_request_error",
|
||||
"param": null,
|
||||
"code": "model_not_found"
|
||||
}
|
||||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
|
||||
var e = assertThrows(UnifiedChatCompletionException.class, this::doUnifiedCompletionInfer);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an unsuccessful status code for request from inference entity id [inference-id] status"
|
||||
+ " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private void testStreamError(String expectedResponse) throws Exception {
|
||||
try (var service = createService()) {
|
||||
var model = createModel(service, TaskType.CHAT_COMPLETION);
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
|
||||
e = unwrapCause(e);
|
||||
assertThat(e, isA(UnifiedChatCompletionException.class));
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is(expectedResponse));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionError() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"message":"Received an error response for request from inference entity id [inference-id]. Error message: \
|
||||
[Timed out waiting for more data]",\
|
||||
"type":"timeout"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionMalformedError() throws Exception {
|
||||
String responseJson = """
|
||||
data: { invalid json }
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"bad_request",\
|
||||
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
|
||||
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
|
||||
"type":"x_content_parse_exception"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testDoChunkedInferAlwaysFails() throws IOException {
|
||||
try (var service = createService()) {
|
||||
service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> {
|
||||
assertThat(e, isA(UnsupportedOperationException.class));
|
||||
assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion"));
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
private DeepSeekService createService() {
|
||||
return new DeepSeekService(
|
||||
HttpRequestSenderTests.createSenderFactory(threadPool, clientManager),
|
||||
createWithEmptySettings(threadPool)
|
||||
);
|
||||
}
|
||||
|
||||
private void parseRequestConfig(String json, ActionListener<Model> listener) throws IOException {
|
||||
try (var service = createService()) {
|
||||
service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> map(String json) throws IOException {
|
||||
try (
|
||||
var parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8))
|
||||
) {
|
||||
return parser.map();
|
||||
}
|
||||
}
|
||||
|
||||
private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOException {
|
||||
try (var service = createService()) {
|
||||
var model = service.parsePersistedConfig("inference-id", TaskType.CHAT_COMPLETION, map(json));
|
||||
assertThat(model, isA(DeepSeekChatCompletionModel.class));
|
||||
return (DeepSeekChatCompletionModel) model;
|
||||
}
|
||||
}
|
||||
|
||||
private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception {
|
||||
try (var service = createService()) {
|
||||
var model = createModel(service, TaskType.CHAT_COMPLETION);
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
TIMEOUT,
|
||||
listener
|
||||
);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException {
|
||||
var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format("""
|
||||
{
|
||||
"service_settings": {
|
||||
"model_id": "some-cool-model",
|
||||
"url": "%s"
|
||||
},
|
||||
"secret_settings": {
|
||||
"api_key": "12345"
|
||||
}
|
||||
}
|
||||
""", webServer.getUri(null).toString())));
|
||||
assertThat(model, isA(DeepSeekChatCompletionModel.class));
|
||||
return (DeepSeekChatCompletionModel) model;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue