Allow timeout during trained model download process (#129003)

* Allow timeout during trained model download process

* Update docs/changelog/129003.yaml

* Update timeout message

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Dan Rubinstein 2025-07-02 12:40:52 -04:00 committed by GitHub
parent 72b5c0175a
commit 136442d83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 48 additions and 17 deletions

View File

@ -0,0 +1,5 @@
pr: 129003
summary: Allow timeout during trained model download process
area: Machine Learning
type: bug
issues: []

View File

@ -5,7 +5,7 @@
* 2.0.
*/
package org.elasticsearch.xpack.ml.inference.assignment;
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;

View File

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
@ -22,6 +23,7 @@ import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@ -29,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
@ -41,12 +44,14 @@ import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public abstract class BaseElasticsearchInternalService implements InferenceService {
protected final OriginSettingClient client;
protected final ThreadPool threadPool;
protected final ExecutorService inferenceExecutor;
protected final Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn;
private final ClusterService clusterService;
@ -60,6 +65,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture;
this.clusterService = context.clusterService();
@ -75,6 +81,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn
) {
this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
this.threadPool = context.threadPool();
this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
this.preferredModelVariantFn = preferredModelVariantFn;
this.clusterService = context.clusterService();
@ -96,20 +103,38 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
return;
}
SubscribableListener.<Boolean>newForked(forkedListener -> { isBuiltinModelPut(model, forkedListener); })
.<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
.addListener(finalListener);
// instead of a subscribably listener, use some wait to wait for the first one.
var subscribableListener = SubscribableListener.<Boolean>newForked(
forkedListener -> { isBuiltinModelPut(model, forkedListener); }
).<Boolean>andThen((l, modelConfigExists) -> {
if (modelConfigExists == false) {
putModel(model, l);
} else {
l.onResponse(true);
}
}).<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
});
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
subscribableListener.addListener(finalListener.delegateResponse((l, e) -> {
if (e instanceof ElasticsearchTimeoutException) {
l.onFailure(
new ModelDeploymentTimeoutException(
format(
"Timed out after [%s] waiting for trained model deployment for inference endpoint [%s] to start. "
+ "The inference endpoint can not be used to perform inference until the deployment has started. "
+ "Use the trained model stats API to track the state of the deployment.",
timeout,
model.getInferenceEntityId()
)
)
);
} else {
l.onFailure(e);
}
}));
} else {
finalListener.onFailure(notElasticsearchModelException(model));

View File

@ -52,6 +52,7 @@ import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
@ -65,7 +66,6 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

View File

@ -16,12 +16,12 @@ import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.assignment.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
import java.util.HashMap;

View File

@ -29,6 +29,7 @@ import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.ModelDeploymentTimeoutException;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;