[ML] Integrate with DeepSeek API (#122218)

Integrating for Chat Completion and Completion task types, both calling
the chat completion API for DeepSeek.
This commit is contained in:
Pat Whelan 2025-03-12 10:24:39 -04:00 committed by GitHub
parent d553455092
commit 9f89a3b318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1089 additions and 11 deletions

View File

@ -0,0 +1,5 @@
pr: 122218
summary: Integrate with `DeepSeek` API
area: Machine Learning
type: enhancement
issues: []

View File

@ -147,6 +147,7 @@ public class TransportVersions {
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -183,6 +184,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(20));
assertThat(services.size(), equalTo(21));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -41,6 +41,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
@ -114,7 +115,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -130,6 +131,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
@ -141,7 +143,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(3));
assertThat(services.size(), equalTo(4));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -149,7 +151,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
providers[i] = (String) serviceConfig.get("service");
}
assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
}
@SuppressWarnings("unchecked")

View File

@ -58,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbedd
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
@ -153,6 +154,7 @@ public class InferenceNamedWriteablesProvider {
addUnifiedNamedWriteables(namedWriteables);
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
return namedWriteables;
}

View File

@ -116,6 +116,7 @@ import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@ -362,6 +363,7 @@ public class InferencePlugin extends Plugin
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}

View File

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.deepseek;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
public class DeepSeekChatCompletionRequest implements Request {
private static final String MODEL_FIELD = "model";
private static final String MAX_TOKENS = "max_tokens";
private final DeepSeekChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());
httpPost.setEntity(createEntity());
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
return new HttpRequest(httpPost, getInferenceEntityId());
}
private ByteArrayEntity createEntity() {
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
try (var builder = JsonXContent.contentBuilder()) {
builder.startObject();
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.field(MODEL_FIELD, modelId);
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
}
builder.endObject();
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new ElasticsearchException("Failed to serialize request payload.", e);
}
}
@Override
public URI getURI() {
return model.uri();
}
@Override
public Request truncate() {
return this;
}
@Override
public boolean[] getTruncationInfo() {
return null;
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import java.util.Objects;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
public class DeepSeekRequestManager extends BaseRequestManager {
private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);
private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
private static final ResponseHandler COMPLETION = createCompletionHandler();
private final DeepSeekChatCompletionModel model;
public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}
@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
switch (inferenceInputs) {
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
}
}
private void execute(
UnifiedChatInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
}
private void execute(
ChatCompletionInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
}
private static ResponseHandler createChatCompletionHandler() {
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
private static ResponseHandler createCompletionHandler() {
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
}

View File

@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
public static final String USER_FIELD = "user";
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private final UnifiedChatInput unifiedChatInput;
private final OpenAiChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}
@ -41,6 +44,10 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
builder.field(USER_FIELD, model.getTaskSettings().user());
}
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}
builder.endObject();
return builder;

View File

@ -17,12 +17,15 @@ import java.util.Objects;
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private final UnifiedChatInput unifiedChatInput;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.modelId = Objects.requireNonNull(modelId);
}
@ -31,6 +34,11 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implement
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
builder.field(MODEL_FIELD, modelId);
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}
builder.endObject();
return builder;

View File

@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
public static final String MESSAGES_FIELD = "messages";
private static final String ROLE_FIELD = "role";
private static final String CONTENT_FIELD = "content";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private static final String STOP_FIELD = "stop";
private static final String TEMPERATURE_FIELD = "temperature";
private static final String TOOL_CHOICE_FIELD = "tool_choice";
@ -104,10 +103,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
}
builder.endArray();
if (unifiedRequest.maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
}
// Underlying providers expect OpenAI to only return 1 possible choice.
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

View File

@ -0,0 +1,200 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.deepseek;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
/**
* Design notes:
* This provider tries to match the OpenAI, so we'll design around that as well.
*
* Task Type:
* - Chat Completion
*
* Service Settings:
* - api_key
* - model
* - url
*
* Task Settings:
* - nothing?
*
* Rate Limiting:
* - The website claims to want unlimited, so we're setting it as MAX_INT per minute?
*/
public class DeepSeekChatCompletionModel extends Model {
// Per-node rate limit group and settings, limiting the outbound requests this node can make to INTEGER.MAX_VALUE per minute.
private static final Object RATE_LIMIT_GROUP = new Object();
private static final RateLimitSettings RATE_LIMIT_SETTINGS = new RateLimitSettings(Integer.MAX_VALUE);
private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions");
private final DeepSeekServiceSettings serviceSettings;
private final DefaultSecretSettings secretSettings;
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, DeepSeekServiceSettings.NAME, DeepSeekServiceSettings::new));
}
public static DeepSeekChatCompletionModel createFromNewInput(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettingsMap
) {
var validationException = new ValidationException();
var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
var uri = createOptionalUri(
extractOptionalString(serviceSettingsMap, URL, ModelConfigurations.SERVICE_SETTINGS, validationException)
);
var secureApiToken = extractRequiredSecureString(
serviceSettingsMap,
"api_key",
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
var serviceSettings = new DeepSeekServiceSettings(model, uri);
var taskSettings = new EmptyTaskSettings();
var secretSettings = new DefaultSecretSettings(secureApiToken);
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings));
}
public static DeepSeekChatCompletionModel readFromStorage(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettingsMap,
Map<String, Object> secrets
) {
var validationException = new ValidationException();
var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
var uri = createOptionalUri(
extractOptionalString(serviceSettingsMap, "url", ModelConfigurations.SERVICE_SETTINGS, validationException)
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
var serviceSettings = new DeepSeekServiceSettings(model, uri);
var taskSettings = new EmptyTaskSettings();
var secretSettings = DefaultSecretSettings.fromMap(secrets);
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings));
}
private DeepSeekChatCompletionModel(
DeepSeekServiceSettings serviceSettings,
DefaultSecretSettings secretSettings,
ModelConfigurations configurations,
ModelSecrets secrets
) {
super(configurations, secrets);
this.serviceSettings = serviceSettings;
this.secretSettings = secretSettings;
}
public SecureString apiKey() {
return secretSettings.apiKey();
}
public String model() {
return serviceSettings.modelId();
}
public URI uri() {
return serviceSettings.uri() != null ? serviceSettings.uri() : DEFAULT_URI;
}
public Object rateLimitGroup() {
return RATE_LIMIT_GROUP;
}
public RateLimitSettings rateLimitSettings() {
return RATE_LIMIT_SETTINGS;
}
private record DeepSeekServiceSettings(String modelId, URI uri) implements ServiceSettings {
private static final String NAME = "deep_seek_service_settings";
DeepSeekServiceSettings {
Objects.requireNonNull(modelId);
}
DeepSeekServiceSettings(StreamInput in) throws IOException {
this(in.readString(), in.readOptional(url -> URI.create(url.readString())));
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_DEEPSEEK;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalString(uri != null ? uri.toString() : null);
}
@Override
public ToXContentObject getFilteredXContentObject() {
return this;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID, modelId);
if (uri != null) {
builder.field(URL, uri.toString());
}
return builder.endObject();
}
}
}

View File

@ -0,0 +1,233 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.deepseek;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.DeepSeekRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
public class DeepSeekService extends SenderService {
private static final String NAME = "deepseek";
private static final String CHAT_COMPLETION_ERROR_PREFIX = "deepseek chat completions";
private static final String COMPLETION_ERROR_PREFIX = "deepseek completions";
private static final String SERVICE_NAME = "DeepSeek";
// The task types exposed via the _inference/_services API
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
@Override
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
doInfer(model, inputs, timeout, COMPLETION_ERROR_PREFIX, listener);
}
private void doInfer(
Model model,
InferenceInputs inputs,
TimeValue timeout,
String errorPrefix,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
var requestCreator = new DeepSeekRequestManager(deepSeekModel, getServiceComponents().threadPool());
var errorMessage = constructFailedToSendRequestMessage(errorPrefix);
var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage);
action.execute(inputs, timeout, listener);
} else {
listener.onFailure(createInvalidModelException(model));
}
}
@Override
protected void doUnifiedCompletionInfer(
Model model,
UnifiedChatInput inputs,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
doInfer(model, inputs, timeout, CHAT_COMPLETION_ERROR_PREFIX, listener);
}
@Override
protected void doChunkedInfer(
Model model,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
}
@Override
public String name() {
return NAME;
}
@Override
public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
ActionListener<Model> parsedModelListener
) {
ActionListener.completeWith(parsedModelListener, () -> {
var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
try {
return DeepSeekChatCompletionModel.createFromNewInput(modelId, taskType, NAME, serviceSettingsMap);
} finally {
throwIfNotEmptyMap(serviceSettingsMap, NAME);
}
});
}
@Override
public Model parsePersistedConfigWithSecrets(
String modelId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
var secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, secretSettingsMap);
}
@Override
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
return parsePersistedConfigWithSecrets(modelId, taskType, config, config);
}
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
}
@Override
public EnumSet<TaskType> supportedTaskTypes() {
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_DEEPSEEK;
}
@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
}
@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
// TODO: Remove this function once all services have been updated to use the new model validators
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
}
private static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
}
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription(
"The name of the model to use for the inference task."
)
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);
configurationMap.putAll(
DefaultSecretSettings.toSettingsConfigurationWithDescription(
"The DeepSeek API authentication key. For more details about generating DeepSeek API keys, "
+ "refer to https://api-docs.deepseek.com.",
SUPPORTED_TASK_TYPES_FOR_SERVICES_API
)
);
configurationMap.put(
URL,
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDefaultValue(
"https://api.deepseek.com/chat/completions"
)
.setDescription("The URL endpoint to use for the requests.")
.setLabel("URL")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);
return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
.setConfigurations(configurationMap)
.build();
}
);
}
}

View File

@ -0,0 +1,441 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.deepseek;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener;
import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.isA;
import static org.mockito.Mockito.mock;
public class DeepSeekServiceTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testParseRequestConfig() throws IOException, URISyntaxException {
parseRequestConfig(format("""
{
"service_settings": {
"api_key": "12345",
"model_id": "some-cool-model",
"url": "%s"
}
}
""", webServer.getUri(null).toString()), assertNoFailureListener(model -> {
if (model instanceof DeepSeekChatCompletionModel deepSeekModel) {
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null)));
} else {
fail("Expected DeepSeekModel, found " + (model != null ? model.getClass().getSimpleName() : "null"));
}
}));
}
public void testParseRequestConfigWithoutApiKey() throws IOException {
parseRequestConfig("""
{
"service_settings": {
"model_id": "some-cool-model"
}
}
""", assertNoSuccessListener(e -> {
if (e instanceof ValidationException ve) {
assertThat(
ve.getMessage(),
equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [api_key];")
);
}
}));
}
public void testParseRequestConfigWithoutModel() throws IOException {
parseRequestConfig("""
{
"service_settings": {
"api_key": "1234"
}
}
""", assertNoSuccessListener(e -> {
if (e instanceof ValidationException ve) {
assertThat(
ve.getMessage(),
equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];")
);
}
}));
}
public void testParseRequestConfigWithExtraSettings() throws IOException {
parseRequestConfig(
"""
{
"service_settings": {
"api_key": "12345",
"model_id": "some-cool-model",
"so": "extra"
}
}
""",
assertNoSuccessListener(
e -> assertThat(
e.getMessage(),
equalTo("Model configuration contains settings [{so=extra}] unknown to the [deepseek] service")
)
)
);
}
public void testParsePersistedConfig() throws IOException {
var deepSeekModel = parsePersistedConfig("""
{
"service_settings": {
"model_id": "some-cool-model"
},
"secret_settings": {
"api_key": "12345"
}
}
""");
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
}
public void testParsePersistedConfigWithUrl() throws IOException {
var deepSeekModel = parsePersistedConfig("""
{
"service_settings": {
"model_id": "some-cool-model",
"url": "http://localhost:989"
},
"secret_settings": {
"api_key": "12345"
}
}
""");
assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray()));
assertThat(deepSeekModel.model(), equalTo("some-cool-model"));
assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989")));
}
public void testParsePersistedConfigWithoutApiKey() {
assertThrows(
"Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];",
ValidationException.class,
() -> parsePersistedConfig("""
{
"service_settings": {
"model_id": "some-cool-model"
},
"secret_settings": {
}
}
""")
);
}
public void testParsePersistedConfigWithoutModel() {
assertThrows(
"Validation Failed: 1: [service_settings] does not contain the required setting [model];",
ValidationException.class,
() -> parsePersistedConfig("""
{
"service_settings": {
},
"secret_settings": {
"api_key": "12345"
}
}
""")
);
}
public void testParsePersistedConfigWithoutServiceSettings() {
assertThrows(
"Validation Failed: 1: [service_settings] does not contain the required setting [model];",
ElasticsearchStatusException.class,
() -> parsePersistedConfig("""
{
"secret_settings": {
"api_key": "12345"
}
}
""")
);
}
public void testDoUnifiedInfer() throws Exception {
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
"object": "chat.completion.chunk", "system_fingerprint": "fp_1234"}
data: [DONE]
"""));
doUnifiedCompletionInfer().hasNoErrors().hasEvent("""
{"id":"12345","choices":[{"delta":{"content":"hello, world","role":"assistant"},"index":0}],""" + """
"model":"deepseek-chat","object":"chat.completion.chunk"}""");
}
public void testDoInfer() throws Exception {
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
{"choices": [{"message": {"content": "hello, world", "role": "assistant"}, "finish_reason": "stop", "index": 0, \
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
"object": "chat.completion", "system_fingerprint": "fp_1234"}"""));
try (var service = createService()) {
var model = createModel(service, TaskType.COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertThat(result, isA(ChatCompletionResults.class));
var completionResults = (ChatCompletionResults) result;
assertThat(
completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(),
equalTo(List.of("hello, world"))
);
}
}
public void testDoInferStream() throws Exception {
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \
"logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \
"object": "chat.completion.chunk", "system_fingerprint": "fp_1234"}
data: [DONE]
"""));
try (var service = createService()) {
var model = createModel(service, TaskType.COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
}
public void testUnifiedCompletionError() {
String responseJson = """
{
"error": {
"message": "The model `deepseek-not-chat` does not exist or you do not have access to it.",
"type": "invalid_request_error",
"param": null,
"code": "model_not_found"
}
}""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
var e = assertThrows(UnifiedChatCompletionException.class, this::doUnifiedCompletionInfer);
assertThat(
e.getMessage(),
equalTo(
"Received an unsuccessful status code for request from inference entity id [inference-id] status"
+ " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]"
)
);
}
private void testStreamError(String expectedResponse) throws Exception {
try (var service = createService()) {
var model = createModel(service, TaskType.CHAT_COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
e = unwrapCause(e);
assertThat(e, isA(UnifiedChatCompletionException.class));
try (var builder = XContentFactory.jsonBuilder()) {
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is(expectedResponse));
}
});
}
}
public void testMidStreamUnifiedCompletionError() throws Exception {
String responseJson = """
event: error
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"message":"Received an error response for request from inference entity id [inference-id]. Error message: \
[Timed out waiting for more data]",\
"type":"timeout"\
}}""");
}
public void testUnifiedCompletionMalformedError() throws Exception {
String responseJson = """
data: { invalid json }
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"code":"bad_request",\
"message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\
at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\
"type":"x_content_parse_exception"\
}}""");
}
public void testDoChunkedInferAlwaysFails() throws IOException {
try (var service = createService()) {
service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> {
assertThat(e, isA(UnsupportedOperationException.class));
assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion"));
}));
}
}
private DeepSeekService createService() {
return new DeepSeekService(
HttpRequestSenderTests.createSenderFactory(threadPool, clientManager),
createWithEmptySettings(threadPool)
);
}
private void parseRequestConfig(String json, ActionListener<Model> listener) throws IOException {
try (var service = createService()) {
service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener);
}
}
private Map<String, Object> map(String json) throws IOException {
try (
var parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8))
) {
return parser.map();
}
}
private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOException {
try (var service = createService()) {
var model = service.parsePersistedConfig("inference-id", TaskType.CHAT_COMPLETION, map(json));
assertThat(model, isA(DeepSeekChatCompletionModel.class));
return (DeepSeekChatCompletionModel) model;
}
}
private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception {
try (var service = createService()) {
var model = createModel(service, TaskType.CHAT_COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
TIMEOUT,
listener
);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException {
var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format("""
{
"service_settings": {
"model_id": "some-cool-model",
"url": "%s"
},
"secret_settings": {
"api_key": "12345"
}
}
""", webServer.getUri(null).toString())));
assertThat(model, isA(DeepSeekChatCompletionModel.class));
return (DeepSeekChatCompletionModel) model;
}
}