Add Azure AI Rerank support (#129848)
* Add Azure AI Rerank support * address comments * address comments * refactor azure ai studio service * update rerank task settings test * add provider for rerank
This commit is contained in:
parent
f9eee6c216
commit
d06b0c8c17
|
@ -0,0 +1,5 @@
|
|||
pr: 129848
|
||||
summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -341,6 +341,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
|
||||
public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
|
||||
public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -111,6 +111,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"azureaistudio",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
"googlevertexai",
|
||||
|
|
|
@ -50,6 +50,8 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
|
|||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
|
||||
|
@ -306,6 +308,17 @@ public class InferenceNamedWriteablesProvider {
|
|||
AzureAiStudioChatCompletionTaskSettings::new
|
||||
)
|
||||
);
|
||||
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
AzureAiStudioRerankServiceSettings.NAME,
|
||||
AzureAiStudioRerankServiceSettings::new
|
||||
)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
|
||||
);
|
||||
}
|
||||
|
||||
private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio;
|
|||
public class AzureAiStudioConstants {
|
||||
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
|
||||
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
|
||||
public static final String RERANK_URI_PATH = "/v1/rerank";
|
||||
|
||||
// common service settings fields
|
||||
public static final String TARGET_FIELD = "target";
|
||||
|
@ -22,6 +23,10 @@ public class AzureAiStudioConstants {
|
|||
public static final String DIMENSIONS_FIELD = "dimensions";
|
||||
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
|
||||
|
||||
// rerank task settings fields
|
||||
public static final String DOCUMENTS_FIELD = "documents";
|
||||
public static final String QUERY_FIELD = "query";
|
||||
|
||||
// embeddings task settings fields
|
||||
public static final String USER_FIELD = "user";
|
||||
|
||||
|
@ -35,5 +40,9 @@ public class AzureAiStudioConstants {
|
|||
public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
|
||||
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
|
||||
|
||||
// rerank task settings fields
|
||||
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
|
||||
public static final String TOP_N_FIELD = "top_n";
|
||||
|
||||
private AzureAiStudioConstants() {}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,9 @@ public final class AzureAiStudioProviderCapabilities {
|
|||
// these providers have chat completion inference (all providers at the moment)
|
||||
public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());
|
||||
|
||||
// these providers have rerank inference
|
||||
public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);
|
||||
|
||||
// these providers allow token ("pay as you go") embeddings endpoints
|
||||
public static final List<AzureAiStudioProvider> tokenEmbeddingsProviders = List.of(
|
||||
AzureAiStudioProvider.OPENAI,
|
||||
|
@ -31,6 +34,9 @@ public final class AzureAiStudioProviderCapabilities {
|
|||
// these providers allow realtime embeddings endpoints (none at the moment)
|
||||
public static final List<AzureAiStudioProvider> realtimeEmbeddingsProviders = List.of();
|
||||
|
||||
// these providers allow realtime rerank endpoints (none at the moment)
|
||||
public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();
|
||||
|
||||
// these providers allow token ("pay as you go") chat completion endpoints
|
||||
public static final List<AzureAiStudioProvider> tokenChatCompletionProviders = List.of(
|
||||
AzureAiStudioProvider.OPENAI,
|
||||
|
@ -54,6 +60,9 @@ public final class AzureAiStudioProviderCapabilities {
|
|||
case TEXT_EMBEDDING -> {
|
||||
return embeddingProviders.contains(provider);
|
||||
}
|
||||
case RERANK -> {
|
||||
return rerankProviders.contains(provider);
|
||||
}
|
||||
default -> {
|
||||
return false;
|
||||
}
|
||||
|
@ -76,6 +85,11 @@ public final class AzureAiStudioProviderCapabilities {
|
|||
? tokenEmbeddingsProviders.contains(provider)
|
||||
: realtimeEmbeddingsProviders.contains(provider);
|
||||
}
|
||||
case RERANK -> {
|
||||
return (endpointType == AzureAiStudioEndpointType.TOKEN)
|
||||
? rerankProviders.contains(provider)
|
||||
: realtimeRerankProviders.contains(provider);
|
||||
}
|
||||
default -> {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
|
||||
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
|
||||
private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);
|
||||
|
||||
private static final ResponseHandler HANDLER = createRerankHandler();
|
||||
|
||||
private final AzureAiStudioRerankModel model;
|
||||
|
||||
public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
|
||||
super(threadPool, model);
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(
|
||||
InferenceInputs inferenceInputs,
|
||||
RequestSender requestSender,
|
||||
Supplier<Boolean> hasRequestRerankFunction,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
||||
AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
|
||||
model,
|
||||
rerankInput.getQuery(),
|
||||
rerankInput.getChunks(),
|
||||
rerankInput.getReturnDocuments(),
|
||||
rerankInput.getTopN()
|
||||
);
|
||||
|
||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
|
||||
}
|
||||
|
||||
private static ResponseHandler createRerankHandler() {
|
||||
// This currently covers response handling for Azure AI Studio
|
||||
return new AzureMistralOpenAiExternalResponseHandler(
|
||||
"azure ai studio rerank",
|
||||
new AzureAiStudioRerankResponseEntity(),
|
||||
ErrorMessageResponseEntity::fromResponse,
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
|
@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
|
|||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -71,10 +72,10 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFie
|
|||
|
||||
public class AzureAiStudioService extends SenderService {
|
||||
|
||||
static final String NAME = "azureaistudio";
|
||||
public static final String NAME = "azureaistudio";
|
||||
|
||||
private static final String SERVICE_NAME = "Azure AI Studio";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK);
|
||||
|
||||
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
|
||||
InputType.INGEST,
|
||||
|
@ -270,8 +271,9 @@ public class AzureAiStudioService extends SenderService {
|
|||
ConfigurationParseContext context
|
||||
) {
|
||||
|
||||
if (taskType == TaskType.TEXT_EMBEDDING) {
|
||||
var embeddingsModel = new AzureAiStudioEmbeddingsModel(
|
||||
AzureAiStudioModel model;
|
||||
switch (taskType) {
|
||||
case TEXT_EMBEDDING -> model = new AzureAiStudioEmbeddingsModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
|
@ -281,16 +283,7 @@ public class AzureAiStudioService extends SenderService {
|
|||
secretSettings,
|
||||
context
|
||||
);
|
||||
checkProviderAndEndpointTypeForTask(
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
embeddingsModel.getServiceSettings().provider(),
|
||||
embeddingsModel.getServiceSettings().endpointType()
|
||||
);
|
||||
return embeddingsModel;
|
||||
}
|
||||
|
||||
if (taskType == TaskType.COMPLETION) {
|
||||
var completionModel = new AzureAiStudioChatCompletionModel(
|
||||
case COMPLETION -> model = new AzureAiStudioChatCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
|
@ -299,15 +292,12 @@ public class AzureAiStudioService extends SenderService {
|
|||
secretSettings,
|
||||
context
|
||||
);
|
||||
checkProviderAndEndpointTypeForTask(
|
||||
TaskType.COMPLETION,
|
||||
completionModel.getServiceSettings().provider(),
|
||||
completionModel.getServiceSettings().endpointType()
|
||||
);
|
||||
return completionModel;
|
||||
case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
}
|
||||
|
||||
throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
|
||||
checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
|
||||
return model;
|
||||
}
|
||||
|
||||
private AzureAiStudioModel createModelFromPersistent(
|
||||
|
|
|
@ -13,8 +13,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
|||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -49,4 +51,12 @@ public class AzureAiStudioActionCreator implements AzureAiStudioActionVisitor {
|
|||
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
|
||||
return new SenderExecutableAction(sender, requestManager, errorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
|
||||
var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
|
||||
var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
|
||||
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
|
||||
return new SenderExecutableAction(sender, requestManager, errorMessage);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio.action;
|
|||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -17,4 +18,6 @@ public interface AzureAiStudioActionVisitor {
|
|||
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
|
||||
|
||||
ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
|
||||
|
||||
ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class AzureAiStudioRerankRequest extends AzureAiStudioRequest {
|
||||
private final String query;
|
||||
private final List<String> input;
|
||||
private final Boolean returnDocuments;
|
||||
private final Integer topN;
|
||||
private final AzureAiStudioRerankModel rerankModel;
|
||||
|
||||
public AzureAiStudioRerankRequest(
|
||||
AzureAiStudioRerankModel model,
|
||||
String query,
|
||||
List<String> input,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN
|
||||
) {
|
||||
super(model);
|
||||
this.rerankModel = Objects.requireNonNull(model);
|
||||
this.query = query;
|
||||
this.input = Objects.requireNonNull(input);
|
||||
this.returnDocuments = returnDocuments;
|
||||
this.topN = topN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(this.uri);
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8));
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
|
||||
setAuthHeader(httpPost, rerankModel);
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
// Not applicable for rerank, only used in text embedding requests
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
// Not applicable for rerank, only used in text embedding requests
|
||||
return null;
|
||||
}
|
||||
|
||||
private AzureAiStudioRerankRequestEntity createRequestEntity() {
|
||||
return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, rerankModel.getTaskSettings());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.QUERY_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
|
||||
public record AzureAiStudioRerankRequestEntity(
|
||||
String query,
|
||||
List<String> input,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
AzureAiStudioRerankTaskSettings taskSettings
|
||||
) implements ToXContentObject {
|
||||
|
||||
public AzureAiStudioRerankRequestEntity {
|
||||
Objects.requireNonNull(query);
|
||||
Objects.requireNonNull(input);
|
||||
Objects.requireNonNull(taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.field(DOCUMENTS_FIELD, input);
|
||||
builder.field(QUERY_FIELD, query);
|
||||
|
||||
if (returnDocuments != null) {
|
||||
builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
|
||||
} else if (taskSettings.returnDocuments() != null) {
|
||||
builder.field(RETURN_DOCUMENTS_FIELD, taskSettings.returnDocuments());
|
||||
}
|
||||
|
||||
if (topN != null) {
|
||||
builder.field(TOP_N_FIELD, topN);
|
||||
} else if (taskSettings.topN() != null) {
|
||||
builder.field(TOP_N_FIELD, taskSettings.topN());
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.action.AzureAiStudioActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RERANK_URI_PATH;
|
||||
|
||||
public class AzureAiStudioRerankModel extends AzureAiStudioModel {
|
||||
|
||||
public static AzureAiStudioRerankModel of(AzureAiStudioRerankModel model, Map<String, Object> taskSettings) {
|
||||
if (taskSettings == null || taskSettings.isEmpty()) {
|
||||
return model;
|
||||
}
|
||||
|
||||
final var requestTaskSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(taskSettings);
|
||||
final var taskSettingToUse = AzureAiStudioRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
|
||||
|
||||
return new AzureAiStudioRerankModel(model, taskSettingToUse);
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankModel(
|
||||
String inferenceEntityId,
|
||||
AzureAiStudioRerankServiceSettings serviceSettings,
|
||||
AzureAiStudioRerankTaskSettings taskSettings,
|
||||
DefaultSecretSettings secrets
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, TaskType.RERANK, AzureAiStudioService.NAME, serviceSettings, taskSettings),
|
||||
new ModelSecrets(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankModel(
|
||||
String inferenceEntityId,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
AzureAiStudioRerankServiceSettings.fromMap(serviceSettings, context),
|
||||
AzureAiStudioRerankTaskSettings.fromMap(taskSettings),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankModel(AzureAiStudioRerankModel model, AzureAiStudioRerankTaskSettings taskSettings) {
|
||||
super(model, taskSettings, model.getServiceSettings().rateLimitSettings());
|
||||
}
|
||||
|
||||
@Override
|
||||
public AzureAiStudioRerankServiceSettings getServiceSettings() {
|
||||
return (AzureAiStudioRerankServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AzureAiStudioRerankTaskSettings getTaskSettings() {
|
||||
return (AzureAiStudioRerankTaskSettings) super.getTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return super.getSecretSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected URI getEndpointUri() throws URISyntaxException {
|
||||
return new URI(this.target + RERANK_URI_PATH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map<String, Object> taskSettings) {
|
||||
return creator.create(this, taskSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
|
||||
public record AzureAiStudioRerankRequestTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
|
||||
|
||||
public static final AzureAiStudioRerankRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioRerankRequestTaskSettings(null, null);
|
||||
|
||||
/**
|
||||
* Extracts the task settings from a map. All settings are considered optional and the absence of a setting
|
||||
* does not throw an error.
|
||||
*
|
||||
* @param map the settings received from a request
|
||||
* @return a {@link AzureAiStudioRerankRequestTaskSettings}
|
||||
*/
|
||||
public static AzureAiStudioRerankRequestTaskSettings fromMap(Map<String, Object> map) {
|
||||
if (map.isEmpty()) {
|
||||
return AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS;
|
||||
}
|
||||
|
||||
final var validationException = new ValidationException();
|
||||
|
||||
final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
|
||||
final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new AzureAiStudioRerankRequestTaskSettings(returnDocuments, topN);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class AzureAiStudioRerankServiceSettings extends AzureAiStudioServiceSettings {
|
||||
public static final String NAME = "azure_ai_studio_rerank_service_settings";
|
||||
|
||||
public static AzureAiStudioRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
final var validationException = new ValidationException();
|
||||
|
||||
final var settings = rerankSettingsFromMap(map, validationException, context);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new AzureAiStudioRerankServiceSettings(settings);
|
||||
}
|
||||
|
||||
private static AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields rerankSettingsFromMap(
|
||||
Map<String, Object> map,
|
||||
ValidationException validationException,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
final var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
|
||||
return new AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields(baseSettings);
|
||||
}
|
||||
|
||||
private record AzureAiStudioRerankCommonFields(BaseAzureAiStudioCommonFields baseCommonFields) {}
|
||||
|
||||
public AzureAiStudioRerankServiceSettings(
|
||||
String target,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
@Nullable RateLimitSettings rateLimitSettings
|
||||
) {
|
||||
super(target, provider, endpointType, rateLimitSettings);
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankServiceSettings(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
private AzureAiStudioRerankServiceSettings(AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields fields) {
|
||||
this(
|
||||
fields.baseCommonFields.target(),
|
||||
fields.baseCommonFields.provider(),
|
||||
fields.baseCommonFields.endpointType(),
|
||||
fields.baseCommonFields.rateLimitSettings()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
super.addXContentFields(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
super.addExposedXContentFields(builder, params);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AzureAiStudioRerankServiceSettings that = (AzureAiStudioRerankServiceSettings) o;
|
||||
|
||||
return Objects.equals(target, that.target)
|
||||
&& Objects.equals(provider, that.provider)
|
||||
&& Objects.equals(endpointType, that.endpointType)
|
||||
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(target, provider, endpointType, rateLimitSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
|
||||
/**
|
||||
* Defines the rerank task settings for the AzureAiStudio service.
|
||||
*/
|
||||
public class AzureAiStudioRerankTaskSettings implements TaskSettings {
|
||||
public static final String NAME = "azure_ai_studio_rerank_task_settings";
|
||||
|
||||
public static AzureAiStudioRerankTaskSettings fromMap(Map<String, Object> map) {
|
||||
final var validationException = new ValidationException();
|
||||
|
||||
final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
|
||||
final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@link AzureAiStudioRerankTaskSettings} object by overriding the values in originalSettings with the ones
|
||||
* passed in via requestSettings if the fields are not null.
|
||||
* @param originalSettings the original {@link AzureAiStudioRerankTaskSettings} from the inference entity configuration from storage
|
||||
* @param requestSettings the {@link AzureAiStudioRerankTaskSettings} from the request
|
||||
* @return a new {@link AzureAiStudioRerankTaskSettings}
|
||||
*/
|
||||
public static AzureAiStudioRerankTaskSettings of(
|
||||
AzureAiStudioRerankTaskSettings originalSettings,
|
||||
AzureAiStudioRerankRequestTaskSettings requestSettings
|
||||
) {
|
||||
|
||||
final var returnDocuments = requestSettings.returnDocuments() == null
|
||||
? originalSettings.returnDocuments()
|
||||
: requestSettings.returnDocuments();
|
||||
final var topN = requestSettings.topN() == null ? originalSettings.topN() : requestSettings.topN();
|
||||
|
||||
return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
|
||||
this.returnDocuments = returnDocuments;
|
||||
this.topN = topN;
|
||||
}
|
||||
|
||||
public AzureAiStudioRerankTaskSettings(StreamInput in) throws IOException {
|
||||
this.returnDocuments = in.readOptionalBoolean();
|
||||
this.topN = in.readOptionalVInt();
|
||||
}
|
||||
|
||||
private final Boolean returnDocuments;
|
||||
private final Integer topN;
|
||||
|
||||
public Boolean returnDocuments() {
|
||||
return returnDocuments;
|
||||
}
|
||||
|
||||
public Integer topN() {
|
||||
return topN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return returnDocuments == null && topN == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalBoolean(returnDocuments);
|
||||
out.writeOptionalVInt(topN);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
if (returnDocuments != null) {
|
||||
builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
|
||||
}
|
||||
if (topN != null) {
|
||||
builder.field(TOP_N_FIELD, topN);
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "AzureAiStudioRerankTaskSettings{" + ", returnDocuments=" + returnDocuments + ", topN=" + topN + '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AzureAiStudioRerankTaskSettings that = (AzureAiStudioRerankTaskSettings) o;
|
||||
return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(returnDocuments, topN);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
AzureAiStudioRerankRequestTaskSettings requestSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(newSettings));
|
||||
return of(this, requestSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.response;
|
||||
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class AzureAiStudioRerankResponseEntity extends BaseResponseEntity {
|
||||
/**
|
||||
* Parses the AzureAiStudio Search rerank json response.
|
||||
* For a request like:
|
||||
*
|
||||
* <pre>
|
||||
* <code>
|
||||
* {
|
||||
* "model": "rerank-v3.5",
|
||||
* "query": "What is the capital of the United States?",
|
||||
* "top_n": 2,
|
||||
* "documents": ["Carson City is the capital city of the American state of Nevada.",
|
||||
* "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."]
|
||||
* }
|
||||
* </code>
|
||||
* </pre>
|
||||
*
|
||||
* The response would look like:
|
||||
*
|
||||
* <pre>
|
||||
* <code>
|
||||
* {
|
||||
* "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
|
||||
* "results": [
|
||||
* {
|
||||
* "document": {
|
||||
* "text": "Carson City is the capital city of the American state of Nevada."
|
||||
* },
|
||||
* "index": 0,
|
||||
* "relevance_score": 0.1728413
|
||||
* },
|
||||
* {
|
||||
* "document": {
|
||||
* "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."
|
||||
* },
|
||||
* "index": 1,
|
||||
* "relevance_score": 0.031005697
|
||||
* }
|
||||
* ],
|
||||
* "meta": {
|
||||
* "api_version": {
|
||||
* "version": "1"
|
||||
* },
|
||||
* "billed_units": {
|
||||
* "search_units": 1
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* </code>
|
||||
* </pre>
|
||||
*/
|
||||
@Override
|
||||
protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
|
||||
final var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||
|
||||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
|
||||
var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
|
||||
return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList());
|
||||
}
|
||||
}
|
||||
|
||||
record RerankResult(List<RerankResultEntry> entries) {
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
|
||||
RerankResult.class.getSimpleName(),
|
||||
true,
|
||||
args -> new RerankResult((List<RerankResultEntry>) args[0])
|
||||
);
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
|
||||
}
|
||||
}
|
||||
|
||||
record RerankResultEntry(Float relevanceScore, Integer index, @Nullable ObjectParser document) {
|
||||
|
||||
public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
|
||||
RerankResultEntry.class.getSimpleName(),
|
||||
args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (ObjectParser) args[2])
|
||||
);
|
||||
static {
|
||||
PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
|
||||
PARSER.declareInt(constructorArg(), new ParseField("index"));
|
||||
PARSER.declareObject(optionalConstructorArg(), ObjectParser.PARSER::apply, new ParseField("document"));
|
||||
}
|
||||
public RankedDocsResults.RankedDoc toRankedDoc() {
|
||||
return new RankedDocsResults.RankedDoc(index, relevanceScore, document == null ? null : document.text);
|
||||
}
|
||||
}
|
||||
|
||||
record ObjectParser(String text) {
|
||||
public static final ConstructingObjectParser<ObjectParser, Void> PARSER = new ConstructingObjectParser<>(
|
||||
ObjectParser.class.getSimpleName(),
|
||||
args -> new AzureAiStudioRerankResponseEntity.ObjectParser((String) args[0])
|
||||
);
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), new ParseField("text"));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -38,6 +38,7 @@ import org.elasticsearch.xcontent.XContentType;
|
|||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
|
@ -54,6 +55,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
|
|||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.hamcrest.CoreMatchers;
|
||||
import org.hamcrest.Matchers;
|
||||
|
@ -219,6 +224,33 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
|
||||
var rerankModel = (AzureAiStudioRerankModel) model;
|
||||
assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
|
||||
assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
||||
assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
||||
assertThat(rerankModel.getSecretSettings().apiKey().toString(), is("secret"));
|
||||
assertNull(rerankModel.getTaskSettings().returnDocuments());
|
||||
assertNull(rerankModel.getTaskSettings().topN());
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.RERANK,
|
||||
getRequestConfigMap(
|
||||
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
||||
getRerankTaskSettingsMap(null, null),
|
||||
getSecretSettingsMap("secret")
|
||||
),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
|
@ -441,6 +473,80 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
||||
serviceSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var taskSettings = getRerankTaskSettingsMap(null, null);
|
||||
taskSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(
|
||||
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
||||
taskSettings,
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
secretSettings.put("extra_key", "value");
|
||||
|
||||
var config = getRequestConfigMap(
|
||||
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
||||
getRerankTaskSettingsMap(null, null),
|
||||
secretSettings
|
||||
);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null);
|
||||
|
@ -505,6 +611,45 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForRerank() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "databricks", "token");
|
||||
|
||||
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(exception.getMessage(), is("The [rerank] task type for provider [databricks] is not available"));
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "realtime");
|
||||
|
||||
var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available")
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
|
@ -603,6 +748,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
||||
getRerankTaskSettingsMap(true, 2),
|
||||
getSecretSettingsMap("secret")
|
||||
);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
|
||||
var chatCompletionModel = (AzureAiStudioRerankModel) model;
|
||||
assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
|
||||
assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
||||
assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
||||
assertThat(chatCompletionModel.getTaskSettings().returnDocuments(), is(true));
|
||||
assertThat(chatCompletionModel.getTaskSettings().topN(), is(2));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException {
|
||||
try (var service = createService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
|
@ -747,6 +913,48 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
||||
serviceSettings.put("extra_key", "value");
|
||||
var taskSettings = getRerankTaskSettingsMap(true, 2);
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
||||
var taskSettings = getRerankTaskSettingsMap(true, 2);
|
||||
taskSettings.put("extra_key", "value");
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
||||
var taskSettings = getRerankTaskSettingsMap(true, 2);
|
||||
var secretSettings = getSecretSettingsMap("secret");
|
||||
secretSettings.put("extra_key", "value");
|
||||
var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
||||
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
|
@ -842,6 +1050,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws IOException {
|
||||
try (var service = createService()) {
|
||||
var config = getPersistedConfigMap(
|
||||
getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
||||
getRerankTaskSettingsMap(true, 2),
|
||||
Map.of()
|
||||
);
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.RERANK, config.config());
|
||||
|
||||
assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
||||
|
||||
var rerankModel = (AzureAiStudioRerankModel) model;
|
||||
assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
|
||||
assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
||||
assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
||||
assertThat(rerankModel.getTaskSettings().returnDocuments(), is(true));
|
||||
assertThat(rerankModel.getTaskSettings().topN(), is(2));
|
||||
}
|
||||
}
|
||||
|
||||
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
|
@ -1184,6 +1413,47 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testInfer_WithRerankModel() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
|
||||
|
||||
var model = AzureAiStudioRerankModelTests.createModel(
|
||||
"id",
|
||||
getUrl(webServer),
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
"apikey"
|
||||
);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
service.infer(
|
||||
model,
|
||||
"query",
|
||||
false,
|
||||
2,
|
||||
List.of("abc"),
|
||||
false,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
assertThat(result, CoreMatchers.instanceOf(RankedDocsResults.class));
|
||||
|
||||
var rankedDocsResults = (RankedDocsResults) result;
|
||||
var rankedDocs = rankedDocsResults.getRankedDocs();
|
||||
assertThat(rankedDocs.size(), is(2));
|
||||
assertThat(rankedDocs.get(0).relevanceScore(), is(0.1111111F));
|
||||
assertThat(rankedDocs.get(0).index(), is(0));
|
||||
assertThat(rankedDocs.get(1).relevanceScore(), is(0.2222222F));
|
||||
assertThat(rankedDocs.get(1).index(), is(1));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInfer_UnauthorisedResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
|
@ -1320,7 +1590,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
{
|
||||
"service": "azureaistudio",
|
||||
"name": "Azure AI Studio",
|
||||
"task_types": ["text_embedding", "completion"],
|
||||
"task_types": ["text_embedding", "rerank", "completion"],
|
||||
"configurations": {
|
||||
"dimensions": {
|
||||
"description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-embeddings.",
|
||||
|
@ -1338,7 +1608,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "completion"]
|
||||
"supported_task_types": ["text_embedding", "rerank", "completion"]
|
||||
},
|
||||
"provider": {
|
||||
"description": "The model provider for your deployment.",
|
||||
|
@ -1347,7 +1617,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "completion"]
|
||||
"supported_task_types": ["text_embedding", "rerank", "completion"]
|
||||
},
|
||||
"api_key": {
|
||||
"description": "API Key for the provider you're connecting to.",
|
||||
|
@ -1356,7 +1626,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
"sensitive": true,
|
||||
"updatable": true,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "completion"]
|
||||
"supported_task_types": ["text_embedding", "rerank", "completion"]
|
||||
},
|
||||
"rate_limit.requests_per_minute": {
|
||||
"description": "Minimize the number of rate limit errors.",
|
||||
|
@ -1365,7 +1635,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "int",
|
||||
"supported_task_types": ["text_embedding", "completion"]
|
||||
"supported_task_types": ["text_embedding", "rerank", "completion"]
|
||||
},
|
||||
"target": {
|
||||
"description": "The target URL of your Azure AI Studio model deployment.",
|
||||
|
@ -1374,7 +1644,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "completion"]
|
||||
"supported_task_types": ["text_embedding", "rerank", "completion"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1462,6 +1732,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
return AzureAiStudioChatCompletionServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
|
||||
}
|
||||
|
||||
private static HashMap<String, Object> getRerankServiceSettingsMap(String target, String provider, String endpointType) {
|
||||
return AzureAiStudioRerankServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getChatCompletionTaskSettingsMap(
|
||||
@Nullable Double temperature,
|
||||
@Nullable Double topP,
|
||||
|
@ -1471,6 +1745,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
return AzureAiStudioChatCompletionTaskSettingsTests.getTaskSettingsMap(temperature, topP, doSample, maxNewTokens);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getRerankTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
|
||||
return AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(returnDocuments, topN);
|
||||
}
|
||||
|
||||
private static Map<String, Object> getSecretSettingsMap(String apiKey) {
|
||||
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
|
||||
}
|
||||
|
@ -1520,4 +1798,28 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
""";
|
||||
|
||||
private static final String testRerankTokenResponseJson = """
|
||||
{
|
||||
"id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
|
||||
"results": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.1111111
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"relevance_score": 0.2222222
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"search_units": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.test.http.MockWebServer;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
|
||||
import org.elasticsearch.xpack.inference.InputTypeTests;
|
||||
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
|
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInpu
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
|
||||
|
@ -34,6 +36,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEnd
|
|||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -78,31 +81,20 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testEmbeddingsRequestAction() throws IOException {
|
||||
var senderFactory = new HttpRequestSender.Factory(
|
||||
final var senderFactory = new HttpRequestSender.Factory(
|
||||
ServiceComponentsTests.createWithEmptySettings(threadPool),
|
||||
clientManager,
|
||||
mockClusterServiceEmpty()
|
||||
);
|
||||
|
||||
var timeoutSettings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
|
||||
var serviceComponents = new ServiceComponents(
|
||||
threadPool,
|
||||
mock(ThrottlerManager.class),
|
||||
timeoutSettings,
|
||||
TruncatorTests.createTruncator()
|
||||
);
|
||||
final var serviceComponents = getServiceComponents();
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
|
||||
|
||||
var model = AzureAiStudioEmbeddingsModelTests.createModel(
|
||||
final var model = AzureAiStudioEmbeddingsModelTests.createModel(
|
||||
"id",
|
||||
"http://will-be-replaced.local",
|
||||
AzureAiStudioProvider.OPENAI,
|
||||
|
@ -111,21 +103,18 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
);
|
||||
model.setURI(getUrl(webServer));
|
||||
|
||||
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
|
||||
var action = creator.create(model, Map.of());
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
var inputType = InputTypeTests.randomSearchAndIngestWithNull();
|
||||
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
|
||||
final var action = creator.create(model, Map.of());
|
||||
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
final var inputType = InputTypeTests.randomSearchAndIngestWithNull();
|
||||
action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
final var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||
assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
|
||||
assertWebServerRequest(API_KEY_HEADER, "apikey");
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), is(InputType.isSpecified(inputType) ? 2 : 1));
|
||||
assertThat(requestMap.get("input"), is(List.of("abc")));
|
||||
if (InputType.isSpecified(inputType)) {
|
||||
|
@ -136,27 +125,15 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testChatCompletionRequestAction() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
var timeoutSettings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
|
||||
var serviceComponents = new ServiceComponents(
|
||||
threadPool,
|
||||
mock(ThrottlerManager.class),
|
||||
timeoutSettings,
|
||||
TruncatorTests.createTruncator()
|
||||
);
|
||||
final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
final var serviceComponents = getServiceComponents();
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
|
||||
var webserverUrl = getUrl(webServer);
|
||||
var model = AzureAiStudioChatCompletionModelTests.createModel(
|
||||
final var webserverUrl = getUrl(webServer);
|
||||
final var model = AzureAiStudioChatCompletionModelTests.createModel(
|
||||
"id",
|
||||
"http://will-be-replaced.local",
|
||||
AzureAiStudioProvider.COHERE,
|
||||
|
@ -165,30 +142,101 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
);
|
||||
model.setURI(webserverUrl);
|
||||
|
||||
var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
|
||||
var action = creator.create(model, Map.of());
|
||||
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
|
||||
final var action = creator.create(model, Map.of());
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
final var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string"))));
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
|
||||
MockRequest request = webServer.requests().get(0);
|
||||
assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
|
||||
|
||||
assertNull(request.getUri().getQuery());
|
||||
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("apikey"));
|
||||
|
||||
var requestMap = entityAsMap(request.getBody());
|
||||
final MockRequest request = webServer.requests().get(0);
|
||||
final var requestMap = entityAsMap(request.getBody());
|
||||
assertThat(requestMap.size(), is(1));
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
|
||||
}
|
||||
}
|
||||
|
||||
private static String testEmbeddingsTokenResponseJson = """
|
||||
public void testRerankRequestAction() throws IOException {
|
||||
final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
final var serviceComponents = getServiceComponents();
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
|
||||
final var webserverUrl = getUrl(webServer);
|
||||
final var model = AzureAiStudioRerankModelTests.createModel(
|
||||
"id",
|
||||
"http://will-be-replaced.local",
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
"apikey"
|
||||
);
|
||||
model.setURI(webserverUrl);
|
||||
|
||||
final var topN = 2;
|
||||
final var returnDocuments = false;
|
||||
final var query = "query";
|
||||
final var documents = List.of("document 1", "document 2", "document 3");
|
||||
|
||||
final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
|
||||
final var action = creator.create(model, Map.of());
|
||||
|
||||
final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(
|
||||
new QueryAndDocsInputs(query, documents, returnDocuments, topN, false),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
);
|
||||
|
||||
final var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(
|
||||
result.asMap(),
|
||||
equalTo(
|
||||
RankedDocsResultsTests.buildExpectationRerank(
|
||||
List.of(
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.1111111f)),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.2222222f))
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
|
||||
|
||||
final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("documents"), is(documents));
|
||||
assertThat(requestMap.get("query"), is(query));
|
||||
assertThat(requestMap.get("top_n"), is(topN));
|
||||
assertThat(requestMap.get("return_documents"), is(returnDocuments));
|
||||
}
|
||||
}
|
||||
|
||||
private void assertWebServerRequest(String authorization, String authorizationHeaderValue) {
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||
assertThat(webServer.requests().get(0).getHeader(authorization), equalTo(authorizationHeaderValue));
|
||||
}
|
||||
|
||||
private ServiceComponents getServiceComponents() {
|
||||
final var timeoutSettings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
return new ServiceComponents(threadPool, mock(ThrottlerManager.class), timeoutSettings, TruncatorTests.createTruncator());
|
||||
}
|
||||
|
||||
private final String testEmbeddingsTokenResponseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
|
@ -209,7 +257,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
}
|
||||
""";
|
||||
|
||||
private static String testCompletionTokenResponseJson = """
|
||||
private final String testCompletionTokenResponseJson = """
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
|
@ -233,4 +281,27 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
|
|||
}
|
||||
}""";
|
||||
|
||||
private final String testRerankTokenResponseJson = """
|
||||
{
|
||||
"id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
|
||||
"results": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.1111111
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"relevance_score": 0.2222222
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"search_units": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
|
||||
|
||||
public class AzureAiStudioRerankRequestEntityTests extends ESTestCase {
|
||||
private static final String INPUT = "texts";
|
||||
private static final String QUERY = "query";
|
||||
private static final Boolean RETURN_DOCUMENTS = false;
|
||||
private static final Integer TOP_N = 8;
|
||||
|
||||
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||
final var entity = new AzureAiStudioRerankRequestEntity(
|
||||
QUERY,
|
||||
List.of(INPUT),
|
||||
Boolean.TRUE,
|
||||
TOP_N,
|
||||
new AzureAiStudioRerankTaskSettings(RETURN_DOCUMENTS, TOP_N)
|
||||
);
|
||||
|
||||
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
final String xContentResult = Strings.toString(builder);
|
||||
final String expected = """
|
||||
{"documents":["texts"],
|
||||
"query":"query",
|
||||
"return_documents":true,
|
||||
"top_n":8}""";
|
||||
assertEquals(stripWhitespace(expected), xContentResult);
|
||||
}
|
||||
|
||||
public void testXContent_WritesMinimalFields() throws IOException {
|
||||
final var entity = new AzureAiStudioRerankRequestEntity(
|
||||
QUERY,
|
||||
List.of(INPUT),
|
||||
null,
|
||||
null,
|
||||
new AzureAiStudioRerankTaskSettings(null, null)
|
||||
);
|
||||
|
||||
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
final String xContentResult = Strings.toString(builder);
|
||||
final String expected = """
|
||||
{"documents":["texts"],"query":"query"}""";
|
||||
assertEquals(stripWhitespace(expected), xContentResult);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.request;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class AzureAiStudioRerankRequestTests extends ESTestCase {
|
||||
private static final String TARGET_URI = "http://testtarget.local";
|
||||
private static final String INPUT = "documents";
|
||||
private static final String QUERY = "query";
|
||||
private static final Integer TOP_N = 2;
|
||||
|
||||
public void testCreateRequest_WithCohereProviderTokenEndpoint_NoParams() throws IOException {
|
||||
final var input = randomAlphaOfLength(3);
|
||||
final var query = randomAlphaOfLength(3);
|
||||
final var apikey = randomAlphaOfLength(3);
|
||||
final var request = createRequest(TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey, query, input);
|
||||
final var httpPost = getHttpPost(request, apikey);
|
||||
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(2));
|
||||
assertThat(requestMap.get(QUERY), is(query));
|
||||
assertThat(requestMap.get(INPUT), is(List.of(input)));
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException {
|
||||
final var input = randomAlphaOfLength(3);
|
||||
final var query = randomAlphaOfLength(3);
|
||||
final var apikey = randomAlphaOfLength(3);
|
||||
final var request = createRequest(
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
apikey,
|
||||
null,
|
||||
TOP_N,
|
||||
query,
|
||||
input
|
||||
);
|
||||
final var httpPost = getHttpPost(request, apikey);
|
||||
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(3));
|
||||
assertThat(requestMap.get(QUERY), is(query));
|
||||
assertThat(requestMap.get(INPUT), is(List.of(input)));
|
||||
assertThat(requestMap.get(TOP_N_FIELD), is(TOP_N));
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithCohereProviderTokenEndpoint_WithReturnDocumentsParam() throws IOException {
|
||||
final var input = randomAlphaOfLength(3);
|
||||
final var query = randomAlphaOfLength(3);
|
||||
final var apikey = randomAlphaOfLength(3);
|
||||
final var request = createRequest(
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
apikey,
|
||||
true,
|
||||
null,
|
||||
query,
|
||||
input
|
||||
);
|
||||
final var httpPost = getHttpPost(request, apikey);
|
||||
final var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(3));
|
||||
assertThat(requestMap.get(QUERY), is(query));
|
||||
assertThat(requestMap.get(INPUT), is(List.of(input)));
|
||||
assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD), is(true));
|
||||
}
|
||||
|
||||
private HttpPost getHttpPost(AzureAiStudioRerankRequest request, String apikey) {
|
||||
final var httpRequest = request.createHttpRequest();
|
||||
|
||||
final var httpPost = validateRequestUrlAndContentType(httpRequest, TARGET_URI + "/v1/rerank");
|
||||
validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey);
|
||||
return httpPost;
|
||||
}
|
||||
|
||||
private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) {
|
||||
assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
final var httpPost = (HttpPost) request.httpRequestBase();
|
||||
assertThat(httpPost.getURI().toString(), is(expectedUrl));
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||
return httpPost;
|
||||
}
|
||||
|
||||
private void validateRequestApiKey(
|
||||
HttpPost httpPost,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
String apiKey
|
||||
) {
|
||||
if (endpointType == AzureAiStudioEndpointType.TOKEN) {
|
||||
if (provider == AzureAiStudioProvider.OPENAI) {
|
||||
assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
|
||||
} else {
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey));
|
||||
}
|
||||
} else {
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey));
|
||||
}
|
||||
}
|
||||
|
||||
public static AzureAiStudioRerankRequest createRequest(
|
||||
String target,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
String apiKey,
|
||||
String query,
|
||||
String input
|
||||
) {
|
||||
return createRequest(target, provider, endpointType, apiKey, null, null, query, input);
|
||||
}
|
||||
|
||||
public static AzureAiStudioRerankRequest createRequest(
|
||||
String target,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
String apiKey,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
String query,
|
||||
String input
|
||||
) {
|
||||
final var model = AzureAiStudioRerankModelTests.createModel(
|
||||
"id",
|
||||
target,
|
||||
provider,
|
||||
endpointType,
|
||||
apiKey,
|
||||
returnDocuments,
|
||||
topN,
|
||||
null
|
||||
);
|
||||
return new AzureAiStudioRerankRequest(model, query, List.of(input), returnDocuments, topN);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.net.URISyntaxException;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.sameInstance;
|
||||
|
||||
public class AzureAiStudioRerankModelTests extends ESTestCase {
|
||||
private static final String MODEL_ID = "id";
|
||||
private static final String TARGET_URI = "http://testtarget.local";
|
||||
private static final String API_KEY = "apikey";
|
||||
private static final Integer TOP_N = 1;
|
||||
private static final Integer TOP_N_OVERRIDE = 2;
|
||||
|
||||
public void testOverrideWith_OverridesWithoutValues() {
|
||||
final var model = createModel(
|
||||
MODEL_ID,
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
API_KEY,
|
||||
true,
|
||||
TOP_N,
|
||||
null
|
||||
);
|
||||
final var requestTaskSettingsMap = getTaskSettingsMap(null, null);
|
||||
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettingsMap);
|
||||
|
||||
assertThat(overriddenModel, sameInstance(overriddenModel));
|
||||
}
|
||||
|
||||
public void testOverrideWith_returnDocuments() {
|
||||
final var model = createModel(
|
||||
MODEL_ID,
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
API_KEY,
|
||||
true,
|
||||
null,
|
||||
null
|
||||
);
|
||||
final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(false, null);
|
||||
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
|
||||
|
||||
assertThat(
|
||||
overriddenModel,
|
||||
is(createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY, false, null, null))
|
||||
);
|
||||
}
|
||||
|
||||
public void testOverrideWith_topN() {
|
||||
final var model = createModel(
|
||||
MODEL_ID,
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
API_KEY,
|
||||
null,
|
||||
TOP_N,
|
||||
null
|
||||
);
|
||||
final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(null, TOP_N_OVERRIDE);
|
||||
final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
|
||||
assertThat(
|
||||
overriddenModel,
|
||||
is(
|
||||
createModel(
|
||||
MODEL_ID,
|
||||
TARGET_URI,
|
||||
AzureAiStudioProvider.COHERE,
|
||||
AzureAiStudioEndpointType.TOKEN,
|
||||
API_KEY,
|
||||
null,
|
||||
TOP_N_OVERRIDE,
|
||||
null
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testSetsProperUrlForCohereTokenModel() throws URISyntaxException {
|
||||
final var model = createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY);
|
||||
assertThat(model.getEndpointUri().toString(), is(TARGET_URI + "/v1/rerank"));
|
||||
}
|
||||
|
||||
public static AzureAiStudioRerankModel createModel(
|
||||
String id,
|
||||
String target,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
String apiKey
|
||||
) {
|
||||
return createModel(id, target, provider, endpointType, apiKey, null, null, null);
|
||||
}
|
||||
|
||||
public static AzureAiStudioRerankModel createModel(
|
||||
String id,
|
||||
String target,
|
||||
AzureAiStudioProvider provider,
|
||||
AzureAiStudioEndpointType endpointType,
|
||||
String apiKey,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
@Nullable RateLimitSettings rateLimitSettings
|
||||
) {
|
||||
return new AzureAiStudioRerankModel(
|
||||
id,
|
||||
new AzureAiStudioRerankServiceSettings(target, provider, endpointType, rateLimitSettings),
|
||||
new AzureAiStudioRerankTaskSettings(returnDocuments, topN),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class AzureAiStudioRerankRequestTaskSettingsTests extends ESTestCase {
|
||||
private static final String INVALID_FIELD_TYPE_STRING = "invalid";
|
||||
private static final boolean RETURN_DOCUMENTS = true;
|
||||
private static final int TOP_N = 2;
|
||||
|
||||
public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() {
|
||||
assertThat(
|
||||
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of())),
|
||||
is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
|
||||
assertThat(
|
||||
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))),
|
||||
is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsReturnDocuments() {
|
||||
assertThat(
|
||||
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, RETURN_DOCUMENTS))),
|
||||
is(new AzureAiStudioRerankRequestTaskSettings(RETURN_DOCUMENTS, null))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnsTopN() {
|
||||
assertThat(
|
||||
AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_N_FIELD, TOP_N))),
|
||||
is(new AzureAiStudioRerankRequestTaskSettings(null, TOP_N))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
|
||||
assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
|
||||
}
|
||||
|
||||
public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
|
||||
assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
|
||||
}
|
||||
|
||||
private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
|
||||
final var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString(
|
||||
Strings.format(
|
||||
"field ["
|
||||
+ field
|
||||
+ "] is not of the expected type. The value ["
|
||||
+ INVALID_FIELD_TYPE_STRING
|
||||
+ "] cannot be converted to a "
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.hamcrest.CoreMatchers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType.TOKEN;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider.COHERE;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class AzureAiStudioRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankServiceSettings> {
|
||||
private static final String TARGET_URI = "http://testtarget.local";
|
||||
|
||||
public void testFromMap_Request_CreatesSettingsCorrectly() {
|
||||
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
|
||||
createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
|
||||
ConfigurationParseContext.REQUEST
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
|
||||
}
|
||||
|
||||
public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() {
|
||||
final var settingsMap = createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name());
|
||||
settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)));
|
||||
|
||||
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST);
|
||||
|
||||
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3))));
|
||||
}
|
||||
|
||||
public void testFromMap_Persistent_CreatesSettingsCorrectly() {
|
||||
final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
|
||||
createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
);
|
||||
|
||||
assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
|
||||
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
settings.toXContent(builder, null);
|
||||
final String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, CoreMatchers.is("""
|
||||
{"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
|
||||
"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public void testToFilteredXContent_WritesAllValues() throws IOException {
|
||||
final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
|
||||
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
final var filteredXContent = settings.getFilteredXContentObject();
|
||||
filteredXContent.toXContent(builder, null);
|
||||
final String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, CoreMatchers.is("""
|
||||
{"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
|
||||
"rate_limit":{"requests_per_minute":3}}"""));
|
||||
}
|
||||
|
||||
public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
|
||||
return new HashMap<>(Map.of(TARGET_FIELD, target, PROVIDER_FIELD, provider, ENDPOINT_TYPE_FIELD, endpointType));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<AzureAiStudioRerankServiceSettings> instanceReader() {
|
||||
return AzureAiStudioRerankServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankServiceSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankServiceSettings mutateInstance(AzureAiStudioRerankServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, AzureAiStudioRerankServiceSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankServiceSettings mutateInstanceForVersion(
|
||||
AzureAiStudioRerankServiceSettings instance,
|
||||
TransportVersion version
|
||||
) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
private static AzureAiStudioRerankServiceSettings createRandom() {
|
||||
return new AzureAiStudioRerankServiceSettings(
|
||||
randomAlphaOfLength(10),
|
||||
randomFrom(AzureAiStudioProvider.values()),
|
||||
randomFrom(AzureAiStudioEndpointType.values()),
|
||||
RateLimitSettingsTests.createRandom()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
|
||||
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class AzureAiStudioRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankTaskSettings> {
|
||||
private static final String INVALID_FIELD_TYPE_STRING = "invalid";
|
||||
|
||||
public void testIsEmpty() {
|
||||
final var randomSettings = createRandom();
|
||||
final var stringRep = Strings.toString(randomSettings);
|
||||
assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithAllValues() {
|
||||
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
|
||||
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
|
||||
assertUpdateSettings(newSettings, initialSettings);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithReturnDocumentsValue() {
|
||||
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
|
||||
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
|
||||
assertUpdateSettings(newSettings, initialSettings);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithTopNValue() {
|
||||
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
|
||||
AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
|
||||
assertUpdateSettings(newSettings, initialSettings);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithNoValues() {
|
||||
AzureAiStudioRerankTaskSettings initialSettings = createRandom();
|
||||
final AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(null, null);
|
||||
assertUpdateSettings(newSettings, initialSettings);
|
||||
}
|
||||
|
||||
private void assertUpdateSettings(AzureAiStudioRerankTaskSettings newSettings, AzureAiStudioRerankTaskSettings initialSettings) {
|
||||
final var settingsMap = new HashMap<String, Object>();
|
||||
if (newSettings.returnDocuments() != null) settingsMap.put(RETURN_DOCUMENTS_FIELD, newSettings.returnDocuments());
|
||||
if (newSettings.topN() != null) settingsMap.put(TOP_N_FIELD, newSettings.topN());
|
||||
|
||||
final AzureAiStudioRerankTaskSettings updatedSettings = (AzureAiStudioRerankTaskSettings) initialSettings.updatedTaskSettings(
|
||||
Collections.unmodifiableMap(settingsMap)
|
||||
);
|
||||
assertEquals(
|
||||
newSettings.returnDocuments() == null ? initialSettings.returnDocuments() : newSettings.returnDocuments(),
|
||||
updatedSettings.returnDocuments()
|
||||
);
|
||||
assertEquals(newSettings.topN() == null ? initialSettings.topN() : newSettings.topN(), updatedSettings.topN());
|
||||
}
|
||||
|
||||
public void testFromMap_AllValues() {
|
||||
assertEquals(new AzureAiStudioRerankTaskSettings(true, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2)));
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnDocuments() {
|
||||
assertEquals(
|
||||
new AzureAiStudioRerankTaskSettings(true, null),
|
||||
AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null))
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromMap_TopN() {
|
||||
assertEquals(new AzureAiStudioRerankTaskSettings(null, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2)));
|
||||
}
|
||||
|
||||
public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
|
||||
getTaskSettingsMap(true, 2).put(RETURN_DOCUMENTS_FIELD, INVALID_FIELD_TYPE_STRING);
|
||||
assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
|
||||
}
|
||||
|
||||
public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
|
||||
getTaskSettingsMap(true, 2).put(TOP_N_FIELD, INVALID_FIELD_TYPE_STRING);
|
||||
assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
|
||||
}
|
||||
|
||||
public void testFromMap_WithNoValues_DoesNotThrowException() {
|
||||
final var taskMap = AzureAiStudioRerankTaskSettings.fromMap(new HashMap<>(Map.of()));
|
||||
assertNull(taskMap.returnDocuments());
|
||||
assertNull(taskMap.topN());
|
||||
}
|
||||
|
||||
public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
|
||||
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2));
|
||||
final var overrideSettings = AzureAiStudioRerankTaskSettings.of(settings, AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS);
|
||||
MatcherAssert.assertThat(overrideSettings, is(settings));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UsesReturnDocumentsOverride() {
|
||||
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null));
|
||||
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(false, null));
|
||||
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(false, null)));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UsesTopNOverride() {
|
||||
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2));
|
||||
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(null, 1));
|
||||
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(null, 1)));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UsesAllParametersOverride() {
|
||||
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(false, 2));
|
||||
final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(true, 1));
|
||||
final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
|
||||
MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(true, 1)));
|
||||
}
|
||||
|
||||
public void testToXContent_WithoutParameters() throws IOException {
|
||||
assertThat(getXContentResult(null, null), is("{}"));
|
||||
}
|
||||
|
||||
public void testToXContent_WithReturnDocumentsParameter() throws IOException {
|
||||
assertThat(getXContentResult(true, null), is("""
|
||||
{"return_documents":true}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WithTopNParameter() throws IOException {
|
||||
assertThat(getXContentResult(null, 2), is("""
|
||||
{"top_n":2}"""));
|
||||
}
|
||||
|
||||
public void testToXContent_WithParameters() throws IOException {
|
||||
assertThat(getXContentResult(true, 2), is("""
|
||||
{"return_documents":true,"top_n":2}"""));
|
||||
}
|
||||
|
||||
private String getXContentResult(Boolean returnDocuments, Integer topN) throws IOException {
|
||||
final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(returnDocuments, topN));
|
||||
final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
settings.toXContent(builder, null);
|
||||
return Strings.toString(builder);
|
||||
}
|
||||
|
||||
public static Map<String, Object> getTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
|
||||
final var map = new HashMap<String, Object>();
|
||||
|
||||
if (returnDocuments != null) {
|
||||
map.put(RETURN_DOCUMENTS_FIELD, returnDocuments);
|
||||
}
|
||||
|
||||
if (topN != null) {
|
||||
map.put(TOP_N_FIELD, topN);
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<AzureAiStudioRerankTaskSettings> instanceReader() {
|
||||
return AzureAiStudioRerankTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankTaskSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankTaskSettings mutateInstance(AzureAiStudioRerankTaskSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, AzureAiStudioRerankTaskSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AzureAiStudioRerankTaskSettings mutateInstanceForVersion(AzureAiStudioRerankTaskSettings instance, TransportVersion version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
private static AzureAiStudioRerankTaskSettings createRandom() {
|
||||
return new AzureAiStudioRerankTaskSettings(
|
||||
randomFrom(new Boolean[] { null, randomBoolean() }),
|
||||
randomFrom(new Integer[] { null, randomNonNegativeInt() })
|
||||
);
|
||||
}
|
||||
|
||||
private static AzureAiStudioRerankTaskSettings createRandom(AzureAiStudioRerankTaskSettings settings) {
|
||||
return new AzureAiStudioRerankTaskSettings(
|
||||
randomValueOtherThan(settings.returnDocuments(), () -> randomFrom(new Boolean[] { null, randomBoolean() })),
|
||||
randomValueOtherThan(settings.topN(), () -> randomFrom(new Integer[] { null, randomNonNegativeInt() }))
|
||||
);
|
||||
}
|
||||
|
||||
private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
|
||||
final var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
|
||||
);
|
||||
|
||||
MatcherAssert.assertThat(
|
||||
thrownException.getMessage(),
|
||||
containsString(
|
||||
Strings.format(
|
||||
"field ["
|
||||
+ field
|
||||
+ "] is not of the expected type. The value ["
|
||||
+ INVALID_FIELD_TYPE_STRING
|
||||
+ "] cannot be converted to a "
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.azureaistudio.response;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class AzureAiStudioRerankResponseEntityTests extends ESTestCase {
|
||||
public void testResponse_WithDocuments() throws IOException {
|
||||
final String responseJson = getResponseJsonWithDocuments();
|
||||
|
||||
final var parsedResults = getParsedResults(responseJson);
|
||||
final var expectedResults = List.of(
|
||||
new RankedDocsResults.RankedDoc(0, 0.1111111F, "test text one"),
|
||||
new RankedDocsResults.RankedDoc(1, 0.2222222F, "test text two")
|
||||
);
|
||||
|
||||
assertThat(parsedResults.getRankedDocs(), is(expectedResults));
|
||||
}
|
||||
|
||||
public void testResponse_NoDocuments() throws IOException {
|
||||
final String responseJson = getResponseJsonNoDocuments();
|
||||
|
||||
final var parsedResults = getParsedResults(responseJson);
|
||||
final var expectedResults = List.of(
|
||||
new RankedDocsResults.RankedDoc(0, 0.1111111F, null),
|
||||
new RankedDocsResults.RankedDoc(1, 0.2222222F, null)
|
||||
);
|
||||
|
||||
assertThat(parsedResults.getRankedDocs(), is(expectedResults));
|
||||
}
|
||||
|
||||
private RankedDocsResults getParsedResults(String responseJson) throws IOException {
|
||||
final var entity = new AzureAiStudioRerankResponseEntity();
|
||||
return (RankedDocsResults) entity.apply(
|
||||
mock(Request.class),
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
}
|
||||
|
||||
private String getResponseJsonWithDocuments() {
|
||||
return """
|
||||
{
|
||||
"id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
|
||||
"results": [
|
||||
{
|
||||
"document": {
|
||||
"text": "test text one"
|
||||
},
|
||||
"index": 0,
|
||||
"relevance_score": 0.1111111
|
||||
},
|
||||
{
|
||||
"document": {
|
||||
"text": "test text two"
|
||||
},
|
||||
"index": 1,
|
||||
"relevance_score": 0.2222222
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"search_units": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
private String getResponseJsonNoDocuments() {
|
||||
return """
|
||||
{
|
||||
"id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
|
||||
"results": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.1111111
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"relevance_score": 0.2222222
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"search_units": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue