feat: VoyageAI integration (#122134)

* VoyageAI embeddings and rerank:
 - embeddings works, tested
 - initial rerank code

What's missing:
 - unit and integration tests
 - rerank request/response mapping and verification

* VoyageAI embeddings and rerank:
 - embeddings works, tested
 - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank)

What's missing:
 - unit and integration tests

* VoyageAI embeddings and rerank:
 - embeddings works, tested
 - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank)

What's missing:
 - unit and integration tests

* VoyageAI embeddings and rerank:
 - embeddings works, tested
 - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank)

What's missing:
 - unit and integration tests

* Adding initial tests
Moving dimensions to ServiceSettings

* Correcting the TransportVersions.java

* Correcting due to comments

* Adding BIT support

* Initial tests

* More tests

* More tests/corrections

* Removing warnings

* Further tests

* Transport version correction

* Adding changelog and correcting TransportVersions

* Spotless tests

* Changes due to the comments

* Changes due to the comments

* Correcting QA tests

* Correcting QA tests

---------

Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
Co-authored-by: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com>
This commit is contained in:
fzowl 2025-02-20 22:11:58 +01:00 committed by GitHub
parent 171a3b93f9
commit 521f8554c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 8140 additions and 5 deletions

View File

@ -0,0 +1,5 @@
pr: 122134
summary: Adding integration for VoyageAI embeddings and rerank models
area: Machine Learning
type: enhancement
issues: []

View File

@ -180,6 +180,7 @@ public class TransportVersions {
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -199,7 +200,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS = def(9_011_0_00);
public static final TransportVersion REMOVE_REPOSITORY_CONFLICT_MESSAGE = def(9_012_0_00);
public static final TransportVersion RERANKER_FAILURES_ALLOWED = def(9_013_0_00);
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View File

@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(19));
assertThat(services.size(), equalTo(20));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -53,6 +53,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
@ -62,7 +63,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(14));
assertThat(services.size(), equalTo(15));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -85,6 +86,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
@ -94,7 +96,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
@ -103,7 +105,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
}
assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
providers
);
}

View File

@ -90,6 +90,11 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
import java.util.ArrayList;
import java.util.List;
@ -142,6 +147,7 @@ public class InferenceNamedWriteablesProvider {
addEisNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addUnifiedNamedWriteables(namedWriteables);
@ -626,6 +632,28 @@ public class InferenceNamedWriteablesProvider {
);
}
private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
VoyageAIEmbeddingsServiceSettings.NAME,
VoyageAIEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
);
}
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(

View File

@ -128,6 +128,7 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import java.util.ArrayList;
@ -359,6 +360,7 @@ public class InferencePlugin extends Plugin
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.voyageai;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
*/
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;
public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
@Override
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
@Override
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.voyageai;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import java.util.Map;
public interface VoyageAIActionVisitor {
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
}

View File

@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
private static ResponseHandler createEmbeddingsHandler() {
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
}
public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}
private final VoyageAIEmbeddingsModel model;
private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}
@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
import java.util.Map;
import java.util.Objects;
abstract class VoyageAIRequestManager extends BaseRequestManager {
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
"voyage-multimodal-3",
"embed_multimodal",
"voyage-3-large",
"embed_large",
"voyage-code-3",
"embed_large",
"voyage-3",
"embed_medium",
"voyage-3-lite",
"embed_small",
"voyage-finance-2",
"embed_large",
"voyage-law-2",
"embed_large",
"voyage-code-2",
"embed_large",
"rerank-2",
"rerank_large",
"rerank-2-lite",
"rerank_small"
);
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}
record RateLimitGrouping(int apiKeyHash) {
public static RateLimitGrouping of(VoyageAIModel model) {
Objects.requireNonNull(model);
String modelId = model.getServiceSettings().modelId();
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);
return new RateLimitGrouping(modelFamily.hashCode());
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import java.util.Objects;
import java.util.function.Supplier;
public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();
private static ResponseHandler createVoyageAIResponseHandler() {
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
}
public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}
private final VoyageAIRerankModel model;
private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}
@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}

View File

@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {
private final VoyageAIAccount account;
private final List<String> input;
private final VoyageAIEmbeddingsServiceSettings serviceSettings;
private final VoyageAIEmbeddingsTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
Objects.requireNonNull(embeddingsModel);
account = VoyageAIAccount.of(embeddingsModel);
this.input = Objects.requireNonNull(input);
serviceSettings = embeddingsModel.getServiceSettings();
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new VoyageAIEmbeddingsRequestEntity(input, serviceSettings, taskSettings, model))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
decorateWithHeaders(httpPost, account);
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public String getInferenceEntityId() {
return inferenceEntityId;
}
@Override
public URI getURI() {
return account.uri();
}
@Override
public Request truncate() {
return this;
}
@Override
public boolean[] getTruncationInfo() {
return null;
}
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
return taskSettings;
}
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
return serviceSettings;
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.invalidInputTypeMessage;
public record VoyageAIEmbeddingsRequestEntity(
List<String> input,
VoyageAIEmbeddingsServiceSettings serviceSettings,
VoyageAIEmbeddingsTaskSettings taskSettings,
String model
) implements ToXContentObject {
private static final String DOCUMENT = "document";
private static final String QUERY = "query";
private static final String INPUT_FIELD = "input";
private static final String MODEL_FIELD = "model";
public static final String INPUT_TYPE_FIELD = "input_type";
public static final String TRUNCATION_FIELD = "truncation";
public static final String OUTPUT_DIMENSION = "output_dimension";
static final String OUTPUT_DTYPE_FIELD = "output_dtype";
public VoyageAIEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Objects.requireNonNull(model);
Objects.requireNonNull(taskSettings);
Objects.requireNonNull(serviceSettings);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INPUT_FIELD, input);
builder.field(MODEL_FIELD, model);
var inputType = convertToString(taskSettings.getInputType());
if (inputType != null) {
builder.field(INPUT_TYPE_FIELD, inputType);
}
if (taskSettings.getTruncation() != null) {
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
}
if (serviceSettings.dimensions() != null) {
builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions());
}
if (serviceSettings.getEmbeddingType() != null) {
builder.field(OUTPUT_DTYPE_FIELD, serviceSettings.getEmbeddingType().toRequestString());
}
builder.endObject();
return builder;
}
static String convertToString(InputType inputType) {
return switch (inputType) {
case null -> null;
case INGEST -> DOCUMENT;
case SEARCH -> QUERY;
default -> {
assert false : invalidInputTypeMessage(inputType);
yield null;
}
};
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
public abstract class VoyageAIRequest implements Request {
public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) {
request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
request.setHeader(createAuthBearerHeader(account.apiKey()));
request.setHeader(VoyageAIUtils.createRequestSourceHeader());
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
public class VoyageAIRerankRequest extends VoyageAIRequest {
private final VoyageAIAccount account;
private final String query;
private final List<String> input;
private final VoyageAIRerankTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
Objects.requireNonNull(model);
this.account = VoyageAIAccount.of(model);
this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
taskSettings = model.getTaskSettings();
this.model = model.getServiceSettings().modelId();
inferenceEntityId = model.getInferenceEntityId();
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new VoyageAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
decorateWithHeaders(httpPost, account);
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public String getInferenceEntityId() {
return inferenceEntityId;
}
@Override
public URI getURI() {
return account.uri();
}
@Override
public Request truncate() {
return this;
}
@Override
public boolean[] getTruncationInfo() {
return null;
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record VoyageAIRerankRequestEntity(String model, String query, List<String> documents, VoyageAIRerankTaskSettings taskSettings)
implements
ToXContentObject {
private static final String DOCUMENTS_FIELD = "documents";
private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
public static final String TRUNCATION_FIELD = "truncation";
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
public VoyageAIRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(model);
Objects.requireNonNull(taskSettings);
}
public VoyageAIRerankRequestEntity(String query, List<String> input, VoyageAIRerankTaskSettings taskSettings, String model) {
this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_FIELD, model);
builder.field(QUERY_FIELD, query);
builder.field(DOCUMENTS_FIELD, documents);
if (taskSettings.getDoesReturnDocuments() != null) {
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
}
if (taskSettings.getTopKDocumentsOnly() != null) {
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
}
if (taskSettings.getTruncation() != null) {
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
}
builder.endObject();
return builder;
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
public class VoyageAIUtils {
public static final String HOST = "api.voyageai.com";
public static final String VERSION_1 = "v1";
public static final String EMBEDDINGS_PATH = "embeddings";
public static final String RERANK_PATH = "rerank";
public static final String REQUEST_SOURCE_HEADER = "Request-Source";
public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch";
public static Header createRequestSourceHeader() {
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);
}
private VoyageAIUtils() {}
}

View File

@ -0,0 +1,197 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase;
public class VoyageAIEmbeddingsResponseEntity {
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
private static String supportedEmbeddingTypes() {
String[] validTypes = new String[] {
toLowerCase(VoyageAIEmbeddingType.FLOAT),
toLowerCase(VoyageAIEmbeddingType.INT8),
toLowerCase(VoyageAIEmbeddingType.BIT) };
Arrays.sort(validTypes);
return String.join(", ", validTypes);
}
record EmbeddingInt8Result(List<EmbeddingInt8ResultEntry> entries) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingInt8Result, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingInt8Result.class.getSimpleName(),
true,
args -> new EmbeddingInt8Result((List<EmbeddingInt8ResultEntry>) args[0])
);
static {
PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data"));
}
}
record EmbeddingInt8ResultEntry(Integer index, List<Integer> embedding) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingInt8ResultEntry, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingInt8ResultEntry.class.getSimpleName(),
true,
args -> new EmbeddingInt8ResultEntry((Integer) args[0], (List<Integer>) args[1])
);
static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareIntArray(constructorArg(), new ParseField("embedding"));
}
private static void checkByteBounds(Integer value) {
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
}
}
public InferenceByteEmbedding toInferenceByteEmbedding() {
embedding.forEach(EmbeddingInt8ResultEntry::checkByteBounds);
return InferenceByteEmbedding.of(embedding.stream().map(Integer::byteValue).toList());
}
}
record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> entries) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResult.class.getSimpleName(),
true,
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
);
static {
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
}
}
record EmbeddingFloatResultEntry(Integer index, List<Float> embedding) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResultEntry, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResultEntry.class.getSimpleName(),
true,
args -> new EmbeddingFloatResultEntry((Integer) args[0], (List<Float>) args[1])
);
static {
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
}
public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFloatEmbedding() {
return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embedding);
}
}
/**
* Parses the VoyageAI json response.
* For a request like:
*
* <pre>
* <code>
* {
* "input": [
* "Sample text 1",
* "Sample text 2"
* ],
* "model": "voyage-3-large"
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "object": "list",
* "data": [
* {
* "object": "embedding",
* "embedding": [
* -0.009327292,
* -0.0028842222,
* ],
* "index": 0
* },
* {
* "object": "embedding",
* "embedding": [ ... ],
* "index": 1
* }
* ],
* "model": "voyage-3-large",
* "usage": {
* "total_tokens": 10
* }
* }
* </code>
* </pre>
*/
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
VoyageAIEmbeddingType embeddingType = ((VoyageAIEmbeddingsRequest) request).getServiceSettings().getEmbeddingType();
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
if (embeddingType == null || embeddingType == VoyageAIEmbeddingType.FLOAT) {
var embeddingResult = EmbeddingFloatResult.PARSER.apply(jsonParser, null);
List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding> embeddingList = embeddingResult.entries.stream()
.map(EmbeddingFloatResultEntry::toInferenceFloatEmbedding)
.toList();
return new InferenceTextEmbeddingFloatResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.INT8) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
List<InferenceByteEmbedding> embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
return new InferenceTextEmbeddingByteResults(embeddingList);
} else if (embeddingType == VoyageAIEmbeddingType.BIT || embeddingType == VoyageAIEmbeddingType.BINARY) {
var embeddingResult = EmbeddingInt8Result.PARSER.apply(jsonParser, null);
List<InferenceByteEmbedding> embeddingList = embeddingResult.entries.stream()
.map(EmbeddingInt8ResultEntry::toInferenceByteEmbedding)
.toList();
return new InferenceTextEmbeddingBitResults(embeddingList);
} else {
throw new IllegalArgumentException(
"Illegal embedding_type value: " + embeddingType + ". Supported types are: " + VALID_EMBEDDING_TYPES_STRING
);
}
}
}
private VoyageAIEmbeddingsResponseEntity() {}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
public class VoyageAIErrorResponseEntity extends ErrorResponse {
private VoyageAIErrorResponseEntity(String errorMessage) {
super(errorMessage);
}
/**
* Parse an HTTP response into a VoyageAIErrorResponseEntity
*
* @param response The error response
* @return An error entity if the response is JSON with a `detail` field containing the error message
* or null if the response does not contain the message field
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
var responseMap = jsonParser.map();
var message = (String) responseMap.get("detail");
if (message != null) {
return new VoyageAIErrorResponseEntity(message);
}
} catch (Exception e) {
// swallow the error
}
return ErrorResponse.UNDEFINED_ERROR;
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class VoyageAIRerankResponseEntity {
private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class);
record RerankResult(List<RerankResultEntry> entries) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);
static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data"));
}
}
record RerankResultEntry(Float relevanceScore, Integer index, @Nullable String document) {
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (String) args[2])
);
static {
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareString(optionalConstructorArg(), new ParseField("document"));
}
public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, document);
}
}
/**
* Parses the VoyageAI ranked response.
* For a request like:
* "model": "rerank-2",
* "query": "What is the capital of the United States?",
* "top_k": 2,
* "documents": ["Carson City is the capital city of the American state of Nevada.",
* "The Commonwealth of the Northern Mariana ... Its capital is Saipan.",
* "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.",
* "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."]
* <p>
* The response will look like (without whitespace):
* {
* "object": "list",
* "data": [
* {
* "relevance_score": 0.4375,
* "index": 0
* },
* {
* "relevance_score": 0.421875,
* "index": 1
* }
* ],
* "model": "rerank-2",
* "usage": {
* "total_tokens": 26
* }
* }
* @param response the http response from VoyageAI
* @return the parsed response
* @throws IOException if there is an error parsing the response
*/
public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList());
}
}
private VoyageAIRerankResponseEntity() {}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.voyageai;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;
public record VoyageAIAccount(URI uri, SecureString apiKey) {
public static VoyageAIAccount of(VoyageAIModel model) {
try {
var uri = model.buildUri();
return new VoyageAIAccount(uri, model.apiKey());
} catch (URISyntaxException e) {
// using bad request here so that potentially sensitive URL information does not get logged
throw new ElasticsearchStatusException("Failed to construct VoyageAI URL", RestStatus.BAD_REQUEST, e);
}
}
public VoyageAIAccount {
Objects.requireNonNull(uri);
Objects.requireNonNull(apiKey);
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.voyageai;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIErrorResponseEntity;
/**
* Defines how to handle various errors returned from the VoyageAI integration.
*
*/
public class VoyageAIResponseHandler extends BaseResponseHandler {
static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response";
static final String PAYMENT_ERROR_MESSAGE = "Payment required";
public VoyageAIResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, VoyageAIErrorResponseEntity::fromResponse);
}
/**
* Validates the status code throws an RetryException if not in the range [200, 300).
*
* @param request The http request
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
@Override
protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
if (result.isSuccessfulResponse()) {
return;
}
// handle error codes
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode > 500) {
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 429) {
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
} else if (statusCode == 400 || statusCode == 422) {
throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result));
} else if (statusCode == 401) {
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
} else if (statusCode == 402) {
throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}
}

View File

@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.Objects;
public abstract class VoyageAIModel extends Model {
private final SecureString apiKey;
private final VoyageAIRateLimitServiceSettings rateLimitServiceSettings;
protected final URI uri;
public VoyageAIModel(
ModelConfigurations configurations,
ModelSecrets secrets,
@Nullable ApiKeySecrets apiKeySecrets,
VoyageAIRateLimitServiceSettings rateLimitServiceSettings
) {
this(configurations, secrets, apiKeySecrets, rateLimitServiceSettings, null);
}
public VoyageAIModel(
ModelConfigurations configurations,
ModelSecrets secrets,
@Nullable ApiKeySecrets apiKeySecrets,
VoyageAIRateLimitServiceSettings rateLimitServiceSettings,
String url
) {
super(configurations, secrets);
this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
this.apiKey = ServiceUtils.apiKey(apiKeySecrets);
this.uri = url == null ? null : URI.create(url);
}
protected VoyageAIModel(VoyageAIModel model, TaskSettings taskSettings) {
super(model, taskSettings);
this.rateLimitServiceSettings = model.rateLimitServiceSettings();
this.apiKey = model.apiKey();
this.uri = model.uri;
}
protected VoyageAIModel(VoyageAIModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);
this.rateLimitServiceSettings = model.rateLimitServiceSettings();
this.apiKey = model.apiKey();
this.uri = model.uri;
}
public SecureString apiKey() {
return apiKey;
}
public VoyageAIRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}
public abstract ExecutableAction accept(VoyageAIActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
public URI uri() {
return uri;
}
public URI buildUri() throws URISyntaxException {
if (uri == null) {
return buildRequestUri();
}
return uri;
}
protected abstract URI buildRequestUri() throws URISyntaxException;
}

View File

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
public interface VoyageAIRateLimitServiceSettings {
RateLimitSettings rateLimitSettings();
}

View File

@ -0,0 +1,397 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
public class VoyageAIService extends SenderService {
public static final String NAME = "voyageai";
private static final String SERVICE_NAME = "Voyage AI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK);
private static final Integer DEFAULT_BATCH_SIZE = 7;
private static final Map<String, Integer> MODEL_BATCH_SIZES = Map.of(
"voyage-multimodal-3",
7,
"voyage-3-large",
7,
"voyage-code-3",
7,
"voyage-3",
10,
"voyage-3-lite",
30,
"voyage-finance-2",
7,
"voyage-law-2",
7,
"voyage-code-2",
7,
"voyage-2",
72,
"voyage-02",
72
);
public VoyageAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
@Override
public String name() {
return NAME;
}
@Override
public void parseRequestConfig(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
ActionListener<Model> parsedModelListener
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}
VoyageAIModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
throwIfNotEmptyMap(config, NAME);
throwIfNotEmptyMap(serviceSettingsMap, NAME);
throwIfNotEmptyMap(taskSettingsMap, NAME);
parsedModelListener.onResponse(model);
} catch (Exception e) {
parsedModelListener.onFailure(e);
}
}
private static VoyageAIModel createModelFromPersistent(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
) {
return createModel(
inferenceEntityId,
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
private static VoyageAIModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case TEXT_EMBEDDING -> new VoyageAIEmbeddingsModel(
inferenceEntityId,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
case RERANK -> new VoyageAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
@Override
public VoyageAIModel parsePersistedConfigWithSecrets(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
}
@Override
public VoyageAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
}
@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
}
@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
}
@Override
protected void doUnifiedCompletionInfer(
Model model,
UnifiedChatInput inputs,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
}
@Override
public void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof VoyageAIModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
VoyageAIModel voyageaiModel = (VoyageAIModel) model;
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
var action = voyageaiModel.accept(actionCreator, taskSettings, inputType);
action.execute(inputs, timeout, listener);
}
@Override
protected void doChunkedInfer(
Model model,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
if (model instanceof VoyageAIModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
VoyageAIModel voyageaiModel = (VoyageAIModel) model;
var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
getBatchSize(voyageaiModel),
EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()),
voyageaiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = voyageaiModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
}
}
private static int getBatchSize(VoyageAIModel model) {
return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE);
}
/**
* For text embedding models get the embedding size and
* update the service settings.
*
* @param model The new model
* @param listener The listener
*/
@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
}
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
var maxInputTokens = serviceSettings.maxInputTokens();
var dimensionSetByUser = serviceSettings.dimensionsSetByUser();
var updatedServiceSettings = new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(
serviceSettings.getCommonSettings().modelId(),
serviceSettings.getCommonSettings().rateLimitSettings()
),
serviceSettings.getEmbeddingType(),
similarityToUse,
embeddingSize,
maxInputTokens,
dimensionSetByUser
);
return new VoyageAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
} else {
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
}
}
/**
* Return the default similarity measure for the embedding type.
* VoyageAI embeddings are normalized to unit vectors therefore Dot
* Product similarity can be used and is the default for all VoyageAI
* models.
*
* @return The default similarity.
*/
static SimilarityMeasure defaultSimilarity() {
return SimilarityMeasure.DOT_PRODUCT;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
}
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
MODEL_ID,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The name of the model to use for the inference task."
)
.setLabel("Model ID")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);
configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setConfigurations(configurationMap)
.build();
}
);
}
}

View File

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
public class VoyageAIServiceFields {
public static final String TRUNCATION = "truncation";
}

View File

@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
public class VoyageAIServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings {
public static final String NAME = "voyageai_service_settings";
public static final String MODEL_ID = "model_id";
private static final Logger logger = LogManager.getLogger(VoyageAIServiceSettings.class);
// See https://docs.voyageai.com/docs/rate-limits
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000);
public static VoyageAIServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
VoyageAIService.NAME,
context
);
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new VoyageAIServiceSettings(modelId, rateLimitSettings);
}
private final String modelId;
private final RateLimitSettings rateLimitSettings;
public VoyageAIServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
this.modelId = Objects.requireNonNull(modelId);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
public VoyageAIServiceSettings(StreamInput in) throws IOException {
modelId = in.readString();
rateLimitSettings = new RateLimitSettings(in);
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public String modelId() {
return modelId;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragment(builder, params);
builder.endObject();
return builder;
}
public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException {
return toXContentFragmentOfExposedFields(builder, params);
}
@Override
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID, modelId);
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
rateLimitSettings.writeTo(out);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VoyageAIServiceSettings that = (VoyageAIServiceSettings) o;
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(modelId, rateLimitSettings);
}
}

View File

@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;
/**
* Defines the type of embedding that the VoyageAI api should return for a request.
*
* <p>
* <a href="https://docs.voyageai.com/reference/embeddings-api">See api docs for details.</a>
* </p>
*/
public enum VoyageAIEmbeddingType {
/**
* Use this when you want to get back the default float embeddings. Valid for all models.
*/
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT),
/**
* Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
*/
INT8(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
/**
* This is a synonym for INT8
*/
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
/**
* Use this when you want to get back binary embeddings. Valid only for v3 models.
*/
BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY),
/**
* This is a synonym for BIT
*/
BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BINARY);
private static final class RequestConstants {
private static final String FLOAT = "float";
private static final String INT8 = "int8";
private static final String BINARY = "binary";
}
private static final Map<DenseVectorFieldMapper.ElementType, VoyageAIEmbeddingType> ELEMENT_TYPE_TO_VOYAGE_EMBEDDING = Map.of(
DenseVectorFieldMapper.ElementType.FLOAT,
FLOAT,
DenseVectorFieldMapper.ElementType.BYTE,
BYTE,
DenseVectorFieldMapper.ElementType.BIT,
BIT
);
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.keySet()
);
private final DenseVectorFieldMapper.ElementType elementType;
private final String requestString;
VoyageAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) {
this.elementType = elementType;
this.requestString = requestString;
}
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}
public String toRequestString() {
return requestString;
}
public static String toLowerCase(VoyageAIEmbeddingType type) {
return type.toString().toLowerCase(Locale.ROOT);
}
public static VoyageAIEmbeddingType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
public static VoyageAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) {
var embedding = ELEMENT_TYPE_TO_VOYAGE_EMBEDDING.get(elementType);
if (embedding == null) {
var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream()
.map(value -> value.toString().toLowerCase(Locale.ROOT))
.toArray(String[]::new);
Arrays.sort(validElementTypes);
throw new IllegalArgumentException(
Strings.format(
"Element type [%s] does not map to a VoyageAI embedding value, must be one of [%s]",
elementType,
String.join(", ", validElementTypes)
)
);
}
return embedding;
}
public DenseVectorFieldMapper.ElementType toElementType() {
return elementType;
}
}

View File

@ -0,0 +1,127 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST;
public class VoyageAIEmbeddingsModel extends VoyageAIModel {
public static VoyageAIEmbeddingsModel of(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var requestTaskSettings = VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings);
return new VoyageAIEmbeddingsModel(
model,
VoyageAIEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)
);
}
public VoyageAIEmbeddingsModel(
String inferenceId,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceId,
service,
VoyageAIEmbeddingsServiceSettings.fromMap(serviceSettings, context),
VoyageAIEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
// should only be used for testing
VoyageAIEmbeddingsModel(
String modelId,
String service,
VoyageAIEmbeddingsServiceSettings serviceSettings,
VoyageAIEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
secretSettings,
serviceSettings.getCommonSettings()
);
}
VoyageAIEmbeddingsModel(
String modelId,
String service,
String url,
VoyageAIEmbeddingsServiceSettings serviceSettings,
VoyageAIEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
secretSettings,
serviceSettings.getCommonSettings(),
url
);
}
private VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsTaskSettings taskSettings) {
super(model, taskSettings);
}
public VoyageAIEmbeddingsModel(VoyageAIEmbeddingsModel model, VoyageAIEmbeddingsServiceSettings serviceSettings) {
super(model, serviceSettings);
}
@Override
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
return (VoyageAIEmbeddingsServiceSettings) super.getServiceSettings();
}
@Override
public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
return (VoyageAIEmbeddingsTaskSettings) super.getTaskSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
@Override
public ExecutableAction accept(VoyageAIActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
return visitor.create(this, taskSettings, inputType);
}
protected URI buildRequestUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(HOST)
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH)
.build();
}
}

View File

@ -0,0 +1,259 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
public class VoyageAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
public static final String NAME = "voyageai_embeddings_service_settings";
static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
public static final VoyageAIEmbeddingsServiceSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsServiceSettings(
null,
null,
null,
null,
null,
false
);
public static final String EMBEDDING_TYPE = "embedding_type";
public static VoyageAIEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
return switch (context) {
case REQUEST -> fromRequestMap(map, context);
case PERSISTENT -> fromPersistentMap(map, context);
};
}
private static VoyageAIEmbeddingsServiceSettings fromRequestMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new VoyageAIEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes, similarity, dims, maxInputTokens, dims != null);
}
private static VoyageAIEmbeddingsServiceSettings fromPersistentMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
VoyageAIEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
if (dimensionsSetByUser == null) {
dimensionsSetByUser = Boolean.FALSE;
}
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new VoyageAIEmbeddingsServiceSettings(
commonServiceSettings,
embeddingTypes,
similarity,
dims,
maxInputTokens,
dimensionsSetByUser
);
}
static VoyageAIEmbeddingType parseEmbeddingType(
Map<String, Object> map,
ConfigurationParseContext context,
ValidationException validationException
) {
return switch (context) {
case REQUEST, PERSISTENT -> Objects.requireNonNullElse(
extractOptionalEnum(
map,
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
VoyageAIEmbeddingType::fromString,
EnumSet.allOf(VoyageAIEmbeddingType.class),
validationException
),
VoyageAIEmbeddingType.FLOAT
);
};
}
private final VoyageAIServiceSettings commonSettings;
private final VoyageAIEmbeddingType embeddingType;
private final SimilarityMeasure similarity;
private final Integer dimensions;
private final Integer maxInputTokens;
private final boolean dimensionsSetByUser;
public VoyageAIEmbeddingsServiceSettings(
VoyageAIServiceSettings commonSettings,
@Nullable VoyageAIEmbeddingType embeddingType,
@Nullable SimilarityMeasure similarity,
@Nullable Integer dimensions,
@Nullable Integer maxInputTokens,
boolean dimensionsSetByUser
) {
this.commonSettings = commonSettings;
this.similarity = similarity;
this.dimensions = dimensions;
this.maxInputTokens = maxInputTokens;
this.embeddingType = embeddingType;
this.dimensionsSetByUser = dimensionsSetByUser;
}
public VoyageAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
this.commonSettings = new VoyageAIServiceSettings(in);
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
this.dimensions = in.readOptionalVInt();
this.maxInputTokens = in.readOptionalVInt();
this.embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(VoyageAIEmbeddingType.class), VoyageAIEmbeddingType.FLOAT);
this.dimensionsSetByUser = in.readBoolean();
}
public VoyageAIServiceSettings getCommonSettings() {
return commonSettings;
}
@Override
public SimilarityMeasure similarity() {
return similarity;
}
@Override
public Integer dimensions() {
return dimensions;
}
public Integer maxInputTokens() {
return maxInputTokens;
}
@Override
public String modelId() {
return commonSettings.modelId();
}
public VoyageAIEmbeddingType getEmbeddingType() {
return embeddingType;
}
@Override
public DenseVectorFieldMapper.ElementType elementType() {
return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType();
}
@Override
public Boolean dimensionsSetByUser() {
return this.dimensionsSetByUser;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder = commonSettings.toXContentFragment(builder, params);
if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
if (dimensions != null) {
builder.field(DIMENSIONS, dimensions);
}
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
if (embeddingType != null) {
builder.field(EMBEDDING_TYPE, embeddingType);
}
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
commonSettings.toXContentFragmentOfExposedFields(builder, params);
return builder;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
commonSettings.writeTo(out);
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
out.writeOptionalEnum(embeddingType);
out.writeBoolean(dimensionsSetByUser);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VoyageAIEmbeddingsServiceSettings that = (VoyageAIEmbeddingsServiceSettings) o;
return Objects.equals(commonSettings, that.commonSettings)
&& Objects.equals(similarity, that.similarity)
&& Objects.equals(dimensions, that.dimensions)
&& Objects.equals(maxInputTokens, that.maxInputTokens)
&& Objects.equals(embeddingType, that.embeddingType)
&& Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser);
}
@Override
public int hashCode() {
return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser);
}
}

View File

@ -0,0 +1,202 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION;
/**
* Defines the task settings for the voyageai text embeddings service.
*
* <p>
* <a href="https://docs.voyageai.com/docs/embeddings">See api docs for details.</a>
* </p>
*/
public class VoyageAIEmbeddingsTaskSettings implements TaskSettings {
public static final String NAME = "voyageai_embeddings_task_settings";
public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null);
static final String INPUT_TYPE = "input_type";
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(InputType.INGEST, InputType.SEARCH);
public static VoyageAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return EMPTY_SETTINGS;
}
ValidationException validationException = new ValidationException();
InputType inputType = extractOptionalEnum(
map,
INPUT_TYPE,
ModelConfigurations.TASK_SETTINGS,
InputType::fromString,
VALID_REQUEST_VALUES,
validationException
);
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation);
}
/**
* Creates a new {@link VoyageAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters.
* For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED.
* Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null.
* Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to
* originalSettings.
* @param originalSettings the settings stored as part of the inference entity configuration
* @param requestTaskSettings the settings passed in within the task_settings field of the request
* @param requestInputType the input type passed in the request parameters
* @return a constructed {@link VoyageAIEmbeddingsTaskSettings}
*/
public static VoyageAIEmbeddingsTaskSettings of(
VoyageAIEmbeddingsTaskSettings originalSettings,
VoyageAIEmbeddingsTaskSettings requestTaskSettings,
InputType requestInputType
) {
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType);
var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings);
return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse);
}
private static InputType getValidInputType(
VoyageAIEmbeddingsTaskSettings originalSettings,
VoyageAIEmbeddingsTaskSettings requestTaskSettings,
InputType requestInputType
) {
InputType inputTypeToUse = originalSettings.inputType;
if (VALID_REQUEST_VALUES.contains(requestInputType)) {
inputTypeToUse = requestInputType;
} else if (requestTaskSettings.inputType != null) {
inputTypeToUse = requestTaskSettings.inputType;
}
return inputTypeToUse;
}
private static Boolean getValidTruncation(
VoyageAIEmbeddingsTaskSettings originalSettings,
VoyageAIEmbeddingsTaskSettings requestTaskSettings
) {
return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation();
}
private final InputType inputType;
private final Boolean truncation;
public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(InputType.class), in.readOptionalBoolean());
}
public VoyageAIEmbeddingsTaskSettings(@Nullable InputType inputType, @Nullable Boolean truncation) {
validateInputType(inputType);
this.inputType = inputType;
this.truncation = truncation;
}
private static void validateInputType(InputType inputType) {
if (inputType == null) {
return;
}
assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
}
@Override
public boolean isEmpty() {
return inputType == null && truncation == null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (inputType != null) {
builder.field(INPUT_TYPE, inputType);
}
if (truncation != null) {
builder.field(TRUNCATION, truncation);
}
builder.endObject();
return builder;
}
public InputType getInputType() {
return inputType;
}
public Boolean getTruncation() {
return truncation;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(inputType);
out.writeOptionalBoolean(truncation);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o;
return Objects.equals(inputType, that.inputType) && Objects.equals(truncation, that.truncation);
}
@Override
public int hashCode() {
return Objects.hash(inputType, truncation);
}
public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}
@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
VoyageAIEmbeddingsTaskSettings updatedSettings = VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings));
return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType);
}
}

View File

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.voyageai.VoyageAIActionVisitor;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils.HOST;
public class VoyageAIRerankModel extends VoyageAIModel {
public static VoyageAIRerankModel of(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
var requestTaskSettings = VoyageAIRerankTaskSettings.fromMap(taskSettings);
return new VoyageAIRerankModel(model, VoyageAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
}
public VoyageAIRerankModel(
String inferenceId,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceId,
service,
VoyageAIRerankServiceSettings.fromMap(serviceSettings, context),
VoyageAIRerankTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
}
// should only be used for testing
VoyageAIRerankModel(
String modelId,
String service,
VoyageAIRerankServiceSettings serviceSettings,
VoyageAIRerankTaskSettings taskSettings,
@Nullable DefaultSecretSettings secretSettings
) {
this(modelId, service, null, serviceSettings, taskSettings, secretSettings);
}
VoyageAIRerankModel(
String modelId,
String service,
String url,
VoyageAIRerankServiceSettings serviceSettings,
VoyageAIRerankTaskSettings taskSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings),
new ModelSecrets(secretSettings),
secretSettings,
serviceSettings.getCommonSettings(),
url
);
}
private VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankTaskSettings taskSettings) {
super(model, taskSettings);
}
public VoyageAIRerankModel(VoyageAIRerankModel model, VoyageAIRerankServiceSettings serviceSettings) {
super(model, serviceSettings);
}
@Override
public VoyageAIRerankServiceSettings getServiceSettings() {
return (VoyageAIRerankServiceSettings) super.getServiceSettings();
}
@Override
public VoyageAIRerankTaskSettings getTaskSettings() {
return (VoyageAIRerankTaskSettings) super.getTaskSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param inputType ignored for rerank task
* @return the rerank action
*/
@Override
public ExecutableAction accept(VoyageAIActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
return visitor.create(this, taskSettings);
}
@Override
protected URI buildRequestUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(HOST)
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH)
.build();
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
public class VoyageAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, VoyageAIRateLimitServiceSettings {
public static final String NAME = "voyageai_rerank_service_settings";
private static final Logger logger = LogManager.getLogger(VoyageAIRerankServiceSettings.class);
public static VoyageAIRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
var commonServiceSettings = VoyageAIServiceSettings.fromMap(map, context);
return new VoyageAIRerankServiceSettings(commonServiceSettings);
}
private final VoyageAIServiceSettings commonSettings;
public VoyageAIRerankServiceSettings(VoyageAIServiceSettings commonSettings) {
this.commonSettings = commonSettings;
}
public VoyageAIRerankServiceSettings(StreamInput in) throws IOException {
this.commonSettings = new VoyageAIServiceSettings(in);
}
public VoyageAIServiceSettings getCommonSettings() {
return commonSettings;
}
@Override
public String modelId() {
return commonSettings.modelId();
}
@Override
public RateLimitSettings rateLimitSettings() {
return commonSettings.rateLimitSettings();
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder = commonSettings.toXContentFragment(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
commonSettings.toXContentFragmentOfExposedFields(builder, params);
return builder;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
commonSettings.writeTo(out);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VoyageAIRerankServiceSettings that = (VoyageAIRerankServiceSettings) o;
return Objects.equals(commonSettings, that.commonSettings);
}
@Override
public int hashCode() {
return Objects.hash(commonSettings);
}
}

View File

@ -0,0 +1,184 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields.TRUNCATION;
/**
* Defines the task settings for the VoyageAI rerank service.
*
*/
public class VoyageAIRerankTaskSettings implements TaskSettings {
public static final String NAME = "voyageai_rerank_task_settings";
public static final String RETURN_DOCUMENTS = "return_documents";
public static final String TOP_K_DOCS_ONLY = "top_k";
public static final VoyageAIRerankTaskSettings EMPTY_SETTINGS = new VoyageAIRerankTaskSettings(null, null, null);
public static VoyageAIRerankTaskSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
if (map == null || map.isEmpty()) {
return EMPTY_SETTINGS;
}
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
Integer topKDocumentsOnly = extractOptionalPositiveInteger(
map,
TOP_K_DOCS_ONLY,
ModelConfigurations.TASK_SETTINGS,
validationException
);
Boolean truncation = extractOptionalBoolean(map, TRUNCATION, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return of(topKDocumentsOnly, returnDocuments, truncation);
}
/**
* Creates a new {@link VoyageAIRerankTaskSettings} by preferring non-null fields from the request settings over the original settings.
*
* @param originalSettings the settings stored as part of the inference entity configuration
* @param requestTaskSettings the settings passed in within the task_settings field of the request
* @return a constructed {@link VoyageAIRerankTaskSettings}
*/
public static VoyageAIRerankTaskSettings of(
VoyageAIRerankTaskSettings originalSettings,
VoyageAIRerankTaskSettings requestTaskSettings
) {
return new VoyageAIRerankTaskSettings(
requestTaskSettings.getTopKDocumentsOnly() != null
? requestTaskSettings.getTopKDocumentsOnly()
: originalSettings.getTopKDocumentsOnly(),
requestTaskSettings.getReturnDocuments() != null
? requestTaskSettings.getReturnDocuments()
: originalSettings.getReturnDocuments(),
requestTaskSettings.getTruncation() != null ? requestTaskSettings.getTruncation() : originalSettings.getTruncation()
);
}
public static VoyageAIRerankTaskSettings of(Integer topKDocumentsOnly, Boolean returnDocuments, Boolean truncation) {
return new VoyageAIRerankTaskSettings(topKDocumentsOnly, returnDocuments, truncation);
}
private final Integer topKDocumentsOnly;
private final Boolean returnDocuments;
private final Boolean truncation;
public VoyageAIRerankTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalBoolean());
}
public VoyageAIRerankTaskSettings(
@Nullable Integer topKDocumentsOnly,
@Nullable Boolean doReturnDocuments,
@Nullable Boolean truncation
) {
this.topKDocumentsOnly = topKDocumentsOnly;
this.returnDocuments = doReturnDocuments;
this.truncation = truncation;
}
@Override
public boolean isEmpty() {
return topKDocumentsOnly == null && returnDocuments == null && truncation == null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (topKDocumentsOnly != null) {
builder.field(TOP_K_DOCS_ONLY, topKDocumentsOnly);
}
if (returnDocuments != null) {
builder.field(RETURN_DOCUMENTS, returnDocuments);
}
if (truncation != null) {
builder.field(TRUNCATION, truncation);
}
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.VOYAGE_AI_INTEGRATION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(topKDocumentsOnly);
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalBoolean(truncation);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VoyageAIRerankTaskSettings that = (VoyageAIRerankTaskSettings) o;
return Objects.equals(topKDocumentsOnly, that.topKDocumentsOnly)
&& Objects.equals(returnDocuments, that.returnDocuments)
&& Objects.equals(truncation, that.truncation);
}
@Override
public int hashCode() {
return Objects.hash(truncation, returnDocuments, topKDocumentsOnly);
}
public Integer getTopKDocumentsOnly() {
return topKDocumentsOnly;
}
public Boolean getDoesReturnDocuments() {
return returnDocuments;
}
public Boolean getReturnDocuments() {
return returnDocuments;
}
public Boolean getTruncation() {
return truncation;
}
@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
VoyageAIRerankTaskSettings updatedSettings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(newSettings));
return VoyageAIRerankTaskSettings.of(this, updatedSettings);
}
}

View File

@ -0,0 +1,145 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.voyageai;
import org.apache.http.HttpHeaders;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests;
import org.hamcrest.MatcherAssert;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class VoyageAIActionCreatorTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testCreate_VoyageAIEmbeddingsModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [
0.123,
-0.123
],
"index": 0
}],
"model": "voyage-3-large",
"usage": {
"total_tokens": 123
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = VoyageAIEmbeddingsModelTests.createModel(
getUrl(webServer),
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
1024,
1024,
"model",
VoyageAIEmbeddingType.FLOAT
);
var actionCreator = new VoyageAIActionCreator(sender, createWithEmptySettings(threadPool));
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH);
var action = actionCreator.create(model, overriddenTaskSettings, InputType.UNSPECIFIED);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F }))));
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().getFirst().getUri().getQuery());
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"output_dtype",
"float",
"truncation",
true,
"input_type",
"query",
"output_dimension",
1024,
"input",
List.of("abc"),
"model",
"model"
)
)
);
}
}
}

View File

@ -0,0 +1,413 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.voyageai;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIUtils;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.hamcrest.MatcherAssert;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationBinary;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
public class VoyageAIEmbeddingsActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [
0.123,
-0.123
],
"index": 0
}],
"model": "voyage-3-large",
"usage": {
"total_tokens": 123
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(
getUrl(webServer),
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
"model",
VoyageAIEmbeddingType.FLOAT,
sender
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
MatcherAssert.assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F }))));
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().getFirst().getUri().getQuery());
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
MatcherAssert.assertThat(
requestMap,
equalTo(
Map.of(
"input",
List.of("abc"),
"model",
"model",
"input_type",
"document",
"output_dtype",
"float",
"truncation",
true,
"output_dimension",
1024
)
)
);
}
}
public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [
0,
-1
],
"index": 0
}],
"model": "voyage-3-large",
"usage": {
"total_tokens": 123
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(
getUrl(webServer),
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
"model",
VoyageAIEmbeddingType.INT8,
sender
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertEquals(buildExpectationByte(List.of(new byte[] { 0, -1 })), result.asMap());
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().getFirst().getUri().getQuery());
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"input",
List.of("abc"),
"model",
"model",
"input_type",
"document",
"output_dtype",
"int8",
"truncation",
true,
"output_dimension",
1024
)
)
);
}
}
public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = HttpRequestSenderTests.createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [
0,
-1
],
"index": 0
}],
"model": "voyage-3-large",
"usage": {
"total_tokens": 123
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(
getUrl(webServer),
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
"model",
VoyageAIEmbeddingType.BINARY,
sender
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertEquals(buildExpectationBinary(List.of(new byte[] { 0, -1 })), result.asMap());
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().getFirst().getUri().getQuery());
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
MatcherAssert.assertThat(
webServer.requests().getFirst().getHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER),
equalTo(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"input",
List.of("abc"),
"model",
"model",
"input_type",
"document",
"output_dtype",
"binary",
"truncation",
true,
"output_dimension",
1024
)
)
);
}
}
public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(thrownException.getMessage(), is("failed"));
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
var sender = mock(Sender.class);
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<HttpResult> listener = (ActionListener<HttpResult>) invocation.getArguments()[2];
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(
thrownException.getMessage(),
is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer)))
);
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() {
var sender = mock(Sender.class);
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<HttpResult> listener = (ActionListener<HttpResult>) invocation.getArguments()[2];
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request"));
}
public void testExecute_ThrowsException() {
var sender = mock(Sender.class);
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(
thrownException.getMessage(),
is(format("Failed to send VoyageAI embeddings request to [%s]", getUrl(webServer)))
);
}
public void testExecute_ThrowsExceptionWithNullUrl() {
var sender = mock(Sender.class);
doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(null, "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model", null, sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(thrownException.getMessage(), is("Failed to send VoyageAI embeddings request"));
}
private ExecutableAction createAction(
String url,
String apiKey,
VoyageAIEmbeddingsTaskSettings taskSettings,
@Nullable String modelName,
@Nullable VoyageAIEmbeddingType embeddingType,
Sender sender
) {
var model = VoyageAIEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "VoyageAI embeddings");
var requestCreator = VoyageAIEmbeddingsRequestManager.of(model, threadPool);
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
}

View File

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.CoreMatchers.is;
public class VoyageAIEmbeddingsRequestEntityTests extends ESTestCase {
public void testXContent_WritesAllFields_ServiceSettingsDefined() throws IOException {
var entity = new VoyageAIEmbeddingsRequestEntity(
List.of("abc"),
VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.URL,
"https://www.abc.com",
ServiceFields.SIMILARITY,
SimilarityMeasure.DOT_PRODUCT.toString(),
ServiceFields.DIMENSIONS,
2048,
ServiceFields.MAX_INPUT_TOKENS,
512,
VoyageAIServiceSettings.MODEL_ID,
"model",
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
"float"
)
),
ConfigurationParseContext.PERSISTENT
),
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"float"}"""));
}
public void testXContent_WritesAllFields_ServiceSettingsDefined_Int8() throws IOException {
var entity = new VoyageAIEmbeddingsRequestEntity(
List.of("abc"),
VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.URL,
"https://www.abc.com",
ServiceFields.SIMILARITY,
SimilarityMeasure.DOT_PRODUCT.toString(),
ServiceFields.DIMENSIONS,
2048,
ServiceFields.MAX_INPUT_TOKENS,
512,
VoyageAIServiceSettings.MODEL_ID,
"model",
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
"int8"
)
),
ConfigurationParseContext.PERSISTENT
),
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"int8"}"""));
}
public void testXContent_WritesAllFields_ServiceSettingsDefined_Binary() throws IOException {
var entity = new VoyageAIEmbeddingsRequestEntity(
List.of("abc"),
VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.URL,
"https://www.abc.com",
ServiceFields.SIMILARITY,
SimilarityMeasure.DOT_PRODUCT.toString(),
ServiceFields.DIMENSIONS,
2048,
ServiceFields.MAX_INPUT_TOKENS,
512,
VoyageAIServiceSettings.MODEL_ID,
"model",
VoyageAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
"binary"
)
),
ConfigurationParseContext.PERSISTENT
),
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","input_type":"document","output_dimension":2048,"output_dtype":"binary"}"""));
}
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new VoyageAIEmbeddingsRequestEntity(
List.of("abc"),
VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS,
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model","input_type":"document"}"""));
}
public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
var entity = new VoyageAIEmbeddingsRequestEntity(
List.of("abc"),
VoyageAIEmbeddingsServiceSettings.EMPTY_SETTINGS,
VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"input":["abc"],"model":"model"}"""));
}
public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
var thrownException = expectThrows(
AssertionError.class,
() -> VoyageAIEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED)
);
MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
}
}

View File

@ -0,0 +1,215 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class VoyageAIEmbeddingsRequestTests extends ESTestCase {
public void testCreateRequest_UrlDefined() throws IOException {
var request = createRequest(
List.of("abc"),
VoyageAIEmbeddingsModelTests.createModel("url", "secret", VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model")
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "output_dtype", "float")));
}
public void testCreateRequest_AllOptionsDefined() throws IOException {
var request = createRequest(
List.of("abc"),
VoyageAIEmbeddingsModelTests.createModel(
"url",
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
null,
null,
"model"
)
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
requestMap,
is(Map.of("input", List.of("abc"), "model", "model", "input_type", "document", "output_dtype", "float"))
);
}
public void testCreateRequest_DimensionDefined() throws IOException {
var request = createRequest(
List.of("abc"),
VoyageAIEmbeddingsModelTests.createModel(
"url",
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
null,
2048,
"model"
)
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"input",
List.of("abc"),
"model",
"model",
"input_type",
"document",
"output_dtype",
"float",
"output_dimension",
2048
)
)
);
}
public void testCreateRequest_EmbeddingTypeDefined() throws IOException {
var request = createRequest(
List.of("abc"),
VoyageAIEmbeddingsModelTests.createModel(
"url",
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null),
null,
2048,
"model",
VoyageAIEmbeddingType.BYTE
)
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
requestMap,
is(
Map.of(
"input",
List.of("abc"),
"model",
"model",
"input_type",
"document",
"output_dtype",
"int8",
"output_dimension",
2048
)
)
);
}
public void testCreateRequest_InputTypeSearch() throws IOException {
var request = createRequest(
List.of("abc"),
VoyageAIEmbeddingsModelTests.createModel(
"url",
"secret",
new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null),
null,
null,
"model"
)
);
var httpRequest = request.createHttpRequest();
MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(),
is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE)
);
var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
requestMap,
is(Map.of("input", List.of("abc"), "model", "model", "input_type", "query", "output_dtype", "float"))
);
}
public static VoyageAIEmbeddingsRequest createRequest(List<String> input, VoyageAIEmbeddingsModel model) {
return new VoyageAIEmbeddingsRequest(input, model);
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIAccount;
import java.net.URI;
import static org.hamcrest.Matchers.is;
public class VoyageAIRequestTests extends ESTestCase {
public void testDecorateWithHeaders() {
var request = new HttpPost("http://www.abc.com");
VoyageAIRequest.decorateWithHeaders(
request,
new VoyageAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' }))
);
assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc"));
assertThat(request.getFirstHeader(VoyageAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(VoyageAIUtils.ELASTIC_REQUEST_SOURCE));
}
}

View File

@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
public class VoyageAIRerankRequestEntityTests extends ESTestCase {
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"top_k": 8
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"return_documents": true,
"top_k": 8
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"return_documents": false,
"top_k": 8
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"return_documents": false,
"top_k": 8,
"truncation": true
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"return_documents": false,
"top_k": 8,
"truncation": false
}
"""));
}
public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
]
}
"""));
}
public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc", "def"),
new VoyageAIRerankTaskSettings(8, null, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
],
"top_k": 8
}
"""));
}
public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
]
}
"""));
}
}

View File

@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModelTests;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
public class VoyageAIRerankRequestTests extends ESTestCase {
private static final String API_KEY = "foo";
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
var input = "input";
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get("documents"), is(List.of(input)));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("model"), is(modelId));
}
public void testCreateRequest_WithTopNSet() throws IOException {
var input = "input";
var query = "query";
var topK = 1;
var modelId = "model";
var request = createRequest(query, input, modelId, topK);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(4));
assertThat(requestMap.get("documents"), is(List.of(input)));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("top_k"), is(topK));
assertThat(requestMap.get("model"), is(modelId));
}
public void testCreateRequest_WithModelSet() throws IOException {
var input = "input";
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get("documents"), is(List.of(input)));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("model"), is(modelId));
}
public void testTruncate_DoesNotTruncate() {
var request = createRequest("query", "input", "null", null);
var truncatedRequest = request.truncate();
assertThat(truncatedRequest, sameInstance(request));
}
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
}
}

View File

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.is;
public class VoyageAIUtilsTests extends ESTestCase {
public void testCreateRequestSourceHeader() {
var requestSourceHeader = VoyageAIUtils.createRequestSourceHeader();
assertThat(requestSourceHeader.getName(), is("Request-Source"));
assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch"));
}
}

View File

@ -0,0 +1,432 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.apache.http.HttpResponse;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class VoyageAIEmbeddingsResponseEntityTests extends ESTestCase {
public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.014539449,
-0.015288644
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.014539449,
-0.015288644
]
},
{
"object": "embedding",
"index": 1,
"embedding": [
0.0123,
-0.0123
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }),
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F })
)
)
);
}
public void testFromResponse_FailsWhenDataFieldIsNotPresent() {
String responseJson = """
{
"object": "list",
"not_data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.014539449,
-0.015288644
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
var thrownException = expectThrows(
java.lang.IllegalArgumentException.class,
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), is("Required [data]"));
}
public void testFromResponse_FailsWhenDataFieldNotAnArray() {
String responseJson = """
{
"object": "list",
"data": {
"test": {
"object": "embedding",
"index": 0,
"embedding": [
0.014539449,
-0.015288644
]
}
},
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
var thrownException = expectThrows(
XContentParseException.class,
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
}
public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embeddingzzz": [
0.014539449,
-0.015288644
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
var thrownException = expectThrows(
XContentParseException.class,
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
}
public void testFromResponse_FailsWhenEmbeddingValueIsAString() {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
"abc"
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
var thrownException = expectThrows(
XContentParseException.class,
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]"));
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
1
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F })))
);
}
public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
40294967295
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F })))
);
}
public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
{}
]
}
],
"model": "voyage-3-large",
"usage": {
"total_tokens": 8
}
}
""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
var thrownException = expectThrows(
XContentParseException.class,
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), is("[8:15] [EmbeddingFloatResult] failed to parse field [data]"));
}
public void testFieldsInDifferentOrderServer() throws IOException {
// The fields of the objects in the data array are reordered
String response = """
{
"object": "list",
"model": "voyage-3-large",
"data": [
{
"embedding": [
-0.9,
0.5,
0.3
],
"index": 0,
"object": "embedding"
},
{
"index": 0,
"embedding": [
0.1,
0.5
],
"object": "embedding"
},
{
"object": "embedding",
"index": 0,
"embedding": [
0.5,
0.5
]
}
],
"usage": {
"total_tokens": 0
}
}""";
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(
List.of("abc", "def"),
createModel("url", "api_key", null, "voyage-3-large")
);
InferenceServiceResults parsedResults = VoyageAIEmbeddingsResponseEntity.fromResponse(
request,
new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
);
assertThat(parsedResults, instanceOf(InferenceTextEmbeddingFloatResults.class));
assertThat(
((InferenceTextEmbeddingFloatResults) parsedResults).embeddings(),
is(
List.of(
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }),
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.5F }),
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.5F, 0.5F })
)
)
);
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.apache.http.HttpResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.hamcrest.MatcherAssert;
import java.nio.charset.StandardCharsets;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class VoyageAIErrorResponseEntityTests extends ESTestCase {
public void testFromResponse() {
String message = "\"input\" length 2049 is larger than the largest allowed size 2048";
String escapedMessage = message.replace("\\", "\\\\").replace("\"", "\\\"");
String responseJson = Strings.format("""
{
"detail": "%s"
}
""", escapedMessage);
ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertNotNull(errorResponse);
MatcherAssert.assertThat(errorResponse.getErrorMessage(), is(message));
}
public void testFromResponse_noMessage() {
String responseJson = """
{
"error": "abc"
}
""";
ErrorResponse errorResponse = VoyageAIErrorResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(errorResponse, is(ErrorResponse.UNDEFINED_ERROR));
}
}

View File

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.voyageai;
import org.apache.http.HttpResponse;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class VoyageAIRerankResponseEntityTests extends ESTestCase {
public void testResponseLiteral() throws IOException {
String responseLiteral = """
{
"object": "list",
"model": "model",
"data": [
{
"index": 2,
"relevance_score": 0.98005307
},
{
"index": 3,
"relevance_score": 0.27904198
},
{
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
}
""";
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
List<RankedDocsResults.RankedDoc> expected = responseLiteralDocs();
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
}
}
public void testGeneratedResponse() throws IOException {
int numDocs = randomIntBetween(1, 10);
List<RankedDocsResults.RankedDoc> expected = new ArrayList<>(numDocs);
StringBuilder responseBuilder = new StringBuilder();
responseBuilder.append("{");
responseBuilder.append("\"model\": \"model\",");
responseBuilder.append("\"object\": \"list\",");
responseBuilder.append("\"data\": [");
List<Integer> indices = linear(numDocs);
List<Float> scores = linearFloats(numDocs);
for (int i = 0; i < numDocs; i++) {
int index = indices.remove(randomInt(indices.size() - 1));
responseBuilder.append("{");
responseBuilder.append("\"index\":").append(index).append(",");
responseBuilder.append("\"relevance_score\":").append(scores.get(i).toString()).append("}");
expected.add(new RankedDocsResults.RankedDoc(index, scores.get(i), null));
if (i < numDocs - 1) {
responseBuilder.append(",");
}
}
responseBuilder.append("],");
responseBuilder.append("\"usage\": {");
responseBuilder.append("\"total_tokens\": 15}");
responseBuilder.append("}");
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) {
assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index());
}
}
private ArrayList<RankedDocsResults.RankedDoc> responseLiteralDocs() {
var list = new ArrayList<RankedDocsResults.RankedDoc>();
list.add(new RankedDocsResults.RankedDoc(2, 0.98005307F, null));
list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null));
list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null));
return list;
}
public void testResponseLiteralWithDocuments() throws IOException {
String responseLiteralWithDocuments = """
{
"object": "list",
"model": "model",
"data": [
{
"document": "Washington, D.C..",
"index": 2,
"relevance_score": 0.98005307
},
{
"document": "Capital punishment has existed in the United States since beforethe United States was a country. ",
"index": 3,
"relevance_score": 0.27904198
},
{
"document": "Carson City is the capital city of the American state of Nevada.",
"index": 0,
"relevance_score": 0.10194652
}
],
"usage": {
"total_tokens": 15
}
}
""";
InferenceServiceResults parsedResults = VoyageAIRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8))
);
MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class));
MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText));
}
private final List<RankedDocsResults.RankedDoc> responseLiteralDocsWithText = List.of(
new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."),
new RankedDocsResults.RankedDoc(
3,
0.27904198F,
"Capital punishment has existed in the United States since beforethe United States was a country. "
),
new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.")
);
private ArrayList<Integer> linear(int n) {
ArrayList<Integer> list = new ArrayList<>();
for (int i = 0; i <= n; i++) {
list.add(i);
}
return list;
}
// creates a list of doubles of monotonically decreasing magnitude
private ArrayList<Float> linearFloats(int n) {
ArrayList<Float> list = new ArrayList<>();
float startValue = 1.0f;
float decrement = startValue / n + 1;
for (int i = 0; i <= n; i++) {
list.add(startValue - (i * decrement));
}
return list;
}
}

View File

@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.voyageai;
import org.apache.http.Header;
import org.apache.http.HeaderElement;
import org.apache.http.HttpResponse;
import org.apache.http.StatusLine;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.hamcrest.MatcherAssert;
import java.nio.charset.StandardCharsets;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.core.Is.is;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class VoyageAIResponseHandlerTests extends ESTestCase {
public void testCheckForFailureStatusCode_DoesNotThrowForStatusCodesBetween200And299() {
callCheckForFailureStatusCode(randomIntBetween(200, 299), "id");
}
public void testCheckForFailureStatusCode_ThrowsFor503() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id"));
assertFalse(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString("Received a server error status code for request from inference entity id [id] status [503]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
}
public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id"));
assertTrue(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString("Received a server error status code for request from inference entity id [id] status [500]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
}
public void testCheckForFailureStatusCode_ThrowsFor429_WithShouldRetryTrue() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id"));
assertTrue(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString("Received a rate limit status code for request from inference entity id [id] status [429]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS));
}
public void testCheckForFailureStatusCode_ThrowsFor400() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id"));
assertFalse(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString("Received an input validation error response for request from inference entity id [id] status [400]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
}
public void testCheckForFailureStatusCode_ThrowsFor400_InputsTooLarge() {
var exception = expectThrows(
RetryException.class,
() -> callCheckForFailureStatusCode(400, "\"input\" length 2049 is larger than the largest allowed size 2048", "id")
);
assertFalse(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString("Received an input validation error response for request from inference entity id [id] status [400]")
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST));
}
public void testCheckForFailureStatusCode_ThrowsFor401() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId"));
assertFalse(exception.shouldRetry());
MatcherAssert.assertThat(
exception.getCause().getMessage(),
containsString(
"Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]"
)
);
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED));
}
public void testCheckForFailureStatusCode_ThrowsFor402() {
var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "inferenceEntityId"));
assertFalse(exception.shouldRetry());
MatcherAssert.assertThat(exception.getCause().getMessage(), containsString("Payment required"));
MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED));
}
private static void callCheckForFailureStatusCode(int statusCode, String modelId) {
callCheckForFailureStatusCode(statusCode, null, modelId);
}
private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) {
var statusLine = mock(StatusLine.class);
when(statusLine.getStatusCode()).thenReturn(statusCode);
var httpResponse = mock(HttpResponse.class);
when(httpResponse.getStatusLine()).thenReturn(statusLine);
var header = mock(Header.class);
when(header.getElements()).thenReturn(new HeaderElement[] {});
when(httpResponse.getFirstHeader(anyString())).thenReturn(header);
String escapedErrorMessage = errorMessage != null ? errorMessage.replace("\\", "\\\\").replace("\"", "\\\"") : errorMessage;
String responseJson = Strings.format("""
{
"detail": "%s"
}
""", escapedErrorMessage);
var mockRequest = mock(Request.class);
when(mockRequest.getInferenceEntityId()).thenReturn(modelId);
var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8));
var handler = new VoyageAIResponseHandler("", (request, result) -> null);
handler.checkForFailureStatusCode(mockRequest, httpResult);
}
}

View File

@ -146,4 +146,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<I
);
}
public static Map<String, Object> buildExpectationBinary(List<byte[]> embeddings) {
return Map.of("text_embedding_bits", embeddings.stream().map(InferenceByteEmbedding::new).toList());
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.is;
public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase<VoyageAIServiceSettings> {
public static VoyageAIServiceSettings createRandomWithNonNullUrl() {
return createRandom();
}
/**
* The created settings can have a url set to null.
*/
public static VoyageAIServiceSettings createRandom() {
var model = randomAlphaOfLength(15);
return new VoyageAIServiceSettings(model, RateLimitSettingsTests.createRandom());
}
public void testFromMap() {
var model = "model";
var serviceSettings = VoyageAIServiceSettings.fromMap(
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)),
ConfigurationParseContext.REQUEST
);
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null)));
}
public void testFromMap_WithRateLimit() {
var model = "model";
var serviceSettings = VoyageAIServiceSettings.fromMap(
new HashMap<>(
Map.of(
VoyageAIServiceSettings.MODEL_ID,
model,
RateLimitSettings.FIELD_NAME,
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
)
),
ConfigurationParseContext.REQUEST
);
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, new RateLimitSettings(3))));
}
public void testFromMap_WhenUsingModelId() {
var model = "model";
var serviceSettings = VoyageAIServiceSettings.fromMap(
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, model)),
ConfigurationParseContext.PERSISTENT
);
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(model, null)));
}
public void testXContent_WritesModelId() throws IOException {
var entity = new VoyageAIServiceSettings("model", new RateLimitSettings(1));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, is("""
{"model_id":"model","rate_limit":{"requests_per_minute":1}}"""));
}
@Override
protected Writeable.Reader<VoyageAIServiceSettings> instanceReader() {
return VoyageAIServiceSettings::new;
}
@Override
protected VoyageAIServiceSettings createTestInstance() {
return createRandomWithNonNullUrl();
}
@Override
protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom);
}
public static Map<String, Object> getServiceSettingsMap(String model) {
var map = new HashMap<String, Object>();
map.put(VoyageAIServiceSettings.MODEL_ID, model);
return map;
}
}

View File

@ -0,0 +1,209 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.hamcrest.MatcherAssert;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettingsTests.getTaskSettingsMap;
import static org.hamcrest.Matchers.is;
public class VoyageAIEmbeddingsModelTests extends ESTestCase {
public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() {
var model = createModel("url", "api_key", null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED);
MatcherAssert.assertThat(overriddenModel, is(model));
}
public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() {
var model = createModel("url", "api_key", null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED);
MatcherAssert.assertThat(overriddenModel, is(model));
}
public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings((InputType) null, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() {
var model = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
var overriddenModel = VoyageAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
var expectedModel = createModel("url", "api_key", new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, null), null, null, "model");
MatcherAssert.assertThat(overriddenModel, is(expectedModel));
}
public static VoyageAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) {
return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model);
}
public static VoyageAIEmbeddingsModel createModel(
String url,
String apiKey,
@Nullable Integer tokenLimit,
@Nullable Integer dimensions,
String model
) {
return createModel(url, apiKey, VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model);
}
public static VoyageAIEmbeddingsModel createModel(
String url,
String apiKey,
VoyageAIEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Integer tokenLimit,
@Nullable Integer dimensions,
String model
) {
return new VoyageAIEmbeddingsModel(
"id",
"service",
url,
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
dimensions,
tokenLimit,
false
),
taskSettings,
chunkingSettings,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static VoyageAIEmbeddingsModel createModel(
String url,
String apiKey,
VoyageAIEmbeddingsTaskSettings taskSettings,
@Nullable Integer tokenLimit,
@Nullable Integer dimensions,
String model
) {
return new VoyageAIEmbeddingsModel(
"id",
"service",
url,
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
dimensions,
tokenLimit,
false
),
taskSettings,
null,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static VoyageAIEmbeddingsModel createModel(
String url,
String apiKey,
VoyageAIEmbeddingsTaskSettings taskSettings,
@Nullable Integer tokenLimit,
@Nullable Integer dimensions,
String model,
VoyageAIEmbeddingType embeddingType
) {
return new VoyageAIEmbeddingsModel(
"id",
"service",
url,
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
embeddingType,
SimilarityMeasure.DOT_PRODUCT,
dimensions,
tokenLimit,
false
),
taskSettings,
null,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static VoyageAIEmbeddingsModel createModel(
String url,
String apiKey,
VoyageAIEmbeddingsTaskSettings taskSettings,
@Nullable Integer tokenLimit,
@Nullable Integer dimensions,
String model,
@Nullable SimilarityMeasure similarityMeasure
) {
return new VoyageAIEmbeddingsModel(
"id",
"service",
url,
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
similarityMeasure,
dimensions,
tokenLimit,
false
),
taskSettings,
null,
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
}

View File

@ -0,0 +1,327 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER;
import static org.hamcrest.Matchers.is;
public class VoyageAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<VoyageAIEmbeddingsServiceSettings> {
public static VoyageAIEmbeddingsServiceSettings createRandom() {
SimilarityMeasure similarityMeasure = SimilarityMeasure.DOT_PRODUCT;
Integer dims = 1024;
Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
Boolean dimensionSetByUser = randomBoolean();
var commonSettings = VoyageAIServiceSettingsTests.createRandom();
return new VoyageAIEmbeddingsServiceSettings(
commonSettings,
VoyageAIEmbeddingType.FLOAT,
similarityMeasure,
dims,
maxInputTokens,
dimensionSetByUser
);
}
public void testFromMap() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var dims = 1536;
var maxInputTokens = 512;
var model = "model";
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.SIMILARITY,
similarity,
ServiceFields.DIMENSIONS,
dims,
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens,
VoyageAIServiceSettings.MODEL_ID,
model
)
),
ConfigurationParseContext.PERSISTENT
);
MatcherAssert.assertThat(
serviceSettings,
is(
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
dims,
maxInputTokens,
false
)
)
);
}
public void testFromMap_WithModelId() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var maxInputTokens = 512;
var model = "model";
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.SIMILARITY,
similarity,
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens,
VoyageAIServiceSettings.MODEL_ID,
model
)
),
ConfigurationParseContext.REQUEST
);
MatcherAssert.assertThat(
serviceSettings,
is(
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
null,
maxInputTokens,
false
)
)
);
}
public void testFromMap_WithModelId_WithDimensions() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var dims = 1536;
var maxInputTokens = 512;
var model = "model";
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.SIMILARITY,
similarity,
ServiceFields.DIMENSIONS,
dims,
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens,
VoyageAIServiceSettings.MODEL_ID,
model
)
),
ConfigurationParseContext.REQUEST
);
MatcherAssert.assertThat(
serviceSettings,
is(
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
dims,
maxInputTokens,
true
)
)
);
}
public void testFromMap_DimensionsSetByUserIsFalseInRequestContext() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var maxInputTokens = 512;
var model = "model";
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.SIMILARITY,
similarity,
DIMENSIONS_SET_BY_USER,
true,
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens,
VoyageAIServiceSettings.MODEL_ID,
model
)
),
ConfigurationParseContext.REQUEST
);
MatcherAssert.assertThat(
serviceSettings,
is(
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
null,
maxInputTokens,
false
)
)
);
}
public void testFromMap_DimensionsSetByUserIsSetInPersistentContext() {
var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
var maxInputTokens = 512;
var model = "model";
var dimensionsSetByUser = randomBoolean();
var serviceSettings = VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(
Map.of(
ServiceFields.SIMILARITY,
similarity,
DIMENSIONS_SET_BY_USER,
dimensionsSetByUser,
ServiceFields.MAX_INPUT_TOKENS,
maxInputTokens,
VoyageAIServiceSettings.MODEL_ID,
model
)
),
ConfigurationParseContext.PERSISTENT
);
MatcherAssert.assertThat(
serviceSettings,
is(
new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings(model, null),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.DOT_PRODUCT,
null,
maxInputTokens,
dimensionsSetByUser
)
)
);
}
public void testFromMap_InvalidSimilarity_ThrowsError() {
var similarity = "by_size";
var thrownException = expectThrows(
ValidationException.class,
() -> VoyageAIEmbeddingsServiceSettings.fromMap(
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)),
ConfigurationParseContext.PERSISTENT
)
);
MatcherAssert.assertThat(
thrownException.getMessage(),
is(
"Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] "
+ "must be one of [cosine, dot_product, l2_norm];"
)
);
}
@SuppressWarnings("checkstyle:LineLength")
public void testToXContent_WritesAllValues() throws IOException {
var serviceSettings = new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings("model", new RateLimitSettings(3)),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.COSINE,
5,
10,
false
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(
xContentResult,
is(
"""
{"model_id":"model","""
+ """
"rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}"""
)
);
}
@SuppressWarnings("checkstyle:LineLength")
public void testToXContent_WritesAllValues_DimensionSetByUser() throws IOException {
var serviceSettings = new VoyageAIEmbeddingsServiceSettings(
new VoyageAIServiceSettings("model", new RateLimitSettings(3)),
VoyageAIEmbeddingType.FLOAT,
SimilarityMeasure.COSINE,
5,
10,
true
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(
xContentResult,
is(
"""
{"model_id":"model","""
+ """
"rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}"""
)
);
}
@Override
protected Writeable.Reader<VoyageAIEmbeddingsServiceSettings> instanceReader() {
return VoyageAIEmbeddingsServiceSettings::new;
}
@Override
protected VoyageAIEmbeddingsServiceSettings createTestInstance() {
return createRandom();
}
@Override
protected VoyageAIEmbeddingsServiceSettings mutateInstance(VoyageAIEmbeddingsServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, VoyageAIEmbeddingsServiceSettingsTests::createRandom);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
public static Map<String, Object> getServiceSettingsMap(String model) {
return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model));
}
}

View File

@ -0,0 +1,218 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.embeddings;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithIngestAndSearch;
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
import static org.hamcrest.Matchers.is;
public class VoyageAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<VoyageAIEmbeddingsTaskSettings> {
public static VoyageAIEmbeddingsTaskSettings createRandom() {
var inputType = randomBoolean() ? randomWithIngestAndSearch() : null;
var truncation = randomBoolean();
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation);
}
public void testIsEmpty() {
var randomSettings = createRandom();
var stringRep = Strings.toString(randomSettings);
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
}
public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() {
var initialSettings = createRandom();
var newSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null);
Map<String, Object> newSettingsMap = new HashMap<>();
VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
Collections.unmodifiableMap(newSettingsMap)
);
assertEquals(initialSettings.getInputType(), updatedSettings.getInputType());
}
public void testUpdatedTaskSettings_Updated_UseNewSettings() {
var initialSettings = createRandom();
var newSettings = new VoyageAIEmbeddingsTaskSettings(randomWithIngestAndSearch(), randomBoolean());
Map<String, Object> newSettingsMap = new HashMap<>();
newSettingsMap.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString());
VoyageAIEmbeddingsTaskSettings updatedSettings = (VoyageAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
Collections.unmodifiableMap(newSettingsMap)
);
assertEquals(newSettings.getInputType(), updatedSettings.getInputType());
}
public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
MatcherAssert.assertThat(
VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())),
is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))
);
}
public void testFromMap_CreatesEmptySettings_WhenMapIsNull() {
MatcherAssert.assertThat(
VoyageAIEmbeddingsTaskSettings.fromMap(null),
is(new VoyageAIEmbeddingsTaskSettings((InputType) null, null))
);
}
public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() {
MatcherAssert.assertThat(
VoyageAIEmbeddingsTaskSettings.fromMap(
new HashMap<>(
Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, false)
)
),
is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false))
);
}
public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
var exception = expectThrows(
ValidationException.class,
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, "abc", VoyageAIServiceFields.TRUNCATION, false))
)
);
MatcherAssert.assertThat(
exception.getMessage(),
is(
Strings.format(
"Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];",
getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
)
)
);
}
public void testFromMap_ReturnsFailure_WhenTruncationIsInvalid() {
var exception = expectThrows(
ValidationException.class,
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
new HashMap<>(
Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString(), VoyageAIServiceFields.TRUNCATION, "abc")
)
)
);
MatcherAssert.assertThat(
exception.getMessage(),
is("Validation Failed: 1: field [truncation] is not of the expected type. The value [abc] cannot be converted to a [Boolean];")
);
}
public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() {
var exception = expectThrows(
ValidationException.class,
() -> VoyageAIEmbeddingsTaskSettings.fromMap(
new HashMap<>(Map.of(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString()))
)
);
MatcherAssert.assertThat(
exception.getMessage(),
is(
Strings.format(
"Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];",
getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
)
)
);
}
private static <E extends Enum<E>> String getValidValuesSortedAndCombined(EnumSet<E> validValues) {
var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
Arrays.sort(validValuesAsStrings);
return String.join(", ", validValuesAsStrings);
}
public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
var thrownException = expectThrows(AssertionError.class, () -> new VoyageAIEmbeddingsTaskSettings(InputType.UNSPECIFIED, null));
MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
}
public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() {
var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, false);
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
taskSettings,
VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
InputType.UNSPECIFIED
);
MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
}
public void testOf_UsesRequestTaskSettings() {
var taskSettings = new VoyageAIEmbeddingsTaskSettings((InputType) null, null);
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
taskSettings,
new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true),
InputType.UNSPECIFIED
);
MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true)));
}
public void testOf_UsesRequestTaskSettings_AndRequestInputType() {
var taskSettings = new VoyageAIEmbeddingsTaskSettings(InputType.SEARCH, true);
var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettings.of(
taskSettings,
new VoyageAIEmbeddingsTaskSettings((InputType) null, null),
InputType.INGEST
);
MatcherAssert.assertThat(overriddenTaskSettings, is(new VoyageAIEmbeddingsTaskSettings(InputType.INGEST, true)));
}
@Override
protected Writeable.Reader<VoyageAIEmbeddingsTaskSettings> instanceReader() {
return VoyageAIEmbeddingsTaskSettings::new;
}
@Override
protected VoyageAIEmbeddingsTaskSettings createTestInstance() {
return createRandom();
}
@Override
protected VoyageAIEmbeddingsTaskSettings mutateInstance(VoyageAIEmbeddingsTaskSettings instance) throws IOException {
return randomValueOtherThan(instance, VoyageAIEmbeddingsTaskSettingsTests::createRandom);
}
public static Map<String, Object> getTaskSettingsMapEmpty() {
return new HashMap<>();
}
public static Map<String, Object> getTaskSettingsMap(@Nullable InputType inputType) {
var map = new HashMap<String, Object>();
if (inputType != null) {
map.put(VoyageAIEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
}
return map;
}
}

View File

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
public class VoyageAIRerankModelTests {
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK, @Nullable Boolean truncation) {
return new VoyageAIRerankModel(
"id",
"service",
ESTestCase.randomAlphaOfLength(10),
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, null, truncation),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static VoyageAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topK) {
return new VoyageAIRerankModel(
"id",
"service",
ESTestCase.randomAlphaOfLength(10),
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, null, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK) {
return new VoyageAIRerankModel(
"id",
"service",
ESTestCase.randomAlphaOfLength(10),
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, null, null),
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
);
}
public static VoyageAIRerankModel createModel(String modelId, @Nullable Integer topK, Boolean returnDocuments, Boolean truncation) {
return new VoyageAIRerankModel(
"id",
"service",
ESTestCase.randomAlphaOfLength(10),
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
);
}
public static VoyageAIRerankModel createModel(
String url,
String modelId,
@Nullable Integer topK,
Boolean returnDocuments,
Boolean truncation
) {
return new VoyageAIRerankModel(
"id",
"service",
url,
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8))
);
}
public static VoyageAIRerankModel createModel(
String url,
String apiKey,
String modelId,
@Nullable Integer topK,
Boolean returnDocuments,
Boolean truncation
) {
return new VoyageAIRerankModel(
"id",
"service",
url,
new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(modelId, null)),
new VoyageAIRerankTaskSettings(topK, returnDocuments, truncation),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettingsTests;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
public class VoyageAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<VoyageAIRerankServiceSettings> {
public static VoyageAIRerankServiceSettings createRandom() {
return new VoyageAIRerankServiceSettings(
new VoyageAIServiceSettings(randomAlphaOfLength(10), RateLimitSettingsTests.createRandom())
);
}
public void testToXContent_WritesAllValues() throws IOException {
var url = "http://www.abc.com";
var model = "model";
var serviceSettings = new VoyageAIRerankServiceSettings(new VoyageAIServiceSettings(model, null));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model_id":"model",
"rate_limit": {
"requests_per_minute": 2000
}
}
"""));
}
@Override
protected Writeable.Reader<VoyageAIRerankServiceSettings> instanceReader() {
return VoyageAIRerankServiceSettings::new;
}
@Override
protected VoyageAIRerankServiceSettings createTestInstance() {
return createRandom();
}
@Override
protected VoyageAIRerankServiceSettings mutateInstance(VoyageAIRerankServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, VoyageAIRerankServiceSettingsTests::createRandom);
}
@Override
protected VoyageAIRerankServiceSettings mutateInstanceForVersion(VoyageAIRerankServiceSettings instance, TransportVersion version) {
return instance;
}
public static Map<String, Object> getServiceSettingsMap(@Nullable String model) {
return new HashMap<>(VoyageAIServiceSettingsTests.getServiceSettingsMap(model));
}
}

View File

@ -0,0 +1,162 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.voyageai.rerank;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceFields;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
public class VoyageAIRerankTaskSettingsTests extends AbstractWireSerializingTestCase<VoyageAIRerankTaskSettings> {
public static VoyageAIRerankTaskSettings createRandom() {
var returnDocuments = randomBoolean() ? randomBoolean() : null;
var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null;
var truncation = randomBoolean() ? randomBoolean() : null;
return new VoyageAIRerankTaskSettings(topNDocsOnly, returnDocuments, truncation);
}
public void testFromMap_WithInvalidTruncation_ThrowsValidationException() {
Map<String, Object> taskMap = Map.of(
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
true,
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
5,
VoyageAIServiceFields.TRUNCATION,
"invalid"
);
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
assertThat(thrownException.getMessage(), containsString("field [truncation] is not of the expected type"));
}
public void testFromMap_WithValidValues_ReturnsSettings() {
Map<String, Object> taskMap = Map.of(
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
true,
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
5,
VoyageAIServiceFields.TRUNCATION,
true
);
var settings = VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap));
assertTrue(settings.getReturnDocuments());
assertEquals(5, settings.getTopKDocumentsOnly().intValue());
assertTrue(settings.getTruncation());
}
public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
var settings = VoyageAIRerankTaskSettings.fromMap(Map.of());
assertNull(settings.getReturnDocuments());
assertNull(settings.getTopKDocumentsOnly());
assertNull(settings.getTruncation());
}
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
Map<String, Object> taskMap = Map.of(
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
"invalid",
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
5
);
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
}
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
Map<String, Object> taskMap = Map.of(
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
true,
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
"invalid"
);
var thrownException = expectThrows(ValidationException.class, () -> VoyageAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
assertThat(
thrownException.getMessage(),
containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Integer];")
);
}
public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of());
assertEquals(initialSettings, updatedSettings);
}
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
Map<String, Object> newSettings = Map.of(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, false);
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertFalse(updatedSettings.getReturnDocuments());
assertTrue(updatedSettings.getTruncation());
assertEquals(initialSettings.getTopKDocumentsOnly(), updatedSettings.getTopKDocumentsOnly());
}
public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() {
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
Map<String, Object> newSettings = Map.of(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, 7);
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertTrue(updatedSettings.getTruncation());
assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue());
assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments());
}
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
var initialSettings = new VoyageAIRerankTaskSettings(5, true, true);
Map<String, Object> newSettings = Map.of(
VoyageAIRerankTaskSettings.RETURN_DOCUMENTS,
false,
VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY,
7
);
VoyageAIRerankTaskSettings updatedSettings = (VoyageAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertTrue(updatedSettings.getTruncation());
assertFalse(updatedSettings.getReturnDocuments());
assertEquals(7, updatedSettings.getTopKDocumentsOnly().intValue());
}
@Override
protected Writeable.Reader<VoyageAIRerankTaskSettings> instanceReader() {
return VoyageAIRerankTaskSettings::new;
}
@Override
protected VoyageAIRerankTaskSettings createTestInstance() {
return createRandom();
}
@Override
protected VoyageAIRerankTaskSettings mutateInstance(VoyageAIRerankTaskSettings instance) throws IOException {
return randomValueOtherThan(instance, VoyageAIRerankTaskSettingsTests::createRandom);
}
public static Map<String, Object> getTaskSettingsMapEmpty() {
return new HashMap<>();
}
public static Map<String, Object> getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, Boolean returnDocuments) {
var map = new HashMap<String, Object>();
if (topNDocumentsOnly != null) {
map.put(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topNDocumentsOnly.toString());
}
if (returnDocuments != null) {
map.put(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments.toString());
}
return map;
}
}