[ML] Sync Inference with Trained Model stats (#130544)

When the Trained Model stats are read, either during `GET _inference` or
`PUT _inference`, the Inference stats are updated to reflected the
Trained Model stats.

Fix #130339
This commit is contained in:
Pat Whelan 2025-07-14 15:40:26 -04:00 committed by GitHub
parent 636b9bbc89
commit 63b753f396
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 63 additions and 18 deletions

View File

@ -0,0 +1,6 @@
pr: 130544
summary: Sync Inference with Trained Model stats
area: Machine Learning
type: bug
issues:
- 130339

View File

@ -114,7 +114,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
}
}).<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
});
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);

View File

@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@ -43,7 +42,7 @@ public class ElasticDeployedModel extends ElasticsearchInternalModel {
@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {
throw new IllegalStateException("cannot start model that uses an existing deployment");

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -33,7 +32,7 @@ public class ElasticRerankerModel extends ElasticsearchInternalModel {
@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {

View File

@ -21,6 +21,8 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
@ -85,12 +87,13 @@ public abstract class ElasticsearchInternalModel extends Model {
}
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {
return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
esModel.updateServiceSettings(response.getTrainedModelAssignment());
listener.onResponse(Boolean.TRUE);
}
@ -98,7 +101,7 @@ public abstract class ElasticsearchInternalModel extends Model {
public void onFailure(Exception e) {
var cause = ExceptionsHelper.unwrapCause(e);
if (cause instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId())));
return;
} else if (cause instanceof ElasticsearchStatusException statusException) {
if (statusException.status() == RestStatus.CONFLICT
@ -128,8 +131,18 @@ public abstract class ElasticsearchInternalModel extends Model {
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
}
public void updateNumAllocations(Integer numAllocations) {
this.internalServiceSettings.setNumAllocations(numAllocations);
public void updateServiceSettings(AssignmentStats assignmentStats) {
this.internalServiceSettings.setAllocations(
assignmentStats.getNumberOfAllocations(),
assignmentStats.getAdaptiveAllocationsSettings()
);
}
private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) {
this.internalServiceSettings.setAllocations(
this.internalServiceSettings.getNumAllocations(),
trainedModelAssignment.getAdaptiveAllocationsSettings()
);
}
@Override

View File

@ -890,7 +890,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
ActionListener.wrap(stats -> {
for (var deploymentStats : stats.getStats().results()) {
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats));
}
var updatedModels = new ArrayList<Model>();
modelsByDeploymentIds.values().forEach(updatedModels::addAll);

View File

@ -43,7 +43,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
private Integer numAllocations;
private final int numThreads;
private final String modelId;
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
private AdaptiveAllocationsSettings adaptiveAllocationsSettings;
private final String deploymentId;
public static ElasticsearchInternalServiceSettings fromPersistedMap(Map<String, Object> map) {
@ -158,8 +158,9 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null;
}
public void setNumAllocations(Integer numAllocations) {
public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
this.numAllocations = numAllocations;
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
}
@Override

View File

@ -108,10 +108,12 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.Elasticse
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
@ -1767,7 +1769,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
modelsByDeploymentId.forEach((deploymentId, models) -> {
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
models.forEach(model -> {
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> {
assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations));
}));
verify((ElasticsearchInternalModel) model).mlNodeDeploymentId();
verifyNoMoreInteractions(model);
});
@ -1858,7 +1862,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
var latch = new CountDownLatch(1);
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown()));
assertTrue(latch.await(30, TimeUnit.SECONDS));
verify(model).updateNumAllocations(3);
verify(model).updateServiceSettings(
assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); })
);
}
}

View File

@ -7,8 +7,17 @@
package org.elasticsearch.xpack.inference.services.elasticsearch;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ElserInternalModelTests extends ESTestCase {
public void testUpdateNumAllocation() {
@ -21,10 +30,22 @@ public class ElserInternalModelTests extends ESTestCase {
null
);
model.updateNumAllocations(1);
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
AssignmentStats assignmentStats = mock();
when(assignmentStats.getNumberOfAllocations()).thenReturn(1);
model.updateServiceSettings(assignmentStats);
model.updateNumAllocations(null);
assertNull(model.getServiceSettings().getNumAllocations());
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings());
TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance();
CreateTrainedModelAssignmentAction.Response response = mock();
when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment);
model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response);
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
assertThat(
model.getServiceSettings().getAdaptiveAllocationsSettings(),
equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings())
);
}
}