Add Azure AI Rerank support (#129848)

* Add Azure AI Rerank support

* address comments

* address comments

* refactor azure ai studio service

* update rerank task settings test

* add provider for rerank
This commit is contained in:
Evgenii-Kazannik 2025-07-17 15:24:02 +02:00 committed by GitHub
parent f9eee6c216
commit d06b0c8c17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 2147 additions and 81 deletions

View File

@ -0,0 +1,5 @@
pr: 129848
summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []

View File

@ -341,6 +341,7 @@ public class TransportVersions {
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -111,6 +111,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"azureaistudio",
"cohere",
"elasticsearch",
"googlevertexai",

View File

@ -50,6 +50,8 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
@ -306,6 +308,17 @@ public class InferenceNamedWriteablesProvider {
AzureAiStudioChatCompletionTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureAiStudioRerankServiceSettings.NAME,
AzureAiStudioRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
);
}
private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio;
public class AzureAiStudioConstants {
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
public static final String RERANK_URI_PATH = "/v1/rerank";
// common service settings fields
public static final String TARGET_FIELD = "target";
@ -22,6 +23,10 @@ public class AzureAiStudioConstants {
public static final String DIMENSIONS_FIELD = "dimensions";
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
// rerank task settings fields
public static final String DOCUMENTS_FIELD = "documents";
public static final String QUERY_FIELD = "query";
// embeddings task settings fields
public static final String USER_FIELD = "user";
@ -35,5 +40,9 @@ public class AzureAiStudioConstants {
public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
// rerank task settings fields
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
public static final String TOP_N_FIELD = "top_n";
private AzureAiStudioConstants() {}
}

View File

@ -22,6 +22,9 @@ public final class AzureAiStudioProviderCapabilities {
// these providers have chat completion inference (all providers at the moment)
public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());
// these providers have rerank inference
public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);
// these providers allow token ("pay as you go") embeddings endpoints
public static final List<AzureAiStudioProvider> tokenEmbeddingsProviders = List.of(
AzureAiStudioProvider.OPENAI,
@ -31,6 +34,9 @@ public final class AzureAiStudioProviderCapabilities {
// these providers allow realtime embeddings endpoints (none at the moment)
public static final List<AzureAiStudioProvider> realtimeEmbeddingsProviders = List.of();
// these providers allow realtime rerank endpoints (none at the moment)
public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();
// these providers allow token ("pay as you go") chat completion endpoints
public static final List<AzureAiStudioProvider> tokenChatCompletionProviders = List.of(
AzureAiStudioProvider.OPENAI,
@ -54,6 +60,9 @@ public final class AzureAiStudioProviderCapabilities {
case TEXT_EMBEDDING -> {
return embeddingProviders.contains(provider);
}
case RERANK -> {
return rerankProviders.contains(provider);
}
default -> {
return false;
}
@ -76,6 +85,11 @@ public final class AzureAiStudioProviderCapabilities {
? tokenEmbeddingsProviders.contains(provider)
: realtimeEmbeddingsProviders.contains(provider);
}
case RERANK -> {
return (endpointType == AzureAiStudioEndpointType.TOKEN)
? rerankProviders.contains(provider)
: realtimeRerankProviders.contains(provider);
}
default -> {
return false;
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
import java.util.function.Supplier;
public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);
private static final ResponseHandler HANDLER = createRerankHandler();
private final AzureAiStudioRerankModel model;
public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}
@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestRerankFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
model,
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN()
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
}
private static ResponseHandler createRerankHandler() {
// This currently covers response handling for Azure AI Studio
return new AzureMistralOpenAiExternalResponseHandler(
"azure ai studio rerank",
new AzureAiStudioRerankResponseEntity(),
ErrorMessageResponseEntity::fromResponse,
true
);
}
}

View File

@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -71,10 +72,10 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFie
public class AzureAiStudioService extends SenderService {
static final String NAME = "azureaistudio";
public static final String NAME = "azureaistudio";
private static final String SERVICE_NAME = "Azure AI Studio";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK);
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
@ -270,8 +271,9 @@ public class AzureAiStudioService extends SenderService {
ConfigurationParseContext context
) {
if (taskType == TaskType.TEXT_EMBEDDING) {
var embeddingsModel = new AzureAiStudioEmbeddingsModel(
AzureAiStudioModel model;
switch (taskType) {
case TEXT_EMBEDDING -> model = new AzureAiStudioEmbeddingsModel(
inferenceEntityId,
taskType,
NAME,
@ -281,16 +283,7 @@ public class AzureAiStudioService extends SenderService {
secretSettings,
context
);
checkProviderAndEndpointTypeForTask(
TaskType.TEXT_EMBEDDING,
embeddingsModel.getServiceSettings().provider(),
embeddingsModel.getServiceSettings().endpointType()
);
return embeddingsModel;
}
if (taskType == TaskType.COMPLETION) {
var completionModel = new AzureAiStudioChatCompletionModel(
case COMPLETION -> model = new AzureAiStudioChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
@ -299,15 +292,12 @@ public class AzureAiStudioService extends SenderService {
secretSettings,
context
);
checkProviderAndEndpointTypeForTask(
TaskType.COMPLETION,
completionModel.getServiceSettings().provider(),
completionModel.getServiceSettings().endpointType()
);
return completionModel;
case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
return model;
}
private AzureAiStudioModel createModelFromPersistent(

View File

@ -13,8 +13,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import java.util.Map;
import java.util.Objects;
@ -49,4 +51,12 @@ public class AzureAiStudioActionCreator implements AzureAiStudioActionVisitor {
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
@Override
public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import java.util.Map;
@ -17,4 +18,6 @@ public interface AzureAiStudioActionVisitor {
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
}

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
public class AzureAiStudioRerankRequest extends AzureAiStudioRequest {
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final AzureAiStudioRerankModel rerankModel;
public AzureAiStudioRerankRequest(
AzureAiStudioRerankModel model,
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN
) {
super(model);
this.rerankModel = Objects.requireNonNull(model);
this.query = query;
this.input = Objects.requireNonNull(input);
this.returnDocuments = returnDocuments;
this.topN = topN;
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(this.uri);
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
setAuthHeader(httpPost, rerankModel);
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public Request truncate() {
// Not applicable for rerank, only used in text embedding requests
return this;
}
@Override
public boolean[] getTruncationInfo() {
// Not applicable for rerank, only used in text embedding requests
return null;
}
private AzureAiStudioRerankRequestEntity createRequestEntity() {
return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, rerankModel.getTaskSettings());
}
}

View File

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.QUERY_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
public record AzureAiStudioRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
AzureAiStudioRerankTaskSettings taskSettings
) implements ToXContentObject {
public AzureAiStudioRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(input);
Objects.requireNonNull(taskSettings);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DOCUMENTS_FIELD, input);
builder.field(QUERY_FIELD, query);
if (returnDocuments != null) {
builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
} else if (taskSettings.returnDocuments() != null) {
builder.field(RETURN_DOCUMENTS_FIELD, taskSettings.returnDocuments());
}
if (topN != null) {
builder.field(TOP_N_FIELD, topN);
} else if (taskSettings.topN() != null) {
builder.field(TOP_N_FIELD, taskSettings.topN());
}
builder.endObject();
return builder;
}
}

View File

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureaistudio.action.AzureAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RERANK_URI_PATH;
public class AzureAiStudioRerankModel extends AzureAiStudioModel {
public static AzureAiStudioRerankModel of(AzureAiStudioRerankModel model, Map<String, Object> taskSettings) {
if (taskSettings == null || taskSettings.isEmpty()) {
return model;
}
final var requestTaskSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(taskSettings);
final var taskSettingToUse = AzureAiStudioRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
return new AzureAiStudioRerankModel(model, taskSettingToUse);
}
public AzureAiStudioRerankModel(
String inferenceEntityId,
AzureAiStudioRerankServiceSettings serviceSettings,
AzureAiStudioRerankTaskSettings taskSettings,
DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, TaskType.RERANK, AzureAiStudioService.NAME, serviceSettings, taskSettings),
new ModelSecrets(secrets)
);
}
public AzureAiStudioRerankModel(
String inferenceEntityId,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
AzureAiStudioRerankServiceSettings.fromMap(serviceSettings, context),
AzureAiStudioRerankTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
}
public AzureAiStudioRerankModel(AzureAiStudioRerankModel model, AzureAiStudioRerankTaskSettings taskSettings) {
super(model, taskSettings, model.getServiceSettings().rateLimitSettings());
}
@Override
public AzureAiStudioRerankServiceSettings getServiceSettings() {
return (AzureAiStudioRerankServiceSettings) super.getServiceSettings();
}
@Override
public AzureAiStudioRerankTaskSettings getTaskSettings() {
return (AzureAiStudioRerankTaskSettings) super.getTaskSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return super.getSecretSettings();
}
@Override
protected URI getEndpointUri() throws URISyntaxException {
return new URI(this.target + RERANK_URI_PATH);
}
@Override
public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map<String, Object> taskSettings) {
return creator.create(this, taskSettings);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
public record AzureAiStudioRerankRequestTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
public static final AzureAiStudioRerankRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioRerankRequestTaskSettings(null, null);
/**
* Extracts the task settings from a map. All settings are considered optional and the absence of a setting
* does not throw an error.
*
* @param map the settings received from a request
* @return a {@link AzureAiStudioRerankRequestTaskSettings}
*/
public static AzureAiStudioRerankRequestTaskSettings fromMap(Map<String, Object> map) {
if (map.isEmpty()) {
return AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS;
}
final var validationException = new ValidationException();
final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new AzureAiStudioRerankRequestTaskSettings(returnDocuments, topN);
}
}

View File

@ -0,0 +1,123 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
public class AzureAiStudioRerankServiceSettings extends AzureAiStudioServiceSettings {
public static final String NAME = "azure_ai_studio_rerank_service_settings";
public static AzureAiStudioRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
final var validationException = new ValidationException();
final var settings = rerankSettingsFromMap(map, validationException, context);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new AzureAiStudioRerankServiceSettings(settings);
}
private static AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields rerankSettingsFromMap(
Map<String, Object> map,
ValidationException validationException,
ConfigurationParseContext context
) {
final var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
return new AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields(baseSettings);
}
private record AzureAiStudioRerankCommonFields(BaseAzureAiStudioCommonFields baseCommonFields) {}
public AzureAiStudioRerankServiceSettings(
String target,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
@Nullable RateLimitSettings rateLimitSettings
) {
super(target, provider, endpointType, rateLimitSettings);
}
public AzureAiStudioRerankServiceSettings(StreamInput in) throws IOException {
super(in);
}
private AzureAiStudioRerankServiceSettings(AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields fields) {
this(
fields.baseCommonFields.target(),
fields.baseCommonFields.provider(),
fields.baseCommonFields.endpointType(),
fields.baseCommonFields.rateLimitSettings()
);
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
super.addXContentFields(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
super.addExposedXContentFields(builder, params);
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AzureAiStudioRerankServiceSettings that = (AzureAiStudioRerankServiceSettings) o;
return Objects.equals(target, that.target)
&& Objects.equals(provider, that.provider)
&& Objects.equals(endpointType, that.endpointType)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(target, provider, endpointType, rateLimitSettings);
}
}

View File

@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
/**
* Defines the rerank task settings for the AzureAiStudio service.
*/
public class AzureAiStudioRerankTaskSettings implements TaskSettings {
public static final String NAME = "azure_ai_studio_rerank_task_settings";
public static AzureAiStudioRerankTaskSettings fromMap(Map<String, Object> map) {
final var validationException = new ValidationException();
final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
}
/**
* Creates a new {@link AzureAiStudioRerankTaskSettings} object by overriding the values in originalSettings with the ones
* passed in via requestSettings if the fields are not null.
* @param originalSettings the original {@link AzureAiStudioRerankTaskSettings} from the inference entity configuration from storage
* @param requestSettings the {@link AzureAiStudioRerankTaskSettings} from the request
* @return a new {@link AzureAiStudioRerankTaskSettings}
*/
public static AzureAiStudioRerankTaskSettings of(
AzureAiStudioRerankTaskSettings originalSettings,
AzureAiStudioRerankRequestTaskSettings requestSettings
) {
final var returnDocuments = requestSettings.returnDocuments() == null
? originalSettings.returnDocuments()
: requestSettings.returnDocuments();
final var topN = requestSettings.topN() == null ? originalSettings.topN() : requestSettings.topN();
return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
}
public AzureAiStudioRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
this.returnDocuments = returnDocuments;
this.topN = topN;
}
public AzureAiStudioRerankTaskSettings(StreamInput in) throws IOException {
this.returnDocuments = in.readOptionalBoolean();
this.topN = in.readOptionalVInt();
}
private final Boolean returnDocuments;
private final Integer topN;
public Boolean returnDocuments() {
return returnDocuments;
}
public Integer topN() {
return topN;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
}
@Override
public boolean isEmpty() {
return returnDocuments == null && topN == null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalVInt(topN);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (returnDocuments != null) {
builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
}
if (topN != null) {
builder.field(TOP_N_FIELD, topN);
}
builder.endObject();
return builder;
}
@Override
public String toString() {
return "AzureAiStudioRerankTaskSettings{" + ", returnDocuments=" + returnDocuments + ", topN=" + topN + '}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AzureAiStudioRerankTaskSettings that = (AzureAiStudioRerankTaskSettings) o;
return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN);
}
@Override
public int hashCode() {
return Objects.hash(returnDocuments, topN);
}
@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
AzureAiStudioRerankRequestTaskSettings requestSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(newSettings));
return of(this, requestSettings);
}
}

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.response;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class AzureAiStudioRerankResponseEntity extends BaseResponseEntity {
/**
* Parses the AzureAiStudio Search rerank json response.
* For a request like:
*
* <pre>
* <code>
* {
* "model": "rerank-v3.5",
* "query": "What is the capital of the United States?",
* "top_n": 2,
* "documents": ["Carson City is the capital city of the American state of Nevada.",
* "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
* "results": [
* {
* "document": {
* "text": "Carson City is the capital city of the American state of Nevada."
* },
* "index": 0,
* "relevance_score": 0.1728413
* },
* {
* "document": {
* "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."
* },
* "index": 1,
* "relevance_score": 0.031005697
* }
* ],
* "meta": {
* "api_version": {
* "version": "1"
* },
* "billed_units": {
* "search_units": 1
* }
* }
* }
* </code>
* </pre>
*/
@Override
protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
final var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList());
}
}
record RerankResult(List<RerankResultEntry> entries) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
RerankResult.class.getSimpleName(),
true,
args -> new RerankResult((List<RerankResultEntry>) args[0])
);
static {
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
}
}
record RerankResultEntry(Float relevanceScore, Integer index, @Nullable ObjectParser document) {
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
RerankResultEntry.class.getSimpleName(),
args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (ObjectParser) args[2])
);
static {
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
PARSER.declareInt(constructorArg(), new ParseField("index"));
PARSER.declareObject(optionalConstructorArg(), ObjectParser.PARSER::apply, new ParseField("document"));
}
public RankedDocsResults.RankedDoc toRankedDoc() {
return new RankedDocsResults.RankedDoc(index, relevanceScore, document == null ? null : document.text);
}
}
record ObjectParser(String text) {
public static final ConstructingObjectParser<ObjectParser, Void> PARSER = new ConstructingObjectParser<>(
ObjectParser.class.getSimpleName(),
args -> new AzureAiStudioRerankResponseEntity.ObjectParser((String) args[0])
);
static {
PARSER.declareString(optionalConstructorArg(), new ParseField("text"));
}
}
}

View File

@ -38,6 +38,7 @@ import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@ -54,6 +55,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
@ -219,6 +224,33 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
var rerankModel = (AzureAiStudioRerankModel) model;
assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
assertThat(rerankModel.getSecretSettings().apiKey().toString(), is("secret"));
assertNull(rerankModel.getTaskSettings().returnDocuments());
assertNull(rerankModel.getTaskSettings().topN());
}, exception -> fail("Unexpected exception: " + exception));
service.parseRequestConfig(
"id",
TaskType.RERANK,
getRequestConfigMap(
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
getRerankTaskSettingsMap(null, null),
getSecretSettingsMap("secret")
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
@ -441,6 +473,80 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
serviceSettings.put("extra_key", "value");
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
);
}
);
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
try (var service = createService()) {
var taskSettings = getRerankTaskSettingsMap(null, null);
taskSettings.put("extra_key", "value");
var config = getRequestConfigMap(
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
taskSettings,
getSecretSettingsMap("secret")
);
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
);
}
);
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
}
}
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
try (var service = createService()) {
var secretSettings = getSecretSettingsMap("secret");
secretSettings.put("extra_key", "value");
var config = getRequestConfigMap(
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
getRerankTaskSettingsMap(null, null),
secretSettings
);
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
);
}
);
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
}
}
public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException {
try (var service = createService()) {
var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null);
@ -505,6 +611,45 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForRerank() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "databricks", "token");
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(exception.getMessage(), is("The [rerank] task type for provider [databricks] is not available"));
}
);
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
}
}
public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "realtime");
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available")
);
}
);
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
}
}
public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
@ -603,6 +748,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
getRerankTaskSettingsMap(true, 2),
getSecretSettingsMap("secret")
);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
var chatCompletionModel = (AzureAiStudioRerankModel) model;
assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
assertThat(chatCompletionModel.getTaskSettings().returnDocuments(), is(true));
assertThat(chatCompletionModel.getTaskSettings().topN(), is(2));
}
}
public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException {
try (var service = createService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
@ -747,6 +913,48 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
serviceSettings.put("extra_key", "value");
var taskSettings = getRerankTaskSettingsMap(true, 2);
var secretSettings = getSecretSettingsMap("secret");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
}
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
var taskSettings = getRerankTaskSettingsMap(true, 2);
taskSettings.put("extra_key", "value");
var secretSettings = getSecretSettingsMap("secret");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
}
}
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
try (var service = createService()) {
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
var taskSettings = getRerankTaskSettingsMap(true, 2);
var secretSettings = getSecretSettingsMap("secret");
secretSettings.put("extra_key", "value");
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
}
}
public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
@ -842,6 +1050,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws IOException {
try (var service = createService()) {
var config = getPersistedConfigMap(
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
getRerankTaskSettingsMap(true, 2),
Map.of()
);
var model = service.parsePersistedConfig("id", TaskType.RERANK, config.config());
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
var rerankModel = (AzureAiStudioRerankModel) model;
assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
assertThat(rerankModel.getTaskSettings().returnDocuments(), is(true));
assertThat(rerankModel.getTaskSettings().topN(), is(2));
}
}
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
@ -1184,6 +1413,47 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
public void testInfer_WithRerankModel() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
var model = AzureAiStudioRerankModelTests.createModel(
"id",
getUrl(webServer),
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
"apikey"
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
model,
"query",
false,
2,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
assertThat(result, CoreMatchers.instanceOf(RankedDocsResults.class));
var rankedDocsResults = (RankedDocsResults) result;
var rankedDocs = rankedDocsResults.getRankedDocs();
assertThat(rankedDocs.size(), is(2));
assertThat(rankedDocs.get(0).relevanceScore(), is(0.1111111F));
assertThat(rankedDocs.get(0).index(), is(0));
assertThat(rankedDocs.get(1).relevanceScore(), is(0.2222222F));
assertThat(rankedDocs.get(1).index(), is(1));
}
}
public void testInfer_UnauthorisedResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@ -1320,7 +1590,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
{
"service": "azureaistudio",
"name": "Azure AI Studio",
"task_types": ["text_embedding", "completion"],
"task_types": ["text_embedding", "rerank", "completion"],
"configurations": {
"dimensions": {
"description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-embeddings.",
@ -1338,7 +1608,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion"]
"supported_task_types": ["text_embedding", "rerank", "completion"]
},
"provider": {
"description": "The model provider for your deployment.",
@ -1347,7 +1617,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion"]
"supported_task_types": ["text_embedding", "rerank", "completion"]
},
"api_key": {
"description": "API Key for the provider you're connecting to.",
@ -1356,7 +1626,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
"sensitive": true,
"updatable": true,
"type": "str",
"supported_task_types": ["text_embedding", "completion"]
"supported_task_types": ["text_embedding", "rerank", "completion"]
},
"rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.",
@ -1365,7 +1635,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding", "completion"]
"supported_task_types": ["text_embedding", "rerank", "completion"]
},
"target": {
"description": "The target URL of your Azure AI Studio model deployment.",
@ -1374,7 +1644,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion"]
"supported_task_types": ["text_embedding", "rerank", "completion"]
}
}
}
@ -1462,6 +1732,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
return AzureAiStudioChatCompletionServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
}
private static HashMap<String, Object> getRerankServiceSettingsMap(String target, String provider, String endpointType) {
return AzureAiStudioRerankServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
}
public static Map<String, Object> getChatCompletionTaskSettingsMap(
@Nullable Double temperature,
@Nullable Double topP,
@ -1471,6 +1745,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
return AzureAiStudioChatCompletionTaskSettingsTests.getTaskSettingsMap(temperature, topP, doSample, maxNewTokens);
}
public static Map<String, Object> getRerankTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
return AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(returnDocuments, topN);
}
private static Map<String, Object> getSecretSettingsMap(String apiKey) {
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
}
@ -1520,4 +1798,28 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}
}
""";
private static final String testRerankTokenResponseJson = """
{
"id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
"results": [
{
"index": 0,
"relevance_score": 0.1111111
},
{
"index": 1,
"relevance_score": 0.2222222
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
""";
}

View File

@ -20,6 +20,7 @@ import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.inference.InputTypeTests;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInpu
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
@ -34,6 +36,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEnd
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
import org.junit.After;
import org.junit.Before;
@ -78,31 +81,20 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
}
public void testEmbeddingsRequestAction() throws IOException {
var senderFactory = new HttpRequestSender.Factory(
final var senderFactory = new HttpRequestSender.Factory(
ServiceComponentsTests.createWithEmptySettings(threadPool),
clientManager,
mockClusterServiceEmpty()
);
var timeoutSettings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
var serviceComponents = new ServiceComponents(
threadPool,
mock(ThrottlerManager.class),
timeoutSettings,
TruncatorTests.createTruncator()
);
final var serviceComponents = getServiceComponents();
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
var model = AzureAiStudioEmbeddingsModelTests.createModel(
final var model = AzureAiStudioEmbeddingsModelTests.createModel(
"id",
"http://will-be-replaced.local",
AzureAiStudioProvider.OPENAI,
@ -111,21 +103,18 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
);
model.setURI(getUrl(webServer));
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
var action = creator.create(model, Map.of());
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var inputType = InputTypeTests.randomSearchAndIngestWithNull();
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
final var action = creator.create(model, Map.of());
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
final var inputType = InputTypeTests.randomSearchAndIngestWithNull();
action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
final var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
assertWebServerRequest(API_KEY_HEADER, "apikey");
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(InputType.isSpecified(inputType) ? 2 : 1));
assertThat(requestMap.get("input"), is(List.of("abc")));
if (InputType.isSpecified(inputType)) {
@ -136,27 +125,15 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
}
public void testChatCompletionRequestAction() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var timeoutSettings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
var serviceComponents = new ServiceComponents(
threadPool,
mock(ThrottlerManager.class),
timeoutSettings,
TruncatorTests.createTruncator()
);
final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
final var serviceComponents = getServiceComponents();
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
var webserverUrl = getUrl(webServer);
var model = AzureAiStudioChatCompletionModelTests.createModel(
final var webserverUrl = getUrl(webServer);
final var model = AzureAiStudioChatCompletionModelTests.createModel(
"id",
"http://will-be-replaced.local",
AzureAiStudioProvider.COHERE,
@ -165,30 +142,101 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
);
model.setURI(webserverUrl);
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
var action = creator.create(model, Map.of());
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
final var action = creator.create(model, Map.of());
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
final var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string"))));
assertThat(webServer.requests(), hasSize(1));
MockRequest request = webServer.requests().get(0);
assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
assertNull(request.getUri().getQuery());
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("apikey"));
var requestMap = entityAsMap(request.getBody());
final MockRequest request = webServer.requests().get(0);
final var requestMap = entityAsMap(request.getBody());
assertThat(requestMap.size(), is(1));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
}
}
private static String testEmbeddingsTokenResponseJson = """
public void testRerankRequestAction() throws IOException {
final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
final var serviceComponents = getServiceComponents();
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
final var webserverUrl = getUrl(webServer);
final var model = AzureAiStudioRerankModelTests.createModel(
"id",
"http://will-be-replaced.local",
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
"apikey"
);
model.setURI(webserverUrl);
final var topN = 2;
final var returnDocuments = false;
final var query = "query";
final var documents = List.of("document 1", "document 2", "document 3");
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
final var action = creator.create(model, Map.of());
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(
new QueryAndDocsInputs(query, documents, returnDocuments, topN, false),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
final var result = listener.actionGet(TIMEOUT);
assertThat(
result.asMap(),
equalTo(
RankedDocsResultsTests.buildExpectationRerank(
List.of(
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.1111111f)),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.2222222f))
)
)
)
);
assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("documents"), is(documents));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("top_n"), is(topN));
assertThat(requestMap.get("return_documents"), is(returnDocuments));
}
}
private void assertWebServerRequest(String authorization, String authorizationHeaderValue) {
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
assertThat(webServer.requests().get(0).getHeader(authorization), equalTo(authorizationHeaderValue));
}
private ServiceComponents getServiceComponents() {
final var timeoutSettings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
return new ServiceComponents(threadPool, mock(ThrottlerManager.class), timeoutSettings, TruncatorTests.createTruncator());
}
private final String testEmbeddingsTokenResponseJson = """
{
"object": "list",
"data": [
@ -209,7 +257,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
}
""";
private static String testCompletionTokenResponseJson = """
private final String testCompletionTokenResponseJson = """
{
"choices": [
{
@ -233,4 +281,27 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
}
}""";
private final String testRerankTokenResponseJson = """
{
"id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
"results": [
{
"index": 0,
"relevance_score": 0.1111111
},
{
"index": 1,
"relevance_score": 0.2222222
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
""";
}

View File

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
public class AzureAiStudioRerankRequestEntityTests extends ESTestCase {
private static final String INPUT = "texts";
private static final String QUERY = "query";
private static final Boolean RETURN_DOCUMENTS = false;
private static final Integer TOP_N = 8;
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
final var entity = new AzureAiStudioRerankRequestEntity(
QUERY,
List.of(INPUT),
Boolean.TRUE,
TOP_N,
new AzureAiStudioRerankTaskSettings(RETURN_DOCUMENTS, TOP_N)
);
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
final String xContentResult = Strings.toString(builder);
final String expected = """
{"documents":["texts"],
"query":"query",
"return_documents":true,
"top_n":8}""";
assertEquals(stripWhitespace(expected), xContentResult);
}
public void testXContent_WritesMinimalFields() throws IOException {
final var entity = new AzureAiStudioRerankRequestEntity(
QUERY,
List.of(INPUT),
null,
null,
new AzureAiStudioRerankTaskSettings(null, null)
);
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
final String xContentResult = Strings.toString(builder);
final String expected = """
{"documents":["texts"],"query":"query"}""";
assertEquals(stripWhitespace(expected), xContentResult);
}
}

View File

@ -0,0 +1,159 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class AzureAiStudioRerankRequestTests extends ESTestCase {
private static final String TARGET_URI = "http://testtarget.local";
private static final String INPUT = "documents";
private static final String QUERY = "query";
private static final Integer TOP_N = 2;
public void testCreateRequest_WithCohereProviderTokenEndpoint_NoParams() throws IOException {
final var input = randomAlphaOfLength(3);
final var query = randomAlphaOfLength(3);
final var apikey = randomAlphaOfLength(3);
final var request = createRequest(TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey, query, input);
final var httpPost = getHttpPost(request, apikey);
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(2));
assertThat(requestMap.get(QUERY), is(query));
assertThat(requestMap.get(INPUT), is(List.of(input)));
}
public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException {
final var input = randomAlphaOfLength(3);
final var query = randomAlphaOfLength(3);
final var apikey = randomAlphaOfLength(3);
final var request = createRequest(
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
apikey,
null,
TOP_N,
query,
input
);
final var httpPost = getHttpPost(request, apikey);
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get(QUERY), is(query));
assertThat(requestMap.get(INPUT), is(List.of(input)));
assertThat(requestMap.get(TOP_N_FIELD), is(TOP_N));
}
public void testCreateRequest_WithCohereProviderTokenEndpoint_WithReturnDocumentsParam() throws IOException {
final var input = randomAlphaOfLength(3);
final var query = randomAlphaOfLength(3);
final var apikey = randomAlphaOfLength(3);
final var request = createRequest(
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
apikey,
true,
null,
query,
input
);
final var httpPost = getHttpPost(request, apikey);
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get(QUERY), is(query));
assertThat(requestMap.get(INPUT), is(List.of(input)));
assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD), is(true));
}
private HttpPost getHttpPost(AzureAiStudioRerankRequest request, String apikey) {
final var httpRequest = request.createHttpRequest();
final var httpPost = validateRequestUrlAndContentType(httpRequest, TARGET_URI + "/v1/rerank");
validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey);
return httpPost;
}
private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) {
assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
final var httpPost = (HttpPost) request.httpRequestBase();
assertThat(httpPost.getURI().toString(), is(expectedUrl));
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
return httpPost;
}
private void validateRequestApiKey(
HttpPost httpPost,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
String apiKey
) {
if (endpointType == AzureAiStudioEndpointType.TOKEN) {
if (provider == AzureAiStudioProvider.OPENAI) {
assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
} else {
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey));
}
} else {
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey));
}
}
public static AzureAiStudioRerankRequest createRequest(
String target,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
String apiKey,
String query,
String input
) {
return createRequest(target, provider, endpointType, apiKey, null, null, query, input);
}
public static AzureAiStudioRerankRequest createRequest(
String target,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
String apiKey,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
String query,
String input
) {
final var model = AzureAiStudioRerankModelTests.createModel(
"id",
target,
provider,
endpointType,
apiKey,
returnDocuments,
topN,
null
);
return new AzureAiStudioRerankRequest(model, query, List.of(input), returnDocuments, topN);
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.net.URISyntaxException;
import static org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
public class AzureAiStudioRerankModelTests extends ESTestCase {
private static final String MODEL_ID = "id";
private static final String TARGET_URI = "http://testtarget.local";
private static final String API_KEY = "apikey";
private static final Integer TOP_N = 1;
private static final Integer TOP_N_OVERRIDE = 2;
public void testOverrideWith_OverridesWithoutValues() {
final var model = createModel(
MODEL_ID,
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
API_KEY,
true,
TOP_N,
null
);
final var requestTaskSettingsMap = getTaskSettingsMap(null, null);
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettingsMap);
assertThat(overriddenModel, sameInstance(overriddenModel));
}
public void testOverrideWith_returnDocuments() {
final var model = createModel(
MODEL_ID,
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
API_KEY,
true,
null,
null
);
final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(false, null);
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
assertThat(
overriddenModel,
is(createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY, false, null, null))
);
}
public void testOverrideWith_topN() {
final var model = createModel(
MODEL_ID,
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
API_KEY,
null,
TOP_N,
null
);
final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(null, TOP_N_OVERRIDE);
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
assertThat(
overriddenModel,
is(
createModel(
MODEL_ID,
TARGET_URI,
AzureAiStudioProvider.COHERE,
AzureAiStudioEndpointType.TOKEN,
API_KEY,
null,
TOP_N_OVERRIDE,
null
)
)
);
}
public void testSetsProperUrlForCohereTokenModel() throws URISyntaxException {
final var model = createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY);
assertThat(model.getEndpointUri().toString(), is(TARGET_URI + "/v1/rerank"));
}
public static AzureAiStudioRerankModel createModel(
String id,
String target,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
String apiKey
) {
return createModel(id, target, provider, endpointType, apiKey, null, null, null);
}
public static AzureAiStudioRerankModel createModel(
String id,
String target,
AzureAiStudioProvider provider,
AzureAiStudioEndpointType endpointType,
String apiKey,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
@Nullable RateLimitSettings rateLimitSettings
) {
return new AzureAiStudioRerankModel(
id,
new AzureAiStudioRerankServiceSettings(target, provider, endpointType, rateLimitSettings),
new AzureAiStudioRerankTaskSettings(returnDocuments, topN),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.MatcherAssert;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
public class AzureAiStudioRerankRequestTaskSettingsTests extends ESTestCase {
private static final String INVALID_FIELD_TYPE_STRING = "invalid";
private static final boolean RETURN_DOCUMENTS = true;
private static final int TOP_N = 2;
public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() {
assertThat(
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of())),
is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
);
}
public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
assertThat(
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))),
is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
);
}
public void testFromMap_ReturnsReturnDocuments() {
assertThat(
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, RETURN_DOCUMENTS))),
is(new AzureAiStudioRerankRequestTaskSettings(RETURN_DOCUMENTS, null))
);
}
public void testFromMap_ReturnsTopN() {
assertThat(
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_N_FIELD, TOP_N))),
is(new AzureAiStudioRerankRequestTaskSettings(null, TOP_N))
);
}
public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
}
public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
}
private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
final var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
);
MatcherAssert.assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"field ["
+ field
+ "] is not of the expected type. The value ["
+ INVALID_FIELD_TYPE_STRING
+ "] cannot be converted to a "
)
)
);
}
}

View File

@ -0,0 +1,123 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.CoreMatchers;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType.TOKEN;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider.COHERE;
import static org.hamcrest.Matchers.is;
public class AzureAiStudioRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankServiceSettings> {
private static final String TARGET_URI = "http://testtarget.local";
public void testFromMap_Request_CreatesSettingsCorrectly() {
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
ConfigurationParseContext.REQUEST
);
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
}
public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() {
final var settingsMap = createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name());
settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)));
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST);
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3))));
}
public void testFromMap_Persistent_CreatesSettingsCorrectly() {
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
ConfigurationParseContext.PERSISTENT
);
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
}
public void testToXContent_WritesAllValues() throws IOException {
final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
settings.toXContent(builder, null);
final String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is("""
{"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
"rate_limit":{"requests_per_minute":3}}"""));
}
public void testToFilteredXContent_WritesAllValues() throws IOException {
final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
final var filteredXContent = settings.getFilteredXContentObject();
filteredXContent.toXContent(builder, null);
final String xContentResult = Strings.toString(builder);
assertThat(xContentResult, CoreMatchers.is("""
{"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
"rate_limit":{"requests_per_minute":3}}"""));
}
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
return new HashMap<>(Map.of(TARGET_FIELD, target, PROVIDER_FIELD, provider, ENDPOINT_TYPE_FIELD, endpointType));
}
@Override
protected Writeable.Reader<AzureAiStudioRerankServiceSettings> instanceReader() {
return AzureAiStudioRerankServiceSettings::new;
}
@Override
protected AzureAiStudioRerankServiceSettings createTestInstance() {
return createRandom();
}
@Override
protected AzureAiStudioRerankServiceSettings mutateInstance(AzureAiStudioRerankServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, AzureAiStudioRerankServiceSettingsTests::createRandom);
}
@Override
protected AzureAiStudioRerankServiceSettings mutateInstanceForVersion(
AzureAiStudioRerankServiceSettings instance,
TransportVersion version
) {
return instance;
}
private static AzureAiStudioRerankServiceSettings createRandom() {
return new AzureAiStudioRerankServiceSettings(
randomAlphaOfLength(10),
randomFrom(AzureAiStudioProvider.values()),
randomFrom(AzureAiStudioEndpointType.values()),
RateLimitSettingsTests.createRandom()
);
}
}

View File

@ -0,0 +1,230 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
public class AzureAiStudioRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankTaskSettings> {
private static final String INVALID_FIELD_TYPE_STRING = "invalid";
public void testIsEmpty() {
final var randomSettings = createRandom();
final var stringRep = Strings.toString(randomSettings);
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
}
public void testUpdatedTaskSettings_WithAllValues() {
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
assertUpdateSettings(newSettings, initialSettings);
}
public void testUpdatedTaskSettings_WithReturnDocumentsValue() {
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
assertUpdateSettings(newSettings, initialSettings);
}
public void testUpdatedTaskSettings_WithTopNValue() {
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
assertUpdateSettings(newSettings, initialSettings);
}
public void testUpdatedTaskSettings_WithNoValues() {
AzureAiStudioRerankTaskSettings initialSettings = createRandom();
final AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(null, null);
assertUpdateSettings(newSettings, initialSettings);
}
private void assertUpdateSettings(AzureAiStudioRerankTaskSettings newSettings, AzureAiStudioRerankTaskSettings initialSettings) {
final var settingsMap = new HashMap<String, Object>();
if (newSettings.returnDocuments() != null) settingsMap.put(RETURN_DOCUMENTS_FIELD, newSettings.returnDocuments());
if (newSettings.topN() != null) settingsMap.put(TOP_N_FIELD, newSettings.topN());
final AzureAiStudioRerankTaskSettings updatedSettings = (AzureAiStudioRerankTaskSettings) initialSettings.updatedTaskSettings(
Collections.unmodifiableMap(settingsMap)
);
assertEquals(
newSettings.returnDocuments() == null ? initialSettings.returnDocuments() : newSettings.returnDocuments(),
updatedSettings.returnDocuments()
);
assertEquals(newSettings.topN() == null ? initialSettings.topN() : newSettings.topN(), updatedSettings.topN());
}
public void testFromMap_AllValues() {
assertEquals(new AzureAiStudioRerankTaskSettings(true, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2)));
}
public void testFromMap_ReturnDocuments() {
assertEquals(
new AzureAiStudioRerankTaskSettings(true, null),
AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null))
);
}
public void testFromMap_TopN() {
assertEquals(new AzureAiStudioRerankTaskSettings(null, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2)));
}
public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
getTaskSettingsMap(true, 2).put(RETURN_DOCUMENTS_FIELD, INVALID_FIELD_TYPE_STRING);
assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
}
public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
getTaskSettingsMap(true, 2).put(TOP_N_FIELD, INVALID_FIELD_TYPE_STRING);
assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
}
public void testFromMap_WithNoValues_DoesNotThrowException() {
final var taskMap = AzureAiStudioRerankTaskSettings.fromMap(new HashMap<>(Map.of()));
assertNull(taskMap.returnDocuments());
assertNull(taskMap.topN());
}
public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2));
final var overrideSettings = AzureAiStudioRerankTaskSettings.of(settings, AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS);
MatcherAssert.assertThat(overrideSettings, is(settings));
}
public void testOverrideWith_UsesReturnDocumentsOverride() {
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null));
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(false, null));
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(false, null)));
}
public void testOverrideWith_UsesTopNOverride() {
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2));
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(null, 1));
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(null, 1)));
}
public void testOverrideWith_UsesAllParametersOverride() {
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(false, 2));
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(true, 1));
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(true, 1)));
}
public void testToXContent_WithoutParameters() throws IOException {
assertThat(getXContentResult(null, null), is("{}"));
}
public void testToXContent_WithReturnDocumentsParameter() throws IOException {
assertThat(getXContentResult(true, null), is("""
{"return_documents":true}"""));
}
public void testToXContent_WithTopNParameter() throws IOException {
assertThat(getXContentResult(null, 2), is("""
{"top_n":2}"""));
}
public void testToXContent_WithParameters() throws IOException {
assertThat(getXContentResult(true, 2), is("""
{"return_documents":true,"top_n":2}"""));
}
private String getXContentResult(Boolean returnDocuments, Integer topN) throws IOException {
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(returnDocuments, topN));
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
settings.toXContent(builder, null);
return Strings.toString(builder);
}
public static Map<String, Object> getTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
final var map = new HashMap<String, Object>();
if (returnDocuments != null) {
map.put(RETURN_DOCUMENTS_FIELD, returnDocuments);
}
if (topN != null) {
map.put(TOP_N_FIELD, topN);
}
return map;
}
@Override
protected Writeable.Reader<AzureAiStudioRerankTaskSettings> instanceReader() {
return AzureAiStudioRerankTaskSettings::new;
}
@Override
protected AzureAiStudioRerankTaskSettings createTestInstance() {
return createRandom();
}
@Override
protected AzureAiStudioRerankTaskSettings mutateInstance(AzureAiStudioRerankTaskSettings instance) throws IOException {
return randomValueOtherThan(instance, AzureAiStudioRerankTaskSettingsTests::createRandom);
}
@Override
protected AzureAiStudioRerankTaskSettings mutateInstanceForVersion(AzureAiStudioRerankTaskSettings instance, TransportVersion version) {
return instance;
}
private static AzureAiStudioRerankTaskSettings createRandom() {
return new AzureAiStudioRerankTaskSettings(
randomFrom(new Boolean[] { null, randomBoolean() }),
randomFrom(new Integer[] { null, randomNonNegativeInt() })
);
}
private static AzureAiStudioRerankTaskSettings createRandom(AzureAiStudioRerankTaskSettings settings) {
return new AzureAiStudioRerankTaskSettings(
randomValueOtherThan(settings.returnDocuments(), () -> randomFrom(new Boolean[] { null, randomBoolean() })),
randomValueOtherThan(settings.topN(), () -> randomFrom(new Integer[] { null, randomNonNegativeInt() }))
);
}
private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
final var thrownException = expectThrows(
ValidationException.class,
() -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
);
MatcherAssert.assertThat(
thrownException.getMessage(),
containsString(
Strings.format(
"field ["
+ field
+ "] is not of the expected type. The value ["
+ INVALID_FIELD_TYPE_STRING
+ "] cannot be converted to a "
)
)
);
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.azureaistudio.response;
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
public class AzureAiStudioRerankResponseEntityTests extends ESTestCase {
public void testResponse_WithDocuments() throws IOException {
final String responseJson = getResponseJsonWithDocuments();
final var parsedResults = getParsedResults(responseJson);
final var expectedResults = List.of(
new RankedDocsResults.RankedDoc(0, 0.1111111F, "test text one"),
new RankedDocsResults.RankedDoc(1, 0.2222222F, "test text two")
);
assertThat(parsedResults.getRankedDocs(), is(expectedResults));
}
public void testResponse_NoDocuments() throws IOException {
final String responseJson = getResponseJsonNoDocuments();
final var parsedResults = getParsedResults(responseJson);
final var expectedResults = List.of(
new RankedDocsResults.RankedDoc(0, 0.1111111F, null),
new RankedDocsResults.RankedDoc(1, 0.2222222F, null)
);
assertThat(parsedResults.getRankedDocs(), is(expectedResults));
}
private RankedDocsResults getParsedResults(String responseJson) throws IOException {
final var entity = new AzureAiStudioRerankResponseEntity();
return (RankedDocsResults) entity.apply(
mock(Request.class),
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
}
private String getResponseJsonWithDocuments() {
return """
{
"id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
"results": [
{
"document": {
"text": "test text one"
},
"index": 0,
"relevance_score": 0.1111111
},
{
"document": {
"text": "test text two"
},
"index": 1,
"relevance_score": 0.2222222
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
""";
}
private String getResponseJsonNoDocuments() {
return """
{
"id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
"results": [
{
"index": 0,
"relevance_score": 0.1111111
},
{
"index": 1,
"relevance_score": 0.2222222
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
""";
}
}