feat: VoyageAI integration (#122134)
* VoyageAI embeddings and rerank: - embeddings works, tested - initial rerank code What's missing: - unit and integration tests - rerank request/response mapping and verification * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * Adding initial tests Moving dimensions to ServiceSettings * Correcting the TransportVersions.java * Correcting due to comments * Adding BIT support * Initial tests * More tests * More tests/corrections * Removing warnings * Further tests * Transport version correction * Adding changelog and correcting TransportVersions * Spotless tests * Changes due to the comments * Changes due to the comments * Correcting QA tests * Correcting QA tests --------- Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co> Co-authored-by: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com>
This commit is contained in:
parent
171a3b93f9
commit
521f8554c3
|
@ -0,0 +1,5 @@
|
|||
pr: 122134
|
||||
summary: Adding integration for VoyageAI embeddings and rerank models
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -180,6 +180,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
|
||||
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
|
||||
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
|
||||
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
||||
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
||||
|
@ -199,7 +200,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00);
|
||||
public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00);
|
||||
public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00);
|
||||
|
||||
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00);
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
||||
|
|
|
@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithoutTaskType() throws IOException {
|
||||
List<Object> services = getAllServices();
|
||||
assertThat(services.size(), equalTo(19));
|
||||
assertThat(services.size(), equalTo(20));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -53,6 +53,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"test_reranking_service",
|
||||
"test_service",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai"
|
||||
).toArray(),
|
||||
providers
|
||||
|
@ -62,7 +63,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
|
||||
assertThat(services.size(), equalTo(14));
|
||||
assertThat(services.size(), equalTo(15));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -85,6 +86,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"mistral",
|
||||
"openai",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai"
|
||||
).toArray(),
|
||||
providers
|
||||
|
@ -94,7 +96,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithRerankTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.RERANK);
|
||||
assertThat(services.size(), equalTo(6));
|
||||
assertThat(services.size(), equalTo(7));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
|
@ -103,7 +105,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
}
|
||||
|
||||
assertArrayEquals(
|
||||
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
|
||||
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
|
||||
.toArray(),
|
||||
providers
|
||||
);
|
||||
}
|
||||
|
|
|
@ -90,6 +90,11 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo
|
|||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -142,6 +147,7 @@ public class InferenceNamedWriteablesProvider {
|
|||
addEisNamedWriteables(namedWriteables);
|
||||
addAlibabaCloudSearchNamedWriteables(namedWriteables);
|
||||
addJinaAINamedWriteables(namedWriteables);
|
||||
addVoyageAINamedWriteables(namedWriteables);
|
||||
|
||||
addUnifiedNamedWriteables(namedWriteables);
|
||||
|
||||
|
@ -626,6 +632,28 @@ public class InferenceNamedWriteablesProvider {
|
|||
);
|
||||
}
|
||||
|
||||
private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
VoyageAIEmbeddingsServiceSettings.NAME,
|
||||
VoyageAIEmbeddingsServiceSettings::new
|
||||
)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
|
||||
);
|
||||
}
|
||||
|
||||
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
|
|
|
@ -128,6 +128,7 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
|||
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -359,6 +360,7 @@ public class InferencePlugin extends Plugin
|
|||
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
|
||||
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
|
||||
ElasticsearchInternalService::new
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.action.voyageai;
|
||||
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
|
||||
/**
|
||||
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
|
||||
*/
|
||||
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
|
||||
private final Sender sender;
|
||||
private final ServiceComponents serviceComponents;
|
||||
|
||||
public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
|
||||
this.sender = Objects.requireNonNull(sender);
|
||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
|
||||
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
|
||||
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
|
||||
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
|
||||
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
|
||||
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.action.voyageai;
|
||||
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface VoyageAIActionVisitor {
|
||||
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
|
||||
|
||||
ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
|
||||
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
|
||||
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
|
||||
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
|
||||
|
||||
private static ResponseHandler createEmbeddingsHandler() {
|
||||
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
|
||||
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
|
||||
}
|
||||
|
||||
private final VoyageAIEmbeddingsModel model;
|
||||
|
||||
private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
|
||||
super(threadPool, model);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(
|
||||
InferenceInputs inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);
|
||||
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
abstract class VoyageAIRequestManager extends BaseRequestManager {
|
||||
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
|
||||
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
|
||||
"voyage-multimodal-3",
|
||||
"embed_multimodal",
|
||||
"voyage-3-large",
|
||||
"embed_large",
|
||||
"voyage-code-3",
|
||||
"embed_large",
|
||||
"voyage-3",
|
||||
"embed_medium",
|
||||
"voyage-3-lite",
|
||||
"embed_small",
|
||||
"voyage-finance-2",
|
||||
"embed_large",
|
||||
"voyage-law-2",
|
||||
"embed_large",
|
||||
"voyage-code-2",
|
||||
"embed_large",
|
||||
"rerank-2",
|
||||
"rerank_large",
|
||||
"rerank-2-lite",
|
||||
"rerank_small"
|
||||
);
|
||||
|
||||
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
|
||||
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
|
||||
}
|
||||
|
||||
record RateLimitGrouping(int apiKeyHash) {
|
||||
public static RateLimitGrouping of(VoyageAIModel model) {
|
||||
Objects.requireNonNull(model);
|
||||
String modelId = model.getServiceSettings().modelId();
|
||||
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);
|
||||
|
||||
return new RateLimitGrouping(modelFamily.hashCode());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
|
||||
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
|
||||
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
|
||||
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();
|
||||
|
||||
private static ResponseHandler createVoyageAIResponseHandler() {
|
||||
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
|
||||
}
|
||||
|
||||
public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
|
||||
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
|
||||
}
|
||||
|
||||
private final VoyageAIRerankModel model;
|
||||
|
||||
private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
|
||||
super(threadPool, model);
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(
|
||||
InferenceInputs inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestCompletedFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
||||
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
|
||||
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {
|
||||
|
||||
private final VoyageAIAccount account;
|
||||
private final List<String> input;
|
||||
private final VoyageAIEmbeddingsServiceSettings serviceSettings;
|
||||
private final VoyageAIEmbeddingsTaskSettings taskSettings;
|
||||
private final String model;
|
||||
private final String inferenceEntityId;
|
||||
|
||||
public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
|
||||
Objects.requireNonNull(embeddingsModel);
|
||||
|
||||
account = VoyageAIAccount.of(embeddingsModel);
|
||||
this.input = Objects.requireNonNull(input);
|
||||
serviceSettings = embeddingsModel.getServiceSettings();
|
||||
taskSettings = embeddingsModel.getTaskSettings();
|
||||
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
|
||||
inferenceEntityId = embeddingsModel.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(account.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
|
||||
.getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
decorateWithHeaders(httpPost, account);
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return inferenceEntityId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return account.uri();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
|
||||
return taskSettings;
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
|
||||
return serviceSettings;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.invalidInputTypeMessage;
|
||||
|
||||
public record VoyageAIEmbeddingsRequestEntity(
|
||||
List<String> input,
|
||||
VoyageAIEmbeddingsServiceSettings serviceSettings,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
String model
|
||||
) implements ToXContentObject {
|
||||
|
||||
private static final String DOCUMENT = "document";
|
||||
private static final String QUERY = "query";
|
||||
private static final String INPUT_FIELD = "input";
|
||||
private static final String MODEL_FIELD = "model";
|
||||
public static final String INPUT_TYPE_FIELD = "input_type";
|
||||
public static final String TRUNCATION_FIELD = "truncation";
|
||||
public static final String OUTPUT_DIMENSION = "output_dimension";
|
||||
static final String OUTPUT_DTYPE_FIELD = "output_dtype";
|
||||
|
||||
public VoyageAIEmbeddingsRequestEntity {
|
||||
Objects.requireNonNull(input);
|
||||
Objects.requireNonNull(model);
|
||||
Objects.requireNonNull(taskSettings);
|
||||
Objects.requireNonNull(serviceSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(INPUT_FIELD, input);
|
||||
builder.field(MODEL_FIELD, model);
|
||||
|
||||
var inputType = convertToString(taskSettings.getInputType());
|
||||
if (inputType != null) {
|
||||
builder.field(INPUT_TYPE_FIELD, inputType);
|
||||
}
|
||||
|
||||
if (taskSettings.getTruncation() != null) {
|
||||
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
|
||||
}
|
||||
|
||||
if (serviceSettings.dimensions() != null) {
|
||||
builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions());
|
||||
}
|
||||
|
||||
if (serviceSettings.getEmbeddingType() != null) {
|
||||
builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType().toRequestString());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
static String convertToString(InputType inputType) {
|
||||
return switch (inputType) {
|
||||
case null -> null;
|
||||
case INGEST -> DOCUMENT;
|
||||
case SEARCH -> QUERY;
|
||||
default -> {
|
||||
assert false : invalidInputTypeMessage(inputType);
|
||||
yield null;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
public abstract class VoyageAIRequest implements Request {
|
||||
|
||||
public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) {
|
||||
request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
|
||||
request.setHeader(createAuthBearerHeader(account.apiKey()));
|
||||
request.setHeader(VoyageAIUtils.createRequestSourceHeader());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class VoyageAIRerankRequest extends VoyageAIRequest {
|
||||
|
||||
private final VoyageAIAccount account;
|
||||
private final String query;
|
||||
private final List<String> input;
|
||||
private final VoyageAIRerankTaskSettings taskSettings;
|
||||
private final String model;
|
||||
private final String inferenceEntityId;
|
||||
|
||||
public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
|
||||
Objects.requireNonNull(model);
|
||||
|
||||
this.account = VoyageAIAccount.of(model);
|
||||
this.input = Objects.requireNonNull(input);
|
||||
this.query = Objects.requireNonNull(query);
|
||||
taskSettings = model.getTaskSettings();
|
||||
this.model = model.getServiceSettings().modelId();
|
||||
inferenceEntityId = model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(account.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new VoyageAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
decorateWithHeaders(httpPost, account);
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return inferenceEntityId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return account.uri();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public record VoyageAIRerankRequestEntity(String model, String query, List<String> documents, VoyageAIRerankTaskSettings taskSettings)
|
||||
implements
|
||||
ToXContentObject {
|
||||
|
||||
private static final String DOCUMENTS_FIELD = "documents";
|
||||
private static final String QUERY_FIELD = "query";
|
||||
private static final String MODEL_FIELD = "model";
|
||||
public static final String TRUNCATION_FIELD = "truncation";
|
||||
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
|
||||
|
||||
public VoyageAIRerankRequestEntity {
|
||||
Objects.requireNonNull(query);
|
||||
Objects.requireNonNull(documents);
|
||||
Objects.requireNonNull(model);
|
||||
Objects.requireNonNull(taskSettings);
|
||||
}
|
||||
|
||||
public VoyageAIRerankRequestEntity(String query, List<String> input, VoyageAIRerankTaskSettings taskSettings, String model) {
|
||||
this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.field(MODEL_FIELD, model);
|
||||
builder.field(QUERY_FIELD, query);
|
||||
builder.field(DOCUMENTS_FIELD, documents);
|
||||
|
||||
if (taskSettings.getDoesReturnDocuments() != null) {
|
||||
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
||||
}
|
||||
|
||||
if (taskSettings.getTopKDocumentsOnly() != null) {
|
||||
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
|
||||
}
|
||||
|
||||
if (taskSettings.getTruncation() != null) {
|
||||
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.Header;
|
||||
import org.apache.http.message.BasicHeader;
|
||||
|
||||
public class VoyageAIUtils {
|
||||
public static final String HOST = "api.voyageai.com";
|
||||
public static final String VERSION_1 = "v1";
|
||||
public static final String EMBEDDINGS_PATH = "embeddings";
|
||||
public static final String RERANK_PATH = "rerank";
|
||||
public static final String REQUEST_SOURCE_HEADER = "Request-Source";
|
||||
public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch";
|
||||
|
||||
public static Header createRequestSourceHeader() {
|
||||
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);
|
||||
}
|
||||
|
||||
private VoyageAIUtils() {}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*
|
||||
* this file was contributed to by a generative AI
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase;
|
||||
|
||||
public class VoyageAIEmbeddingsResponseEntity {
|
||||
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
|
||||
|
||||
private static String supportedEmbeddingTypes() {
|
||||
String[] validTypes = new String[] {
|
||||
toLowerCase(VoyageAIEmbeddingType.FLOAT),
|
||||
toLowerCase(VoyageAIEmbeddingType.INT8),
|
||||
toLowerCase(VoyageAIEmbeddingType.BIT) };
|
||||
Arrays.sort(validTypes);
|
||||
return String.join(", ", validTypes);
|
||||
}
|
||||
|
||||
record EmbeddingInt8Result(List<EmbeddingInt8ResultEntry> entries) {
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<EmbeddingInt8Result, Void> PARSER = new ConstructingObjectParser<>(
|
||||
EmbeddingInt8Result.class.getSimpleName(),
|
||||
true,
|
||||
args -> new EmbeddingInt8Result((List<EmbeddingInt8ResultEntry>) args[0])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data"));
|
||||
}
|
||||
}
|
||||
|
||||
record EmbeddingInt8ResultEntry(Integer index, List<Integer> embedding) {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<EmbeddingInt8ResultEntry, Void> PARSER = new ConstructingObjectParser<>(
|
||||
EmbeddingInt8ResultEntry.class.getSimpleName(),
|
||||
true,
|
||||
args -> new EmbeddingInt8ResultEntry((Integer) args[0], (List<Integer>) args[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), new ParseField("index"));
|
||||
PARSER.declareIntArray(constructorArg(), new ParseField("embedding"));
|
||||
}
|
||||
|
||||
private static void checkByteBounds(Integer value) {
|
||||
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceByteEmbedding toInferenceByteEmbedding() {
|
||||
embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds);
|
||||
return InferenceByteEmbedding.of(embedding.stream().map(Integer::byteValue).toList());
|
||||
}
|
||||
}
|
||||
|
||||
record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> entries) {
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
|
||||
EmbeddingFloatResult.class.getSimpleName(),
|
||||
true,
|
||||
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
|
||||
}
|
||||
}
|
||||
|
||||
record EmbeddingFloatResultEntry(Integer index, List<Float> embedding) {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<EmbeddingFloatResultEntry, Void> PARSER = new ConstructingObjectParser<>(
|
||||
EmbeddingFloatResultEntry.class.getSimpleName(),
|
||||
true,
|
||||
args -> new EmbeddingFloatResultEntry((Integer) args[0], (List<Float>) args[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), new ParseField("index"));
|
||||
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
|
||||
}
|
||||
|
||||
public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFloatEmbedding() {
|
||||
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embedding);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the VoyageAI json response.
|
||||
* For a request like:
|
||||
*
|
||||
* <pre>
|
||||
* <code>
|
||||
* {
|
||||
* "input": [
|
||||
* "Sample text 1",
|
||||
* "Sample text 2"
|
||||
* ],
|
||||
* "model": "voyage-3-large"
|
||||
* }
|
||||
* </code>
|
||||
* </pre>
|
||||
*
|
||||
* The response would look like:
|
||||
*
|
||||
* <pre>
|
||||
* <code>
|
||||
* {
|
||||
* "object": "list",
|
||||
* "data": [
|
||||
* {
|
||||
* "object": "embedding",
|
||||
* "embedding": [
|
||||
* -0.009327292,
|
||||
* -0.0028842222,
|
||||
* ],
|
||||
* "index": 0
|
||||
* },
|
||||
* {
|
||||
* "object": "embedding",
|
||||
* "embedding": [ ... ],
|
||||
* "index": 1
|
||||
* }
|
||||
* ],
|
||||
* "model": "voyage-3-large",
|
||||
* "usage": {
|
||||
* "total_tokens": 10
|
||||
* }
|
||||
* }
|
||||
* </code>
|
||||
* </pre>
|
||||
*/
|
||||
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
|
||||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||
VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType();
|
||||
|
||||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
|
||||
if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
|
||||
var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null);
|
||||
|
||||
List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding> embeddingList = embeddingResult.entries.stream()
|
||||
.map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding)
|
||||
.toList();
|
||||
return new InferenceTextEmbeddingFloatResults(embeddingList);
|
||||
} else if (embeddingType == VoyageAIEmbeddingType.INT8) {
|
||||
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
|
||||
List<InferenceByteEmbedding> embeddingList = embeddingResult.entries.stream()
|
||||
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
|
||||
.toList();
|
||||
return new InferenceTextEmbeddingByteResults(embeddingList);
|
||||
} else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) {
|
||||
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
|
||||
List<InferenceByteEmbedding> embeddingList = embeddingResult.entries.stream()
|
||||
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
|
||||
.toList();
|
||||
return new InferenceTextEmbeddingBitResults(embeddingList);
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private VoyageAIEmbeddingsResponseEntity() {}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
|
||||
public class VoyageAIErrorResponseEntity extends ErrorResponse {
|
||||
|
||||
private VoyageAIErrorResponseEntity(String errorMessage) {
|
||||
super(errorMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse an HTTP response into a VoyageAIErrorResponseEntity
|
||||
*
|
||||
* @param response The error response
|
||||
* @return An error entity if the response is JSON with a `detail` field containing the error message
|
||||
* or null if the response does not contain the message field
|
||||
*/
|
||||
public static ErrorResponse fromResponse(HttpResult response) {
|
||||
try (
|
||||
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(XContentParserConfiguration.EMPTY, response.body())
|
||||
) {
|
||||
var responseMap = jsonParser.map();
|
||||
var message = (String) responseMap.get("detail");
|
||||
if (message != null) {
|
||||
return new VoyageAIErrorResponseEntity(message);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// swallow the error
|
||||
}
|
||||
|
||||
return ErrorResponse.UNDEFINED_ERROR;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*
|
||||
* this file was contributed to by a generative AI
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class VoyageAIRerankResponseEntity {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class);
|
||||
|
||||
record RerankResult(List<RerankResultEntry> entries) {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
|
||||
RerankResult.class.getSimpleName(),
|
||||
true,
|
||||
args -> new RerankResult((List<RerankResultEntry>) args[0])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data"));
|
||||
}
|
||||
}
|
||||
|
||||
record RerankResultEntry(Float relevanceScore, Integer index, @Nullable String document) {
|
||||
|
||||
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
|
||||
RerankResultEntry.class.getSimpleName(),
|
||||
args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (String) args[2])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
|
||||
PARSER.declareInt(constructorArg(), new ParseField("index"));
|
||||
PARSER.declareString(optionalConstructorArg(), new ParseField("document"));
|
||||
}
|
||||
|
||||
public RankedDocsResults.RankedDoc toRankedDoc() {
|
||||
return new RankedDocsResults.RankedDoc(index, relevanceScore, document);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the VoyageAI ranked response.
|
||||
* For a request like:
|
||||
* "model": "rerank-2",
|
||||
* "query": "What is the capital of the United States?",
|
||||
* "top_k": 2,
|
||||
* "documents": ["Carson City is the capital city of the American state of Nevada.",
|
||||
* "The Commonwealth of the Northern Mariana ... Its capital is Saipan.",
|
||||
* "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.",
|
||||
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
|
||||
* <p>
|
||||
* The response will look like (without whitespace):
|
||||
* {
|
||||
* "object": "list",
|
||||
* "data": [
|
||||
* {
|
||||
* "relevance_score": 0.4375,
|
||||
* "index": 0
|
||||
* },
|
||||
* {
|
||||
* "relevance_score": 0.421875,
|
||||
* "index": 1
|
||||
* }
|
||||
* ],
|
||||
* "model": "rerank-2",
|
||||
* "usage": {
|
||||
* "total_tokens": 26
|
||||
* }
|
||||
* }
|
||||
* @param response the http response from VoyageAI
|
||||
* @return the parsed response
|
||||
* @throws IOException if there is an error parsing the response
|
||||
*/
|
||||
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
|
||||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||
|
||||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
|
||||
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
|
||||
|
||||
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList());
|
||||
}
|
||||
}
|
||||
|
||||
private VoyageAIRerankResponseEntity() {}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.voyageai;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Objects;
|
||||
|
||||
public record VoyageAIAccount(URI uri, SecureString apiKey) {
|
||||
|
||||
public static VoyageAIAccount of(VoyageAIModel model) {
|
||||
try {
|
||||
var uri = model.buildUri();
|
||||
return new VoyageAIAccount(uri, model.apiKey());
|
||||
} catch (URISyntaxException e) {
|
||||
// using bad request here so that potentially sensitive URL information does not get logged
|
||||
throw new ElasticsearchStatusException("Failed to construct VoyageAI URL", RestStatus.BAD_REQUEST, e);
|
||||
}
|
||||
}
|
||||
|
||||
public VoyageAIAccount {
|
||||
Objects.requireNonNull(uri);
|
||||
Objects.requireNonNull(apiKey);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.voyageai;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity;
|
||||
|
||||
/**
|
||||
* Defines how to handle various errors returned from the VoyageAI integration.
|
||||
*
|
||||
*/
|
||||
public class VoyageAIResponseHandler extends BaseResponseHandler {
|
||||
static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response";
|
||||
static final String PAYMENT_ERROR_MESSAGE = "Payment required";
|
||||
|
||||
public VoyageAIResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, VoyageAIErrorResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the status code throws an RetryException if not in the range [200, 300).
|
||||
*
|
||||
* @param request The http request
|
||||
* @param result The http response and body
|
||||
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
|
||||
*/
|
||||
@Override
|
||||
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
|
||||
if (result.isSuccessfulResponse()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// handle error codes
|
||||
int statusCode = result.response().getStatusLine().getStatusCode();
|
||||
if (statusCode == 500) {
|
||||
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
|
||||
} else if (statusCode > 500) {
|
||||
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
|
||||
} else if (statusCode == 429) {
|
||||
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
|
||||
} else if (statusCode == 400 || statusCode == 422) {
|
||||
throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result));
|
||||
} else if (statusCode == 401) {
|
||||
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
|
||||
} else if (statusCode == 402) {
|
||||
throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result));
|
||||
} else if (statusCode >= 300 && statusCode < 400) {
|
||||
throw new RetryException(false, buildError(REDIRECTION, request, result));
|
||||
} else {
|
||||
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class VoyageAIModel extends Model {
|
||||
private final SecureString apiKey;
|
||||
private final VoyageAIRateLimitServiceSettings rateLimitServiceSettings;
|
||||
protected final URI uri;
|
||||
|
||||
public VoyageAIModel(
|
||||
ModelConfigurations configurations,
|
||||
ModelSecrets secrets,
|
||||
@Nullable ApiKeySecrets apiKeySecrets,
|
||||
VoyageAIRateLimitServiceSettings rateLimitServiceSettings
|
||||
) {
|
||||
this(configurations, secrets, apiKeySecrets, rateLimitServiceSettings, null);
|
||||
}
|
||||
|
||||
public VoyageAIModel(
|
||||
ModelConfigurations configurations,
|
||||
ModelSecrets secrets,
|
||||
@Nullable ApiKeySecrets apiKeySecrets,
|
||||
VoyageAIRateLimitServiceSettings rateLimitServiceSettings,
|
||||
String url
|
||||
) {
|
||||
super(configurations, secrets);
|
||||
|
||||
this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
|
||||
this.apiKey = ServiceUtils.apiKey(apiKeySecrets);
|
||||
this.uri = url == null ? null : URI.create(url);
|
||||
}
|
||||
|
||||
protected VoyageAIModel(VoyageAIModel model, TaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
|
||||
this.rateLimitServiceSettings = model.rateLimitServiceSettings();
|
||||
this.apiKey = model.apiKey();
|
||||
this.uri = model.uri;
|
||||
}
|
||||
|
||||
protected VoyageAIModel(VoyageAIModel model, ServiceSettings serviceSettings) {
|
||||
super(model, serviceSettings);
|
||||
|
||||
this.rateLimitServiceSettings = model.rateLimitServiceSettings();
|
||||
this.apiKey = model.apiKey();
|
||||
this.uri = model.uri;
|
||||
}
|
||||
|
||||
public SecureString apiKey() {
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
public VoyageAIRateLimitServiceSettings rateLimitServiceSettings() {
|
||||
return rateLimitServiceSettings;
|
||||
}
|
||||
|
||||
public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
|
||||
|
||||
public URI uri() {
|
||||
return uri;
|
||||
}
|
||||
|
||||
public URI buildUri() throws URISyntaxException {
|
||||
if (uri == null) {
|
||||
return buildRequestUri();
|
||||
}
|
||||
return uri;
|
||||
}
|
||||
|
||||
protected abstract URI buildRequestUri() throws URISyntaxException;
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
public interface VoyageAIRateLimitServiceSettings {
|
||||
RateLimitSettings rateLimitSettings();
|
||||
|
||||
}
|
|
@ -0,0 +1,397 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.SenderService;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
|
||||
|
||||
public class VoyageAIService extends SenderService {
|
||||
public static final String NAME = "voyageai";
|
||||
|
||||
private static final String SERVICE_NAME = "Voyage AI";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
|
||||
|
||||
private static final Integer DEFAULT_BATCH_SIZE = 7;
|
||||
private static final Map<String, Integer> MODEL_BATCH_SIZES = Map.of(
|
||||
"voyage-multimodal-3",
|
||||
7,
|
||||
"voyage-3-large",
|
||||
7,
|
||||
"voyage-code-3",
|
||||
7,
|
||||
"voyage-3",
|
||||
10,
|
||||
"voyage-3-lite",
|
||||
30,
|
||||
"voyage-finance-2",
|
||||
7,
|
||||
"voyage-law-2",
|
||||
7,
|
||||
"voyage-code-2",
|
||||
7,
|
||||
"voyage-2",
|
||||
72,
|
||||
"voyage-02",
|
||||
72
|
||||
);
|
||||
|
||||
public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parseRequestConfig(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
ActionListener<Model> parsedModelListener
|
||||
) {
|
||||
try {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(
|
||||
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
|
||||
);
|
||||
}
|
||||
VoyageAIModel model = createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
serviceSettingsMap,
|
||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
throwIfNotEmptyMap(config, NAME);
|
||||
throwIfNotEmptyMap(serviceSettingsMap, NAME);
|
||||
throwIfNotEmptyMap(taskSettingsMap, NAME);
|
||||
|
||||
parsedModelListener.onResponse(model);
|
||||
} catch (Exception e) {
|
||||
parsedModelListener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static VoyageAIModel createModelFromPersistent(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage
|
||||
) {
|
||||
return createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
chunkingSettings,
|
||||
secretSettings,
|
||||
failureMessage,
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
}
|
||||
|
||||
private static VoyageAIModel createModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return switch (taskType) {
|
||||
case TEXT_EMBEDDING -> new VoyageAIEmbeddingsModel(
|
||||
inferenceEntityId,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
chunkingSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIModel parsePersistedConfigWithSecrets(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
Map<String, Object> secrets
|
||||
) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
||||
}
|
||||
|
||||
return createModelFromPersistent(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
secretSettingsMap,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
||||
}
|
||||
|
||||
return createModelFromPersistent(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
null,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceServiceConfiguration getConfiguration() {
|
||||
return Configuration.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EnumSet<TaskType> supportedTaskTypes() {
|
||||
return supportedTaskTypes;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doUnifiedCompletionInfer(
|
||||
Model model,
|
||||
UnifiedChatInput inputs,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
throwUnsupportedUnifiedCompletionOperation(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doInfer(
|
||||
Model model,
|
||||
InferenceInputs inputs,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
if (model instanceof VoyageAIModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
VoyageAIModel voyageaiModel = (VoyageAIModel) model;
|
||||
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
var action = voyageaiModel.accept(actionCreator, taskSettings, inputType);
|
||||
action.execute(inputs, timeout, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doChunkedInfer(
|
||||
Model model,
|
||||
DocumentsOnlyInput inputs,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<List<ChunkedInference>> listener
|
||||
) {
|
||||
if (model instanceof VoyageAIModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
VoyageAIModel voyageaiModel = (VoyageAIModel) model;
|
||||
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
|
||||
|
||||
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
|
||||
inputs.getInputs(),
|
||||
getBatchSize(voyageaiModel),
|
||||
EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()),
|
||||
voyageaiModel.getConfigurations().getChunkingSettings()
|
||||
).batchRequestsWithListeners(listener);
|
||||
|
||||
for (var request : batchedRequests) {
|
||||
var action = voyageaiModel.accept(actionCreator, taskSettings, inputType);
|
||||
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
|
||||
}
|
||||
}
|
||||
|
||||
private static int getBatchSize(VoyageAIModel model) {
|
||||
return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE);
|
||||
}
|
||||
|
||||
/**
|
||||
* For text embedding models get the embedding size and
|
||||
* update the service settings.
|
||||
*
|
||||
* @param model The new model
|
||||
* @param listener The listener
|
||||
*/
|
||||
@Override
|
||||
public void checkModelConfig(Model model, ActionListener<Model> listener) {
|
||||
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||
if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) {
|
||||
var serviceSettings = embeddingsModel.getServiceSettings();
|
||||
var similarityFromModel = serviceSettings.similarity();
|
||||
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
|
||||
var maxInputTokens = serviceSettings.maxInputTokens();
|
||||
var dimensionSetByUser = serviceSettings.dimensionsSetByUser();
|
||||
|
||||
var updatedServiceSettings = new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(
|
||||
serviceSettings.getCommonSettings().modelId(),
|
||||
serviceSettings.getCommonSettings().rateLimitSettings()
|
||||
),
|
||||
serviceSettings.getEmbeddingType(),
|
||||
similarityToUse,
|
||||
embeddingSize,
|
||||
maxInputTokens,
|
||||
dimensionSetByUser
|
||||
);
|
||||
|
||||
return new VoyageAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
|
||||
} else {
|
||||
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the default similarity measure for the embedding type.
|
||||
* VoyageAI embeddings are normalized to unit vectors therefore Dot
|
||||
* Product similarity can be used and is the default for all VoyageAI
|
||||
* models.
|
||||
*
|
||||
* @return The default similarity.
|
||||
*/
|
||||
static SimilarityMeasure defaultSimilarity() {
|
||||
return SimilarityMeasure.DOT_PRODUCT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
public static class Configuration {
|
||||
public static InferenceServiceConfiguration get() {
|
||||
return configuration.getOrCompute();
|
||||
}
|
||||
|
||||
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
|
||||
() -> {
|
||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
||||
|
||||
configurationMap.put(
|
||||
MODEL_ID,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The name of the model to use for the inference task."
|
||||
)
|
||||
.setLabel("Model ID")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
);
|
||||
|
||||
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
|
||||
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
|
||||
|
||||
return new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||
.setName(SERVICE_NAME)
|
||||
.setTaskTypes(supportedTaskTypes)
|
||||
.setConfigurations(configurationMap)
|
||||
.build();
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
public class VoyageAIServiceFields {
|
||||
public static final String TRUNCATION = "truncation";
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||
|
||||
public class VoyageAIServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings {
|
||||
|
||||
public static final String NAME = "voyageai_service_settings";
|
||||
public static final String MODEL_ID = "model_id";
|
||||
private static final Logger logger = LogManager.getLogger(VoyageAIServiceSettings.class);
|
||||
// See https://docs.voyageai.com/docs/rate-limits
|
||||
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000);
|
||||
|
||||
public static VoyageAIServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
VoyageAIService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new VoyageAIServiceSettings(modelId, rateLimitSettings);
|
||||
}
|
||||
|
||||
private final String modelId;
|
||||
private final RateLimitSettings rateLimitSettings;
|
||||
|
||||
public VoyageAIServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
|
||||
this.modelId = Objects.requireNonNull(modelId);
|
||||
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||
}
|
||||
|
||||
public VoyageAIServiceSettings(StreamInput in) throws IOException {
|
||||
modelId = in.readString();
|
||||
rateLimitSettings = new RateLimitSettings(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
toXContentFragment(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
|
||||
return toXContentFragmentOfExposedFields(builder, params);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(MODEL_ID, modelId);
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
rateLimitSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
VoyageAIServiceSettings that = (VoyageAIServiceSettings) o;
|
||||
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, rateLimitSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Defines the type of embedding that the VoyageAI api should return for a request.
|
||||
*
|
||||
* <p>
|
||||
* <a href="https://docs.voyageai.com/reference/embeddings-api">See api docs for details.</a>
|
||||
* </p>
|
||||
*/
|
||||
public enum VoyageAIEmbeddingType {
|
||||
/**
|
||||
* Use this when you want to get back the default float embeddings. Valid for all models.
|
||||
*/
|
||||
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT),
|
||||
/**
|
||||
* Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
|
||||
*/
|
||||
INT8(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
|
||||
/**
|
||||
* This is a synonym for INT8
|
||||
*/
|
||||
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
|
||||
/**
|
||||
* Use this when you want to get back binary embeddings. Valid only for v3 models.
|
||||
*/
|
||||
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY),
|
||||
/**
|
||||
* This is a synonym for BIT
|
||||
*/
|
||||
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY);
|
||||
|
||||
private static final class RequestConstants {
|
||||
private static final String FLOAT = "float";
|
||||
private static final String INT8 = "int8";
|
||||
private static final String BINARY = "binary";
|
||||
}
|
||||
|
||||
private static final Map<DenseVectorFieldMapper.ElementType, VoyageAIEmbeddingType> ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of(
|
||||
DenseVectorFieldMapper.ElementType.FLOAT,
|
||||
FLOAT,
|
||||
DenseVectorFieldMapper.ElementType.BYTE,
|
||||
BYTE,
|
||||
DenseVectorFieldMapper.ElementType.BIT,
|
||||
BIT
|
||||
);
|
||||
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
|
||||
ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet()
|
||||
);
|
||||
|
||||
private final DenseVectorFieldMapper.ElementType elementType;
|
||||
private final String requestString;
|
||||
|
||||
VoyageAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) {
|
||||
this.elementType = elementType;
|
||||
this.requestString = requestString;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
|
||||
public String toRequestString() {
|
||||
return requestString;
|
||||
}
|
||||
|
||||
public static String toLowerCase(VoyageAIEmbeddingType type) {
|
||||
return type.toString().toLowerCase(Locale.ROOT);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingType fromString(String name) {
|
||||
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) {
|
||||
var embedding = ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.get(elementType);
|
||||
|
||||
if (embedding == null) {
|
||||
var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream()
|
||||
.map(value -> value.toString().toLowerCase(Locale.ROOT))
|
||||
.toArray(String[]::new);
|
||||
Arrays.sort(validElementTypes);
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
Strings.format(
|
||||
"Element type [%s] does not map to a VoyageAI embedding value, must be one of [%s]",
|
||||
elementType,
|
||||
String.join(", ", validElementTypes)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return embedding;
|
||||
}
|
||||
|
||||
public DenseVectorFieldMapper.ElementType toElementType() {
|
||||
return elementType;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.apache.http.client.utils.URIBuilder;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST;
|
||||
|
||||
public class VoyageAIEmbeddingsModel extends VoyageAIModel {
|
||||
public static VoyageAIEmbeddingsModel of(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
|
||||
var requestTaskSettings = VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings);
|
||||
return new VoyageAIEmbeddingsModel(
|
||||
model,
|
||||
VoyageAIEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)
|
||||
);
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsModel(
|
||||
String inferenceId,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceId,
|
||||
service,
|
||||
VoyageAIEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||
VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings),
|
||||
chunkingSettings,
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
// should only be used for testing
|
||||
VoyageAIEmbeddingsModel(
|
||||
String modelId,
|
||||
String service,
|
||||
VoyageAIEmbeddingsServiceSettings serviceSettings,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings),
|
||||
new ModelSecrets(secretSettings),
|
||||
secretSettings,
|
||||
serviceSettings.getCommonSettings()
|
||||
);
|
||||
}
|
||||
|
||||
VoyageAIEmbeddingsModel(
|
||||
String modelId,
|
||||
String service,
|
||||
String url,
|
||||
VoyageAIEmbeddingsServiceSettings serviceSettings,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings),
|
||||
new ModelSecrets(secretSettings),
|
||||
secretSettings,
|
||||
serviceSettings.getCommonSettings(),
|
||||
url
|
||||
);
|
||||
}
|
||||
|
||||
private VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsTaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsServiceSettings serviceSettings) {
|
||||
super(model, serviceSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
|
||||
return (VoyageAIEmbeddingsServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
|
||||
return (VoyageAIEmbeddingsTaskSettings) super.getTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction accept(VoyageAIActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
|
||||
return visitor.create(this, taskSettings, inputType);
|
||||
}
|
||||
|
||||
protected URI buildRequestUri() throws URISyntaxException {
|
||||
return new URIBuilder().setScheme("https")
|
||||
.setHost(HOST)
|
||||
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH)
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
|
||||
|
||||
public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
|
||||
public static final String NAME = "voyageai_embeddings_service_settings";
|
||||
static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
|
||||
public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings(
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
public static final String EMBEDDING_TYPE = "embedding_type";
|
||||
|
||||
public static VoyageAIEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
return switch (context) {
|
||||
case REQUEST -> fromRequestMap(map, context);
|
||||
case PERSISTENT -> fromPersistentMap(map, context);
|
||||
};
|
||||
}
|
||||
|
||||
private static VoyageAIEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
|
||||
|
||||
VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException);
|
||||
|
||||
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens, dims != null);
|
||||
}
|
||||
|
||||
private static VoyageAIEmbeddingsServiceSettings fromPersistentMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
|
||||
|
||||
VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException);
|
||||
|
||||
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||
|
||||
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
|
||||
if (dimensionsSetByUser == null) {
|
||||
dimensionsSetByUser = Boolean.FALSE;
|
||||
}
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new VoyageAIEmbeddingsServiceSettings(
|
||||
commonServiceSettings,
|
||||
embeddingTypes,
|
||||
similarity,
|
||||
dims,
|
||||
maxInputTokens,
|
||||
dimensionsSetByUser
|
||||
);
|
||||
}
|
||||
|
||||
static VoyageAIEmbeddingType parseEmbeddingType(
|
||||
Map<String, Object> map,
|
||||
ConfigurationParseContext context,
|
||||
ValidationException validationException
|
||||
) {
|
||||
return switch (context) {
|
||||
case REQUEST, PERSISTENT -> Objects.requireNonNullElse(
|
||||
extractOptionalEnum(
|
||||
map,
|
||||
EMBEDDING_TYPE,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
VoyageAIEmbeddingType::fromString,
|
||||
EnumSet.allOf(VoyageAIEmbeddingType.class),
|
||||
validationException
|
||||
),
|
||||
VoyageAIEmbeddingType.FLOAT
|
||||
);
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
private final VoyageAIServiceSettings commonSettings;
|
||||
private final VoyageAIEmbeddingType embeddingType;
|
||||
private final SimilarityMeasure similarity;
|
||||
private final Integer dimensions;
|
||||
private final Integer maxInputTokens;
|
||||
private final boolean dimensionsSetByUser;
|
||||
|
||||
public VoyageAIEmbeddingsServiceSettings(
|
||||
VoyageAIServiceSettings commonSettings,
|
||||
@Nullable VoyageAIEmbeddingType embeddingType,
|
||||
@Nullable SimilarityMeasure similarity,
|
||||
@Nullable Integer dimensions,
|
||||
@Nullable Integer maxInputTokens,
|
||||
boolean dimensionsSetByUser
|
||||
) {
|
||||
this.commonSettings = commonSettings;
|
||||
this.similarity = similarity;
|
||||
this.dimensions = dimensions;
|
||||
this.maxInputTokens = maxInputTokens;
|
||||
this.embeddingType = embeddingType;
|
||||
this.dimensionsSetByUser = dimensionsSetByUser;
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
|
||||
this.commonSettings = new VoyageAIServiceSettings(in);
|
||||
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
|
||||
this.dimensions = in.readOptionalVInt();
|
||||
this.maxInputTokens = in.readOptionalVInt();
|
||||
this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT);
|
||||
this.dimensionsSetByUser = in.readBoolean();
|
||||
}
|
||||
|
||||
public VoyageAIServiceSettings getCommonSettings() {
|
||||
return commonSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimilarityMeasure similarity() {
|
||||
return similarity;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer dimensions() {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public Integer maxInputTokens() {
|
||||
return maxInputTokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return commonSettings.modelId();
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingType getEmbeddingType() {
|
||||
return embeddingType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseVectorFieldMapper.ElementType elementType() {
|
||||
return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean dimensionsSetByUser() {
|
||||
return this.dimensionsSetByUser;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder = commonSettings.toXContentFragment(builder, params);
|
||||
if (similarity != null) {
|
||||
builder.field(SIMILARITY, similarity);
|
||||
}
|
||||
if (dimensions != null) {
|
||||
builder.field(DIMENSIONS, dimensions);
|
||||
}
|
||||
if (maxInputTokens != null) {
|
||||
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||
}
|
||||
if (embeddingType != null) {
|
||||
builder.field(EMBEDDING_TYPE, embeddingType);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
commonSettings.toXContentFragmentOfExposedFields(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
commonSettings.writeTo(out);
|
||||
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
|
||||
out.writeOptionalVInt(dimensions);
|
||||
out.writeOptionalVInt(maxInputTokens);
|
||||
out.writeOptionalEnum(embeddingType);
|
||||
out.writeBoolean(dimensionsSetByUser);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
VoyageAIEmbeddingsServiceSettings that = (VoyageAIEmbeddingsServiceSettings) o;
|
||||
return Objects.equals(commonSettings, that.commonSettings)
|
||||
&& Objects.equals(similarity, that.similarity)
|
||||
&& Objects.equals(dimensions, that.dimensions)
|
||||
&& Objects.equals(maxInputTokens, that.maxInputTokens)
|
||||
&& Objects.equals(embeddingType, that.embeddingType)
|
||||
&& Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION;
|
||||
|
||||
/**
|
||||
* Defines the task settings for the voyageai text embeddings service.
|
||||
*
|
||||
* <p>
|
||||
* <a href="https://docs.voyageai.com/docs/embeddings">See api docs for details.</a>
|
||||
* </p>
|
||||
*/
|
||||
public class VoyageAIEmbeddingsTaskSettings implements TaskSettings {
|
||||
|
||||
public static final String NAME = "voyageai_embeddings_task_settings";
|
||||
public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null);
|
||||
static final String INPUT_TYPE = "input_type";
|
||||
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH);
|
||||
|
||||
public static VoyageAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
|
||||
if (map == null || map.isEmpty()) {
|
||||
return EMPTY_SETTINGS;
|
||||
}
|
||||
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
InputType inputType = extractOptionalEnum(
|
||||
map,
|
||||
INPUT_TYPE,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
InputType::fromString,
|
||||
VALID_REQUEST_VALUES,
|
||||
validationException
|
||||
);
|
||||
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@link VoyageAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters.
|
||||
* For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED.
|
||||
* Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null.
|
||||
* Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to
|
||||
* originalSettings.
|
||||
* @param originalSettings the settings stored as part of the inference entity configuration
|
||||
* @param requestTaskSettings the settings passed in within the task_settings field of the request
|
||||
* @param requestInputType the input type passed in the request parameters
|
||||
* @return a constructed {@link VoyageAIEmbeddingsTaskSettings}
|
||||
*/
|
||||
public static VoyageAIEmbeddingsTaskSettings of(
|
||||
VoyageAIEmbeddingsTaskSettings originalSettings,
|
||||
VoyageAIEmbeddingsTaskSettings requestTaskSettings,
|
||||
InputType requestInputType
|
||||
) {
|
||||
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType);
|
||||
var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings);
|
||||
|
||||
return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse);
|
||||
}
|
||||
|
||||
private static InputType getValidInputType(
|
||||
VoyageAIEmbeddingsTaskSettings originalSettings,
|
||||
VoyageAIEmbeddingsTaskSettings requestTaskSettings,
|
||||
InputType requestInputType
|
||||
) {
|
||||
InputType inputTypeToUse = originalSettings.inputType;
|
||||
|
||||
if (VALID_REQUEST_VALUES.contains(requestInputType)) {
|
||||
inputTypeToUse = requestInputType;
|
||||
} else if (requestTaskSettings.inputType != null) {
|
||||
inputTypeToUse = requestTaskSettings.inputType;
|
||||
}
|
||||
|
||||
return inputTypeToUse;
|
||||
}
|
||||
|
||||
private static Boolean getValidTruncation(
|
||||
VoyageAIEmbeddingsTaskSettings originalSettings,
|
||||
VoyageAIEmbeddingsTaskSettings requestTaskSettings
|
||||
) {
|
||||
return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation();
|
||||
}
|
||||
|
||||
private final InputType inputType;
|
||||
private final Boolean truncation;
|
||||
|
||||
public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
|
||||
}
|
||||
|
||||
public VoyageAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean truncation) {
|
||||
validateInputType(inputType);
|
||||
this.inputType = inputType;
|
||||
this.truncation = truncation;
|
||||
}
|
||||
|
||||
private static void validateInputType(InputType inputType) {
|
||||
if (inputType == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return inputType == null && truncation == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (inputType != null) {
|
||||
builder.field(INPUT_TYPE, inputType);
|
||||
}
|
||||
|
||||
if (truncation != null) {
|
||||
builder.field(TRUNCATION, truncation);
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public InputType getInputType() {
|
||||
return inputType;
|
||||
}
|
||||
|
||||
public Boolean getTruncation() {
|
||||
return truncation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalEnum(inputType);
|
||||
out.writeOptionalBoolean(truncation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o;
|
||||
return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(inputType, truncation);
|
||||
}
|
||||
|
||||
public static String invalidInputTypeMessage(InputType inputType) {
|
||||
return Strings.format("received invalid input type value [%s]", inputType.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
VoyageAIEmbeddingsTaskSettings updatedSettings = VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings));
|
||||
return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.apache.http.client.utils.URIBuilder;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST;
|
||||
|
||||
public class VoyageAIRerankModel extends VoyageAIModel {
|
||||
public static VoyageAIRerankModel of(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
|
||||
var requestTaskSettings = VoyageAIRerankTaskSettings.fromMap(taskSettings);
|
||||
return new VoyageAIRerankModel(model, VoyageAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
|
||||
}
|
||||
|
||||
public VoyageAIRerankModel(
|
||||
String inferenceId,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceId,
|
||||
service,
|
||||
VoyageAIRerankServiceSettings.fromMap(serviceSettings, context),
|
||||
VoyageAIRerankTaskSettings.fromMap(taskSettings),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
// should only be used for testing
|
||||
VoyageAIRerankModel(
|
||||
String modelId,
|
||||
String service,
|
||||
VoyageAIRerankServiceSettings serviceSettings,
|
||||
VoyageAIRerankTaskSettings taskSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
this(modelId, service, null, serviceSettings, taskSettings, secretSettings);
|
||||
}
|
||||
|
||||
VoyageAIRerankModel(
|
||||
String modelId,
|
||||
String service,
|
||||
String url,
|
||||
VoyageAIRerankServiceSettings serviceSettings,
|
||||
VoyageAIRerankTaskSettings taskSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings),
|
||||
new ModelSecrets(secretSettings),
|
||||
secretSettings,
|
||||
serviceSettings.getCommonSettings(),
|
||||
url
|
||||
);
|
||||
}
|
||||
|
||||
private VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankTaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
}
|
||||
|
||||
public VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankServiceSettings serviceSettings) {
|
||||
super(model, serviceSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIRerankServiceSettings getServiceSettings() {
|
||||
return (VoyageAIRerankServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VoyageAIRerankTaskSettings getTaskSettings() {
|
||||
return (VoyageAIRerankTaskSettings) super.getTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor _
|
||||
* @param taskSettings _
|
||||
* @param inputType ignored for rerank task
|
||||
* @return the rerank action
|
||||
*/
|
||||
@Override
|
||||
public ExecutableAction accept(VoyageAIActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
|
||||
return visitor.create(this, taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected URI buildRequestUri() throws URISyntaxException {
|
||||
return new URIBuilder().setScheme("https")
|
||||
.setHost(HOST)
|
||||
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH)
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class VoyageAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings {
|
||||
public static final String NAME = "voyageai_rerank_service_settings";
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(VoyageAIRerankServiceSettings.class);
|
||||
|
||||
public static VoyageAIRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
|
||||
|
||||
return new VoyageAIRerankServiceSettings(commonServiceSettings);
|
||||
}
|
||||
|
||||
private final VoyageAIServiceSettings commonSettings;
|
||||
|
||||
public VoyageAIRerankServiceSettings(VoyageAIServiceSettings commonSettings) {
|
||||
this.commonSettings = commonSettings;
|
||||
}
|
||||
|
||||
public VoyageAIRerankServiceSettings(StreamInput in) throws IOException {
|
||||
this.commonSettings = new VoyageAIServiceSettings(in);
|
||||
}
|
||||
|
||||
public VoyageAIServiceSettings getCommonSettings() {
|
||||
return commonSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return commonSettings.modelId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return commonSettings.rateLimitSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder = commonSettings.toXContentFragment(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
commonSettings.toXContentFragmentOfExposedFields(builder, params);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
commonSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
VoyageAIRerankServiceSettings that = (VoyageAIRerankServiceSettings) o;
|
||||
return Objects.equals(commonSettings, that.commonSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(commonSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION;
|
||||
|
||||
/**
|
||||
* Defines the task settings for the VoyageAI rerank service.
|
||||
*
|
||||
*/
|
||||
public class VoyageAIRerankTaskSettings implements TaskSettings {
|
||||
|
||||
public static final String NAME = "voyageai_rerank_task_settings";
|
||||
public static final String RETURN_DOCUMENTS = "return_documents";
|
||||
public static final String TOP_K_DOCS_ONLY = "top_k";
|
||||
|
||||
public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null, null);
|
||||
|
||||
public static VoyageAIRerankTaskSettings fromMap(Map<String, Object> map) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
if (map == null || map.isEmpty()) {
|
||||
return EMPTY_SETTINGS;
|
||||
}
|
||||
|
||||
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
|
||||
Integer topKDocumentsOnly = extractOptionalPositiveInteger(
|
||||
map,
|
||||
TOP_K_DOCS_ONLY,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
|
||||
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return of(topKDocumentsOnly, returnDocuments, truncation);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@link VoyageAIRerankTaskSettings} by preferring non-null fields from the request settings over the original settings.
|
||||
*
|
||||
* @param originalSettings the settings stored as part of the inference entity configuration
|
||||
* @param requestTaskSettings the settings passed in within the task_settings field of the request
|
||||
* @return a constructed {@link VoyageAIRerankTaskSettings}
|
||||
*/
|
||||
public static VoyageAIRerankTaskSettings of(
|
||||
VoyageAIRerankTaskSettings originalSettings,
|
||||
VoyageAIRerankTaskSettings requestTaskSettings
|
||||
) {
|
||||
return new VoyageAIRerankTaskSettings(
|
||||
requestTaskSettings.getTopKDocumentsOnly() != null
|
||||
? requestTaskSettings.getTopKDocumentsOnly()
|
||||
: originalSettings.getTopKDocumentsOnly(),
|
||||
requestTaskSettings.getReturnDocuments() != null
|
||||
? requestTaskSettings.getReturnDocuments()
|
||||
: originalSettings.getReturnDocuments(),
|
||||
requestTaskSettings.getTruncation() != null ? requestTaskSettings.getTruncation() : originalSettings.getTruncation()
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean returnDocuments, Boolean truncation) {
|
||||
return new VoyageAIRerankTaskSettings(topKDocumentsOnly, returnDocuments, truncation);
|
||||
}
|
||||
|
||||
private final Integer topKDocumentsOnly;
|
||||
private final Boolean returnDocuments;
|
||||
private final Boolean truncation;
|
||||
|
||||
public VoyageAIRerankTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalBoolean());
|
||||
}
|
||||
|
||||
public VoyageAIRerankTaskSettings(
|
||||
@Nullable Integer topKDocumentsOnly,
|
||||
@Nullable Boolean doReturnDocuments,
|
||||
@Nullable Boolean truncation
|
||||
) {
|
||||
this.topKDocumentsOnly = topKDocumentsOnly;
|
||||
this.returnDocuments = doReturnDocuments;
|
||||
this.truncation = truncation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return topKDocumentsOnly == null && returnDocuments == null && truncation == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (topKDocumentsOnly != null) {
|
||||
builder.field(TOP_K_DOCS_ONLY, topKDocumentsOnly);
|
||||
}
|
||||
if (returnDocuments != null) {
|
||||
builder.field(RETURN_DOCUMENTS, returnDocuments);
|
||||
}
|
||||
if (truncation != null) {
|
||||
builder.field(TRUNCATION, truncation);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalInt(topKDocumentsOnly);
|
||||
out.writeOptionalBoolean(returnDocuments);
|
||||
out.writeOptionalBoolean(truncation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o;
|
||||
return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly)
|
||||
&& Objects.equals(returnDocuments, that.returnDocuments)
|
||||
&& Objects.equals(truncation, that.truncation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(truncation, returnDocuments, topKDocumentsOnly);
|
||||
}
|
||||
|
||||
public Integer getTopKDocumentsOnly() {
|
||||
return topKDocumentsOnly;
|
||||
}
|
||||
|
||||
public Boolean getDoesReturnDocuments() {
|
||||
return returnDocuments;
|
||||
}
|
||||
|
||||
public Boolean getReturnDocuments() {
|
||||
return returnDocuments;
|
||||
}
|
||||
|
||||
public Boolean getTruncation() {
|
||||
return truncation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
VoyageAIRerankTaskSettings updatedSettings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(newSettings));
|
||||
return VoyageAIRerankTaskSettings.of(this, updatedSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.action.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class VoyageAIActionCreatorTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testCreate_VoyageAIEmbeddingsModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.123,
|
||||
-0.123
|
||||
],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 123
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var model = VoyageAIEmbeddingsModelTests.createModel(
|
||||
getUrl(webServer),
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
|
||||
1024,
|
||||
1024,
|
||||
"model",
|
||||
VoyageAIEmbeddingType.FLOAT
|
||||
);
|
||||
var actionCreator = new VoyageAIActionCreator(sender, createWithEmptySettings(threadPool));
|
||||
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH);
|
||||
var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F }))));
|
||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().getFirst().getUri().getQuery());
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaType())
|
||||
);
|
||||
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"output_dtype",
|
||||
"float",
|
||||
"truncation",
|
||||
true,
|
||||
"input_type",
|
||||
"query",
|
||||
"output_dimension",
|
||||
1024,
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model"
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,413 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.action.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
|
||||
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class VoyageAIEmbeddingsActionTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.123,
|
||||
-0.123
|
||||
],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 123
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(
|
||||
getUrl(webServer),
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
|
||||
"model",
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
sender
|
||||
);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F }))));
|
||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().getFirst().getUri().getQuery());
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaType())
|
||||
);
|
||||
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
|
||||
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
equalTo(
|
||||
Map.of(
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"document",
|
||||
"output_dtype",
|
||||
"float",
|
||||
"truncation",
|
||||
true,
|
||||
"output_dimension",
|
||||
1024
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0,
|
||||
-1
|
||||
],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 123
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(
|
||||
getUrl(webServer),
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
|
||||
"model",
|
||||
VoyageAIEmbeddingType.INT8,
|
||||
sender
|
||||
);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertEquals(buildExpectationByte(List.of(new byte[] { 0, -1 })), result.asMap());
|
||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().getFirst().getUri().getQuery());
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaType())
|
||||
);
|
||||
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
|
||||
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"document",
|
||||
"output_dtype",
|
||||
"int8",
|
||||
"truncation",
|
||||
true,
|
||||
"output_dimension",
|
||||
1024
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0,
|
||||
-1
|
||||
],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 123
|
||||
}
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(
|
||||
getUrl(webServer),
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
|
||||
"model",
|
||||
VoyageAIEmbeddingType.BINARY,
|
||||
sender
|
||||
);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertEquals(buildExpectationBinary(List.of(new byte[] { 0, -1 })), result.asMap());
|
||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().getFirst().getUri().getQuery());
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaType())
|
||||
);
|
||||
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
|
||||
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"document",
|
||||
"output_dtype",
|
||||
"binary",
|
||||
"truncation",
|
||||
true,
|
||||
"output_dimension",
|
||||
1024
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<HttpResult> listener = (ActionListener<HttpResult>) invocation.getArguments()[2];
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer)))
|
||||
);
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<HttpResult> listener = (ActionListener<HttpResult>) invocation.getArguments()[2];
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer)))
|
||||
);
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsExceptionWithNullUrl() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request"));
|
||||
}
|
||||
|
||||
private ExecutableAction createAction(
|
||||
String url,
|
||||
String apiKey,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
@Nullable String modelName,
|
||||
@Nullable VoyageAIEmbeddingType embeddingType,
|
||||
Sender sender
|
||||
) {
|
||||
var model = VoyageAIEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType);
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "VoyageAI embeddings");
|
||||
var requestCreator = VoyageAIEmbeddingsRequestManager.of(model, threadPool);
|
||||
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
||||
public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase {
|
||||
public void testXContent_WritesAllFields_ServiceSettingsDefined() throws IOException {
|
||||
var entity = new VoyageAIEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
"https://www.abc.com",
|
||||
ServiceFields.SIMILARITY,
|
||||
SimilarityMeasure.DOT_PRODUCT.toString(),
|
||||
ServiceFields.DIMENSIONS,
|
||||
2048,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
512,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
"model",
|
||||
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
|
||||
"float"
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
),
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"float"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_WritesAllFields_ServiceSettingsDefined_Int8() throws IOException {
|
||||
var entity = new VoyageAIEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
"https://www.abc.com",
|
||||
ServiceFields.SIMILARITY,
|
||||
SimilarityMeasure.DOT_PRODUCT.toString(),
|
||||
ServiceFields.DIMENSIONS,
|
||||
2048,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
512,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
"model",
|
||||
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
|
||||
"int8"
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
),
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"int8"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_WritesAllFields_ServiceSettingsDefined_Binary() throws IOException {
|
||||
var entity = new VoyageAIEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
"https://www.abc.com",
|
||||
ServiceFields.SIMILARITY,
|
||||
SimilarityMeasure.DOT_PRODUCT.toString(),
|
||||
ServiceFields.DIMENSIONS,
|
||||
2048,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
512,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
"model",
|
||||
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
|
||||
"binary"
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
),
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"binary"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||
var entity = new VoyageAIEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS,
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"input":["abc"],"model":"model","input_type":"document"}"""));
|
||||
}
|
||||
|
||||
public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
|
||||
var entity = new VoyageAIEmbeddingsRequestEntity(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS,
|
||||
VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
MatcherAssert.assertThat(xContentResult, is("""
|
||||
{"input":["abc"],"model":"model"}"""));
|
||||
}
|
||||
|
||||
public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
|
||||
var thrownException = expectThrows(
|
||||
AssertionError.class,
|
||||
() -> VoyageAIEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED)
|
||||
);
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIEmbeddingsRequestTests extends ESTestCase {
|
||||
public void testCreateRequest_UrlDefined() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsModelTests.createModel("url", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model")
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "output_dtype", "float")));
|
||||
}
|
||||
|
||||
public void testCreateRequest_AllOptionsDefined() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsModelTests.createModel(
|
||||
"url",
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
null,
|
||||
null,
|
||||
"model"
|
||||
)
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("input", List.of("abc"), "model", "model", "input_type", "document", "output_dtype", "float"))
|
||||
);
|
||||
}
|
||||
|
||||
public void testCreateRequest_DimensionDefined() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsModelTests.createModel(
|
||||
"url",
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
null,
|
||||
2048,
|
||||
"model"
|
||||
)
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"document",
|
||||
"output_dtype",
|
||||
"float",
|
||||
"output_dimension",
|
||||
2048
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testCreateRequest_EmbeddingTypeDefined() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsModelTests.createModel(
|
||||
"url",
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
|
||||
null,
|
||||
2048,
|
||||
"model",
|
||||
VoyageAIEmbeddingType.BYTE
|
||||
)
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(
|
||||
Map.of(
|
||||
"input",
|
||||
List.of("abc"),
|
||||
"model",
|
||||
"model",
|
||||
"input_type",
|
||||
"document",
|
||||
"output_dtype",
|
||||
"int8",
|
||||
"output_dimension",
|
||||
2048
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testCreateRequest_InputTypeSearch() throws IOException {
|
||||
var request = createRequest(
|
||||
List.of("abc"),
|
||||
VoyageAIEmbeddingsModelTests.createModel(
|
||||
"url",
|
||||
"secret",
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null),
|
||||
null,
|
||||
null,
|
||||
"model"
|
||||
)
|
||||
);
|
||||
|
||||
var httpRequest = request.createHttpRequest();
|
||||
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
|
||||
MatcherAssert.assertThat(
|
||||
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
|
||||
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
|
||||
);
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
MatcherAssert.assertThat(
|
||||
requestMap,
|
||||
is(Map.of("input", List.of("abc"), "model", "model", "input_type", "query", "output_dtype", "float"))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsRequest createRequest(List<String> input, VoyageAIEmbeddingsModel model) {
|
||||
return new VoyageAIEmbeddingsRequest(input, model);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
|
||||
|
||||
import java.net.URI;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIRequestTests extends ESTestCase {
|
||||
|
||||
public void testDecorateWithHeaders() {
|
||||
var request = new HttpPost("http://www.abc.com");
|
||||
|
||||
VoyageAIRequest.decorateWithHeaders(
|
||||
request,
|
||||
new VoyageAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' }))
|
||||
);
|
||||
|
||||
assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc"));
|
||||
assertThat(request.getFirstHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
||||
|
||||
public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
],
|
||||
"top_k": 8
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
],
|
||||
"return_documents": true,
|
||||
"top_k": 8
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
],
|
||||
"return_documents": false,
|
||||
"top_k": 8
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
],
|
||||
"return_documents": false,
|
||||
"top_k": 8,
|
||||
"truncation": true
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
],
|
||||
"return_documents": false,
|
||||
"top_k": 8,
|
||||
"truncation": false
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc"
|
||||
]
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity(
|
||||
"query",
|
||||
List.of("abc", "def"),
|
||||
new VoyageAIRerankTaskSettings(8, null, null),
|
||||
"model"
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc",
|
||||
"def"
|
||||
],
|
||||
"top_k": 8
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
|
||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model": "model",
|
||||
"query": "query",
|
||||
"documents": [
|
||||
"abc",
|
||||
"def"
|
||||
]
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.sameInstance;
|
||||
|
||||
public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||
|
||||
private static final String API_KEY = "foo";
|
||||
|
||||
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
|
||||
var input = "input";
|
||||
var query = "query";
|
||||
var modelId = "model";
|
||||
|
||||
var request = createRequest(query, input, modelId, null);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
|
||||
assertThat(requestMap, aMapWithSize(3));
|
||||
assertThat(requestMap.get("documents"), is(List.of(input)));
|
||||
assertThat(requestMap.get("query"), is(query));
|
||||
assertThat(requestMap.get("model"), is(modelId));
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithTopNSet() throws IOException {
|
||||
var input = "input";
|
||||
var query = "query";
|
||||
var topK = 1;
|
||||
var modelId = "model";
|
||||
|
||||
var request = createRequest(query, input, modelId, topK);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
|
||||
assertThat(requestMap, aMapWithSize(4));
|
||||
assertThat(requestMap.get("documents"), is(List.of(input)));
|
||||
assertThat(requestMap.get("query"), is(query));
|
||||
assertThat(requestMap.get("top_k"), is(topK));
|
||||
assertThat(requestMap.get("model"), is(modelId));
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithModelSet() throws IOException {
|
||||
var input = "input";
|
||||
var query = "query";
|
||||
var modelId = "model";
|
||||
|
||||
var request = createRequest(query, input, modelId, null);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
|
||||
assertThat(requestMap, aMapWithSize(3));
|
||||
assertThat(requestMap.get("documents"), is(List.of(input)));
|
||||
assertThat(requestMap.get("query"), is(query));
|
||||
assertThat(requestMap.get("model"), is(modelId));
|
||||
}
|
||||
|
||||
public void testTruncate_DoesNotTruncate() {
|
||||
var request = createRequest("query", "input", "null", null);
|
||||
var truncatedRequest = request.truncate();
|
||||
|
||||
assertThat(truncatedRequest, sameInstance(request));
|
||||
}
|
||||
|
||||
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
|
||||
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
|
||||
return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIUtilsTests extends ESTestCase {
|
||||
|
||||
public void testCreateRequestSourceHeader() {
|
||||
var requestSourceHeader = VoyageAIUtils.createRequestSourceHeader();
|
||||
|
||||
assertThat(requestSourceHeader.getName(), is("Request-Source"));
|
||||
assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch"));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,432 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentParseException;
|
||||
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class VoyageAIEmbeddingsResponseEntityTests extends ESTestCase {
|
||||
public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
assertThat(
|
||||
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
|
||||
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F })))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
},
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 1,
|
||||
"embedding": [
|
||||
0.0123,
|
||||
-0.0123
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
assertThat(
|
||||
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
|
||||
is(
|
||||
List.of(
|
||||
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }),
|
||||
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F })
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenDataFieldIsNotPresent() {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"not_data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
java.lang.IllegalArgumentException.class,
|
||||
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Required [data]"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenDataFieldNotAnArray() {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": {
|
||||
"test": {
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
}
|
||||
},
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
XContentParseException.class,
|
||||
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embeddingzzz": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
XContentParseException.class,
|
||||
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingValueIsAString() {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
"abc"
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
XContentParseException.class,
|
||||
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]"));
|
||||
}
|
||||
|
||||
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
assertThat(
|
||||
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
|
||||
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F })))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
40294967295
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
assertThat(
|
||||
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
|
||||
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F })))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
{}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "voyage-3-large",
|
||||
"usage": {
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
XContentParseException.class,
|
||||
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]"));
|
||||
}
|
||||
|
||||
public void testFieldsInDifferentOrderServer() throws IOException {
|
||||
// The fields of the objects in the data array are reordered
|
||||
String response = """
|
||||
{
|
||||
"object": "list",
|
||||
"model": "voyage-3-large",
|
||||
"data": [
|
||||
{
|
||||
"embedding": [
|
||||
-0.9,
|
||||
0.5,
|
||||
0.3
|
||||
],
|
||||
"index": 0,
|
||||
"object": "embedding"
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.1,
|
||||
0.5
|
||||
],
|
||||
"object": "embedding"
|
||||
},
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.5,
|
||||
0.5
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 0
|
||||
}
|
||||
}""";
|
||||
|
||||
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
|
||||
List.of("abc", "def"),
|
||||
createModel("url", "api_key", null, "voyage-3-large")
|
||||
);
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
|
||||
request,
|
||||
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
assertThat(parsedResults, instanceOf(InferenceTextEmbeddingFloatResults.class));
|
||||
|
||||
assertThat(
|
||||
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
|
||||
is(
|
||||
List.of(
|
||||
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }),
|
||||
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.5F }),
|
||||
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.5F, 0.5F })
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class VoyageAIErrorResponseEntityTests extends ESTestCase {
|
||||
public void testFromResponse() {
|
||||
String message = "\"input\" length 2049 is larger than the largest allowed size 2048";
|
||||
String escapedMessage = message.replace("\\", "\\\\").replace("\"", "\\\"");
|
||||
String responseJson = Strings.format("""
|
||||
{
|
||||
"detail": "%s"
|
||||
}
|
||||
""", escapedMessage);
|
||||
|
||||
ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
assertNotNull(errorResponse);
|
||||
MatcherAssert.assertThat(errorResponse.getErrorMessage(), is(message));
|
||||
}
|
||||
|
||||
public void testFromResponse_noMessage() {
|
||||
String responseJson = """
|
||||
{
|
||||
"error": "abc"
|
||||
}
|
||||
""";
|
||||
|
||||
ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
MatcherAssert.assertThat(errorResponse, is(ErrorResponse.UNDEFINED_ERROR));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.response.voyageai;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class VoyageAIRerankResponseEntityTests extends ESTestCase {
|
||||
|
||||
public void testResponseLiteral() throws IOException {
|
||||
String responseLiteral = """
|
||||
{
|
||||
"object": "list",
|
||||
"model": "model",
|
||||
"data": [
|
||||
{
|
||||
"index": 2,
|
||||
"relevance_score": 0.98005307
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"relevance_score": 0.27904198
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.10194652
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 15
|
||||
}
|
||||
}
|
||||
""";
|
||||
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
|
||||
List<RankedDocsResults.RankedDoc> expected = responseLiteralDocs();
|
||||
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
|
||||
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
|
||||
}
|
||||
}
|
||||
|
||||
public void testGeneratedResponse() throws IOException {
|
||||
int numDocs = randomIntBetween(1, 10);
|
||||
|
||||
List<RankedDocsResults.RankedDoc> expected = new ArrayList<>(numDocs);
|
||||
StringBuilder responseBuilder = new StringBuilder();
|
||||
|
||||
responseBuilder.append("{");
|
||||
responseBuilder.append("\"model\": \"model\",");
|
||||
responseBuilder.append("\"object\": \"list\",");
|
||||
responseBuilder.append("\"data\": [");
|
||||
List<Integer> indices = linear(numDocs);
|
||||
List<Float> scores = linearFloats(numDocs);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
int index = indices.remove(randomInt(indices.size() - 1));
|
||||
|
||||
responseBuilder.append("{");
|
||||
responseBuilder.append("\"index\":").append(index).append(",");
|
||||
responseBuilder.append("\"relevance_score\":").append(scores.get(i).toString()).append("}");
|
||||
expected.add(new RankedDocsResults.RankedDoc(index, scores.get(i), null));
|
||||
if (i < numDocs - 1) {
|
||||
responseBuilder.append(",");
|
||||
}
|
||||
}
|
||||
responseBuilder.append("],");
|
||||
responseBuilder.append("\"usage\": {");
|
||||
responseBuilder.append("\"total_tokens\": 15}");
|
||||
responseBuilder.append("}");
|
||||
|
||||
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
|
||||
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
|
||||
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
|
||||
}
|
||||
}
|
||||
|
||||
private ArrayList<RankedDocsResults.RankedDoc> responseLiteralDocs() {
|
||||
var list = new ArrayList<RankedDocsResults.RankedDoc>();
|
||||
|
||||
list.add(new RankedDocsResults.RankedDoc(2, 0.98005307F, null));
|
||||
list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null));
|
||||
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null));
|
||||
return list;
|
||||
}
|
||||
|
||||
public void testResponseLiteralWithDocuments() throws IOException {
|
||||
String responseLiteralWithDocuments = """
|
||||
{
|
||||
"object": "list",
|
||||
"model": "model",
|
||||
"data": [
|
||||
{
|
||||
"document": "Washington, D.C..",
|
||||
"index": 2,
|
||||
"relevance_score": 0.98005307
|
||||
},
|
||||
{
|
||||
"document": "Capital punishment has existed in the United States since beforethe United States was a country. ",
|
||||
"index": 3,
|
||||
"relevance_score": 0.27904198
|
||||
},
|
||||
{
|
||||
"document": "Carson City is the capital city of the American state of Nevada.",
|
||||
"index": 0,
|
||||
"relevance_score": 0.10194652
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"total_tokens": 15
|
||||
}
|
||||
}
|
||||
""";
|
||||
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
|
||||
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
|
||||
MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
|
||||
}
|
||||
|
||||
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
|
||||
new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."),
|
||||
new RankedDocsResults.RankedDoc(
|
||||
3,
|
||||
0.27904198F,
|
||||
"Capital punishment has existed in the United States since beforethe United States was a country. "
|
||||
),
|
||||
new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.")
|
||||
);
|
||||
|
||||
private ArrayList<Integer> linear(int n) {
|
||||
ArrayList<Integer> list = new ArrayList<>();
|
||||
for (int i = 0; i <= n; i++) {
|
||||
list.add(i);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// creates a list of doubles of monotonically decreasing magnitude
|
||||
private ArrayList<Float> linearFloats(int n) {
|
||||
ArrayList<Float> list = new ArrayList<>();
|
||||
float startValue = 1.0f;
|
||||
float decrement = startValue / n + 1;
|
||||
for (int i = 0; i <= n; i++) {
|
||||
list.add(startValue - (i * decrement));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.voyageai;
|
||||
|
||||
import org.apache.http.Header;
|
||||
import org.apache.http.HeaderElement;
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.StatusLine;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class VoyageAIResponseHandlerTests extends ESTestCase {
|
||||
public void testCheckForFailureStatusCode_DoesNotThrowForStatusCodesBetween200And299() {
|
||||
callCheckForFailureStatusCode(randomIntBetween(200, 299), "id");
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor503() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id"));
|
||||
assertFalse(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString("Received a server error status code for request from inference entity id [id] status [503]")
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id"));
|
||||
assertTrue(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString("Received a server error status code for request from inference entity id [id] status [500]")
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor429_WithShouldRetryTrue() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id"));
|
||||
assertTrue(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString("Received a rate limit status code for request from inference entity id [id] status [429]")
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor400() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id"));
|
||||
assertFalse(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString("Received an input validation error response for request from inference entity id [id] status [400]")
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor400_InputsTooLarge() {
|
||||
var exception = expectThrows(
|
||||
RetryException.class,
|
||||
() -> callCheckForFailureStatusCode(400, "\"input\" length 2049 is larger than the largest allowed size 2048", "id")
|
||||
);
|
||||
assertFalse(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString("Received an input validation error response for request from inference entity id [id] status [400]")
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor401() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId"));
|
||||
assertFalse(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(
|
||||
exception.getCause().getMessage(),
|
||||
containsString(
|
||||
"Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]"
|
||||
)
|
||||
);
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED));
|
||||
}
|
||||
|
||||
public void testCheckForFailureStatusCode_ThrowsFor402() {
|
||||
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "inferenceEntityId"));
|
||||
assertFalse(exception.shouldRetry());
|
||||
MatcherAssert.assertThat(exception.getCause().getMessage(), containsString("Payment required"));
|
||||
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED));
|
||||
}
|
||||
|
||||
private static void callCheckForFailureStatusCode(int statusCode, String modelId) {
|
||||
callCheckForFailureStatusCode(statusCode, null, modelId);
|
||||
}
|
||||
|
||||
private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) {
|
||||
var statusLine = mock(StatusLine.class);
|
||||
when(statusLine.getStatusCode()).thenReturn(statusCode);
|
||||
|
||||
var httpResponse = mock(HttpResponse.class);
|
||||
when(httpResponse.getStatusLine()).thenReturn(statusLine);
|
||||
var header = mock(Header.class);
|
||||
when(header.getElements()).thenReturn(new HeaderElement[] {});
|
||||
when(httpResponse.getFirstHeader(anyString())).thenReturn(header);
|
||||
|
||||
String escapedErrorMessage = errorMessage != null ? errorMessage.replace("\\", "\\\\").replace("\"", "\\\"") : errorMessage;
|
||||
|
||||
String responseJson = Strings.format("""
|
||||
{
|
||||
"detail": "%s"
|
||||
}
|
||||
""", escapedErrorMessage);
|
||||
|
||||
var mockRequest = mock(Request.class);
|
||||
when(mockRequest.getInferenceEntityId()).thenReturn(modelId);
|
||||
var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8));
|
||||
var handler = new VoyageAIResponseHandler("", (request, result) -> null);
|
||||
|
||||
handler.checkForFailureStatusCode(mockRequest, httpResult);
|
||||
}
|
||||
}
|
|
@ -146,4 +146,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<I
|
|||
);
|
||||
}
|
||||
|
||||
public static Map<String, Object> buildExpectationBinary(List<byte[]> embeddings) {
|
||||
return Map.of("text_embedding_bits", embeddings.stream().map(InferenceByteEmbedding::new).toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase<VoyageAIServiceSettings> {
|
||||
|
||||
public static VoyageAIServiceSettings createRandomWithNonNullUrl() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
/**
|
||||
* The created settings can have a url set to null.
|
||||
*/
|
||||
public static VoyageAIServiceSettings createRandom() {
|
||||
var model = randomAlphaOfLength(15);
|
||||
|
||||
return new VoyageAIServiceSettings(model, RateLimitSettingsTests.createRandom());
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null)));
|
||||
}
|
||||
|
||||
public void testFromMap_WithRateLimit() {
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model,
|
||||
RateLimitSettings.FIELD_NAME,
|
||||
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, new RateLimitSettings(3))));
|
||||
}
|
||||
|
||||
public void testFromMap_WhenUsingModelId() {
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null)));
|
||||
}
|
||||
|
||||
public void testXContent_WritesModelId() throws IOException {
|
||||
var entity = new VoyageAIServiceSettings("model", new RateLimitSettings(1));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, is("""
|
||||
{"model_id":"model","rate_limit":{"requests_per_minute":1}}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<VoyageAIServiceSettings> instanceReader() {
|
||||
return VoyageAIServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIServiceSettings createTestInstance() {
|
||||
return createRandomWithNonNullUrl();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getServiceSettingsMap(String model) {
|
||||
var map = new HashMap<String, Object>();
|
||||
|
||||
map.put(VoyageAIServiceSettings.MODEL_ID, model);
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,209 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIEmbeddingsModelTests extends ESTestCase {
|
||||
|
||||
public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() {
|
||||
var model = createModel("url", "api_key", null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED);
|
||||
MatcherAssert.assertThat(overriddenModel, is(model));
|
||||
}
|
||||
|
||||
public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() {
|
||||
var model = createModel("url", "api_key", null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED);
|
||||
MatcherAssert.assertThat(overriddenModel, is(model));
|
||||
}
|
||||
|
||||
public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() {
|
||||
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
|
||||
|
||||
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
|
||||
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
|
||||
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) {
|
||||
return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
@Nullable Integer tokenLimit,
|
||||
@Nullable Integer dimensions,
|
||||
String model
|
||||
) {
|
||||
return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Integer tokenLimit,
|
||||
@Nullable Integer dimensions,
|
||||
String model
|
||||
) {
|
||||
return new VoyageAIEmbeddingsModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
dimensions,
|
||||
tokenLimit,
|
||||
false
|
||||
),
|
||||
taskSettings,
|
||||
chunkingSettings,
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
@Nullable Integer tokenLimit,
|
||||
@Nullable Integer dimensions,
|
||||
String model
|
||||
) {
|
||||
return new VoyageAIEmbeddingsModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
dimensions,
|
||||
tokenLimit,
|
||||
false
|
||||
),
|
||||
taskSettings,
|
||||
null,
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
@Nullable Integer tokenLimit,
|
||||
@Nullable Integer dimensions,
|
||||
String model,
|
||||
VoyageAIEmbeddingType embeddingType
|
||||
) {
|
||||
return new VoyageAIEmbeddingsModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
embeddingType,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
dimensions,
|
||||
tokenLimit,
|
||||
false
|
||||
),
|
||||
taskSettings,
|
||||
null,
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIEmbeddingsModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
VoyageAIEmbeddingsTaskSettings taskSettings,
|
||||
@Nullable Integer tokenLimit,
|
||||
@Nullable Integer dimensions,
|
||||
String model,
|
||||
@Nullable SimilarityMeasure similarityMeasure
|
||||
) {
|
||||
return new VoyageAIEmbeddingsModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
similarityMeasure,
|
||||
dimensions,
|
||||
tokenLimit,
|
||||
false
|
||||
),
|
||||
taskSettings,
|
||||
null,
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<VoyageAIEmbeddingsServiceSettings> {
|
||||
public static VoyageAIEmbeddingsServiceSettings createRandom() {
|
||||
SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT;
|
||||
Integer dims = 1024;
|
||||
Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
|
||||
Boolean dimensionSetByUser = randomBoolean();
|
||||
|
||||
var commonSettings = VoyageAIServiceSettingsTests.createRandom();
|
||||
|
||||
return new VoyageAIEmbeddingsServiceSettings(
|
||||
commonSettings,
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
similarityMeasure,
|
||||
dims,
|
||||
maxInputTokens,
|
||||
dimensionSetByUser
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
|
||||
var dims = 1536;
|
||||
var maxInputTokens = 512;
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.SIMILARITY,
|
||||
similarity,
|
||||
ServiceFields.DIMENSIONS,
|
||||
dims,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
dims,
|
||||
maxInputTokens,
|
||||
false
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_WithModelId() {
|
||||
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
|
||||
var maxInputTokens = 512;
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.SIMILARITY,
|
||||
similarity,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
null,
|
||||
maxInputTokens,
|
||||
false
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_WithModelId_WithDimensions() {
|
||||
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
|
||||
var dims = 1536;
|
||||
var maxInputTokens = 512;
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.SIMILARITY,
|
||||
similarity,
|
||||
ServiceFields.DIMENSIONS,
|
||||
dims,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
dims,
|
||||
maxInputTokens,
|
||||
true
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_DimensionsSetByUserIsFalseInRequestContext() {
|
||||
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
|
||||
var maxInputTokens = 512;
|
||||
var model = "model";
|
||||
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.SIMILARITY,
|
||||
similarity,
|
||||
DIMENSIONS_SET_BY_USER,
|
||||
true,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
null,
|
||||
maxInputTokens,
|
||||
false
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_DimensionsSetByUserIsSetInPersistentContext() {
|
||||
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
|
||||
var maxInputTokens = 512;
|
||||
var model = "model";
|
||||
var dimensionsSetByUser = randomBoolean();
|
||||
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(
|
||||
ServiceFields.SIMILARITY,
|
||||
similarity,
|
||||
DIMENSIONS_SET_BY_USER,
|
||||
dimensionsSetByUser,
|
||||
ServiceFields.MAX_INPUT_TOKENS,
|
||||
maxInputTokens,
|
||||
VoyageAIServiceSettings.MODEL_ID,
|
||||
model
|
||||
)
|
||||
),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
serviceSettings,
|
||||
is(
|
||||
new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings(model, null),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.DOT_PRODUCT,
|
||||
null,
|
||||
maxInputTokens,
|
||||
dimensionsSetByUser
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_InvalidSimilarity_ThrowsError() {
|
||||
var similarity = "by_size";
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> VoyageAIEmbeddingsServiceSettings.fromMap(
|
||||
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
is(
|
||||
"Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] "
|
||||
+ "must be one of [cosine, dot_product, l2_norm];"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("checkstyle:LineLength")
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
var serviceSettings = new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings("model", new RateLimitSettings(3)),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.COSINE,
|
||||
5,
|
||||
10,
|
||||
false
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
assertThat(
|
||||
xContentResult,
|
||||
is(
|
||||
"""
|
||||
{"model_id":"model","""
|
||||
+ """
|
||||
"rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}"""
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("checkstyle:LineLength")
|
||||
public void testToXContent_WritesAllValues_DimensionSetByUser() throws IOException {
|
||||
var serviceSettings = new VoyageAIEmbeddingsServiceSettings(
|
||||
new VoyageAIServiceSettings("model", new RateLimitSettings(3)),
|
||||
VoyageAIEmbeddingType.FLOAT,
|
||||
SimilarityMeasure.COSINE,
|
||||
5,
|
||||
10,
|
||||
true
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
assertThat(
|
||||
xContentResult,
|
||||
is(
|
||||
"""
|
||||
{"model_id":"model","""
|
||||
+ """
|
||||
"rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}"""
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<VoyageAIEmbeddingsServiceSettings> instanceReader() {
|
||||
return VoyageAIEmbeddingsServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIEmbeddingsServiceSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIEmbeddingsServiceSettings mutateInstance(VoyageAIEmbeddingsServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, VoyageAIEmbeddingsServiceSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getServiceSettingsMap(String model) {
|
||||
return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch;
|
||||
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class VoyageAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<VoyageAIEmbeddingsTaskSettings> {
|
||||
|
||||
public static VoyageAIEmbeddingsTaskSettings createRandom() {
|
||||
var inputType = randomBoolean() ? randomWithIngestAndSearch() : null;
|
||||
var truncation = randomBoolean();
|
||||
|
||||
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation);
|
||||
}
|
||||
|
||||
public void testIsEmpty() {
|
||||
var randomSettings = createRandom();
|
||||
var stringRep = Strings.toString(randomSettings);
|
||||
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() {
|
||||
var initialSettings = createRandom();
|
||||
var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null);
|
||||
Map<String, Object> newSettingsMap = new HashMap<>();
|
||||
VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
|
||||
Collections.unmodifiableMap(newSettingsMap)
|
||||
);
|
||||
assertEquals(initialSettings.getInputType(), updatedSettings.getInputType());
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_Updated_UseNewSettings() {
|
||||
var initialSettings = createRandom();
|
||||
var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithIngestAndSearch(), randomBoolean());
|
||||
Map<String, Object> newSettingsMap = new HashMap<>();
|
||||
newSettingsMap.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString());
|
||||
VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
|
||||
Collections.unmodifiableMap(newSettingsMap)
|
||||
);
|
||||
assertEquals(newSettings.getInputType(), updatedSettings.getInputType());
|
||||
}
|
||||
|
||||
public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
|
||||
MatcherAssert.assertThat(
|
||||
VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())),
|
||||
is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_CreatesEmptySettings_WhenMapIsNull() {
|
||||
MatcherAssert.assertThat(
|
||||
VoyageAIEmbeddingsTaskSettings.fromMap(null),
|
||||
is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
|
||||
MatcherAssert.assertThat(
|
||||
VoyageAIEmbeddingsTaskSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, false)
|
||||
)
|
||||
),
|
||||
is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
|
||||
var exception = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
|
||||
new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc", VoyageAIServiceFields.TRUNCATION, false))
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
exception.getMessage(),
|
||||
is(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];",
|
||||
getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsFailure_WhenTruncationIsInvalid() {
|
||||
var exception = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
|
||||
new HashMap<>(
|
||||
Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, "abc")
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
exception.getMessage(),
|
||||
is("Validation Failed: 1: field [truncation] is not of the expected type. The value [abc] cannot be converted to a [Boolean];")
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() {
|
||||
var exception = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
|
||||
new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString()))
|
||||
)
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
exception.getMessage(),
|
||||
is(
|
||||
Strings.format(
|
||||
"Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];",
|
||||
getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static <E extends Enum<E>> String getValidValuesSortedAndCombined(EnumSet<E> validValues) {
|
||||
var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
|
||||
Arrays.sort(validValuesAsStrings);
|
||||
|
||||
return String.join(", ", validValuesAsStrings);
|
||||
}
|
||||
|
||||
public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
|
||||
var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED, null));
|
||||
MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
|
||||
}
|
||||
|
||||
public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() {
|
||||
var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false);
|
||||
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
|
||||
taskSettings,
|
||||
VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
||||
InputType.UNSPECIFIED
|
||||
);
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
|
||||
}
|
||||
|
||||
public void testOf_UsesRequestTaskSettings() {
|
||||
var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null);
|
||||
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
|
||||
taskSettings,
|
||||
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
|
||||
InputType.UNSPECIFIED
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true)));
|
||||
}
|
||||
|
||||
public void testOf_UsesRequestTaskSettings_AndRequestInputType() {
|
||||
var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, true);
|
||||
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
|
||||
taskSettings,
|
||||
new VoyageAIEmbeddingsTaskSettings((InputType) null, null),
|
||||
InputType.INGEST
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true)));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<VoyageAIEmbeddingsTaskSettings> instanceReader() {
|
||||
return VoyageAIEmbeddingsTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIEmbeddingsTaskSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIEmbeddingsTaskSettings mutateInstance(VoyageAIEmbeddingsTaskSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, VoyageAIEmbeddingsTaskSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getTaskSettingsMapEmpty() {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
public static Map<String, Object> getTaskSettingsMap(@Nullable InputType inputType) {
|
||||
var map = new HashMap<String, Object>();
|
||||
|
||||
if (inputType != null) {
|
||||
map.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
|
||||
public class VoyageAIRerankModelTests {
|
||||
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK, @Nullable Boolean truncation) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
ESTestCase.randomAlphaOfLength(10),
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, null, truncation),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
ESTestCase.randomAlphaOfLength(10),
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, null, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
ESTestCase.randomAlphaOfLength(10),
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, null, null),
|
||||
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK, Boolean returnDocuments, Boolean truncation) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
ESTestCase.randomAlphaOfLength(10),
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
|
||||
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankModel createModel(
|
||||
String url,
|
||||
String modelId,
|
||||
@Nullable Integer topK,
|
||||
Boolean returnDocuments,
|
||||
Boolean truncation
|
||||
) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
|
||||
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
|
||||
);
|
||||
}
|
||||
|
||||
public static VoyageAIRerankModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
String modelId,
|
||||
@Nullable Integer topK,
|
||||
Boolean returnDocuments,
|
||||
Boolean truncation
|
||||
) {
|
||||
return new VoyageAIRerankModel(
|
||||
"id",
|
||||
"service",
|
||||
url,
|
||||
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
|
||||
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
||||
|
||||
public class VoyageAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<VoyageAIRerankServiceSettings> {
|
||||
public static VoyageAIRerankServiceSettings createRandom() {
|
||||
return new VoyageAIRerankServiceSettings(
|
||||
new VoyageAIServiceSettings(randomAlphaOfLength(10), RateLimitSettingsTests.createRandom())
|
||||
);
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
var url = "http://www.abc.com";
|
||||
var model = "model";
|
||||
|
||||
var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(model, null));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
serviceSettings.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"model_id":"model",
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 2000
|
||||
}
|
||||
}
|
||||
"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<VoyageAIRerankServiceSettings> instanceReader() {
|
||||
return VoyageAIRerankServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIRerankServiceSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIRerankServiceSettings mutateInstance(VoyageAIRerankServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, VoyageAIRerankServiceSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIRerankServiceSettings mutateInstanceForVersion(VoyageAIRerankServiceSettings instance, TransportVersion version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
public static Map<String, Object> getServiceSettingsMap(@Nullable String model) {
|
||||
return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
||||
public class VoyageAIRerankTaskSettingsTests extends AbstractWireSerializingTestCase<VoyageAIRerankTaskSettings> {
|
||||
|
||||
public static VoyageAIRerankTaskSettings createRandom() {
|
||||
var returnDocuments = randomBoolean() ? randomBoolean() : null;
|
||||
var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null;
|
||||
var truncation = randomBoolean() ? randomBoolean() : null;
|
||||
|
||||
return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments, truncation);
|
||||
}
|
||||
|
||||
public void testFromMap_WithInvalidTruncation_ThrowsValidationException() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
true,
|
||||
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
|
||||
5,
|
||||
VoyageAIServiceFields.TRUNCATION,
|
||||
"invalid"
|
||||
);
|
||||
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||
assertThat(thrownException.getMessage(), containsString("field [truncation] is not of the expected type"));
|
||||
}
|
||||
|
||||
public void testFromMap_WithValidValues_ReturnsSettings() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
true,
|
||||
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
|
||||
5,
|
||||
VoyageAIServiceFields.TRUNCATION,
|
||||
true
|
||||
);
|
||||
var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap));
|
||||
assertTrue(settings.getReturnDocuments());
|
||||
assertEquals(5, settings.getTopKDocumentsOnly().intValue());
|
||||
assertTrue(settings.getTruncation());
|
||||
}
|
||||
|
||||
public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
|
||||
var settings = VoyageAIRerankTaskSettings.fromMap(Map.of());
|
||||
assertNull(settings.getReturnDocuments());
|
||||
assertNull(settings.getTopKDocumentsOnly());
|
||||
assertNull(settings.getTruncation());
|
||||
}
|
||||
|
||||
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
"invalid",
|
||||
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
|
||||
5
|
||||
);
|
||||
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
|
||||
}
|
||||
|
||||
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
true,
|
||||
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
|
||||
"invalid"
|
||||
);
|
||||
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];")
|
||||
);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
|
||||
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
|
||||
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of());
|
||||
assertEquals(initialSettings, updatedSettings);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
|
||||
Map<String, Object> newSettings = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false);
|
||||
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertFalse(updatedSettings.getReturnDocuments());
|
||||
assertTrue(updatedSettings.getTruncation());
|
||||
assertEquals(initialSettings.getTopKDocumentsOnly(), updatedSettings.getTopKDocumentsOnly());
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
|
||||
Map<String, Object> newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 7);
|
||||
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertTrue(updatedSettings.getTruncation());
|
||||
assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue());
|
||||
assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments());
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
|
||||
Map<String, Object> newSettings = Map.of(
|
||||
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
false,
|
||||
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
|
||||
7
|
||||
);
|
||||
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertTrue(updatedSettings.getTruncation());
|
||||
assertFalse(updatedSettings.getReturnDocuments());
|
||||
assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<VoyageAIRerankTaskSettings> instanceReader() {
|
||||
return VoyageAIRerankTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIRerankTaskSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VoyageAIRerankTaskSettings mutateInstance(VoyageAIRerankTaskSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, VoyageAIRerankTaskSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getTaskSettingsMapEmpty() {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
public static Map<String, Object> getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, Boolean returnDocuments) {
|
||||
var map = new HashMap<String, Object>();
|
||||
|
||||
if (topNDocumentsOnly != null) {
|
||||
map.put(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topNDocumentsOnly.toString());
|
||||
}
|
||||
|
||||
if (returnDocuments != null) {
|
||||
map.put(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments.toString());
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue