[ML] Sync Inference with Trained Model stats (#130544)
When the Trained Model stats are read, either during `GET _inference` or `PUT _inference`, the Inference stats are updated to reflected the Trained Model stats. Fix #130339
This commit is contained in:
parent
636b9bbc89
commit
63b753f396
|
@ -0,0 +1,6 @@
|
|||
pr: 130544
|
||||
summary: Sync Inference with Trained Model stats
|
||||
area: Machine Learning
|
||||
type: bug
|
||||
issues:
|
||||
- 130339
|
|
@ -114,7 +114,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
|
|||
}
|
||||
}).<Boolean>andThen((l2, modelDidPut) -> {
|
||||
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
|
||||
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
|
||||
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2);
|
||||
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
|
||||
});
|
||||
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
|
||||
|
|
|
@ -10,7 +10,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
|
|||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
|
@ -43,7 +42,7 @@ public class ElasticDeployedModel extends ElasticsearchInternalModel {
|
|||
|
||||
@Override
|
||||
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
|
||||
Model model,
|
||||
ElasticsearchInternalModel esModel,
|
||||
ActionListener<Boolean> listener
|
||||
) {
|
||||
throw new IllegalStateException("cannot start model that uses an existing deployment");
|
||||
|
|
|
@ -9,7 +9,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
|
|||
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
@ -33,7 +32,7 @@ public class ElasticRerankerModel extends ElasticsearchInternalModel {
|
|||
|
||||
@Override
|
||||
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
|
||||
Model model,
|
||||
ElasticsearchInternalModel esModel,
|
||||
ActionListener<Boolean> listener
|
||||
) {
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@ import org.elasticsearch.inference.TaskType;
|
|||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
|
||||
|
@ -85,12 +87,13 @@ public abstract class ElasticsearchInternalModel extends Model {
|
|||
}
|
||||
|
||||
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
|
||||
Model model,
|
||||
ElasticsearchInternalModel esModel,
|
||||
ActionListener<Boolean> listener
|
||||
) {
|
||||
return new ActionListener<>() {
|
||||
@Override
|
||||
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
|
||||
esModel.updateServiceSettings(response.getTrainedModelAssignment());
|
||||
listener.onResponse(Boolean.TRUE);
|
||||
}
|
||||
|
||||
|
@ -98,7 +101,7 @@ public abstract class ElasticsearchInternalModel extends Model {
|
|||
public void onFailure(Exception e) {
|
||||
var cause = ExceptionsHelper.unwrapCause(e);
|
||||
if (cause instanceof ResourceNotFoundException) {
|
||||
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
|
||||
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId())));
|
||||
return;
|
||||
} else if (cause instanceof ElasticsearchStatusException statusException) {
|
||||
if (statusException.status() == RestStatus.CONFLICT
|
||||
|
@ -128,8 +131,18 @@ public abstract class ElasticsearchInternalModel extends Model {
|
|||
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
public void updateNumAllocations(Integer numAllocations) {
|
||||
this.internalServiceSettings.setNumAllocations(numAllocations);
|
||||
public void updateServiceSettings(AssignmentStats assignmentStats) {
|
||||
this.internalServiceSettings.setAllocations(
|
||||
assignmentStats.getNumberOfAllocations(),
|
||||
assignmentStats.getAdaptiveAllocationsSettings()
|
||||
);
|
||||
}
|
||||
|
||||
private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) {
|
||||
this.internalServiceSettings.setAllocations(
|
||||
this.internalServiceSettings.getNumAllocations(),
|
||||
trainedModelAssignment.getAdaptiveAllocationsSettings()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -890,7 +890,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
|||
ActionListener.wrap(stats -> {
|
||||
for (var deploymentStats : stats.getStats().results()) {
|
||||
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
|
||||
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
|
||||
modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats));
|
||||
}
|
||||
var updatedModels = new ArrayList<Model>();
|
||||
modelsByDeploymentIds.values().forEach(updatedModels::addAll);
|
||||
|
|
|
@ -43,7 +43,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
|
|||
private Integer numAllocations;
|
||||
private final int numThreads;
|
||||
private final String modelId;
|
||||
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
|
||||
private AdaptiveAllocationsSettings adaptiveAllocationsSettings;
|
||||
private final String deploymentId;
|
||||
|
||||
public static ElasticsearchInternalServiceSettings fromPersistedMap(Map<String, Object> map) {
|
||||
|
@ -158,8 +158,9 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
|
|||
this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null;
|
||||
}
|
||||
|
||||
public void setNumAllocations(Integer numAllocations) {
|
||||
public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
|
||||
this.numAllocations = numAllocations;
|
||||
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -108,10 +108,12 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.Elasticse
|
|||
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
|
||||
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.same;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
|
@ -1767,7 +1769,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
modelsByDeploymentId.forEach((deploymentId, models) -> {
|
||||
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
|
||||
models.forEach(model -> {
|
||||
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
|
||||
verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> {
|
||||
assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations));
|
||||
}));
|
||||
verify((ElasticsearchInternalModel) model).mlNodeDeploymentId();
|
||||
verifyNoMoreInteractions(model);
|
||||
});
|
||||
|
@ -1858,7 +1862,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|||
var latch = new CountDownLatch(1);
|
||||
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown()));
|
||||
assertTrue(latch.await(30, TimeUnit.SECONDS));
|
||||
verify(model).updateNumAllocations(3);
|
||||
verify(model).updateServiceSettings(
|
||||
assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); })
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,8 +7,17 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.elasticsearch;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
|
||||
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ElserInternalModelTests extends ESTestCase {
|
||||
public void testUpdateNumAllocation() {
|
||||
|
@ -21,10 +30,22 @@ public class ElserInternalModelTests extends ESTestCase {
|
|||
null
|
||||
);
|
||||
|
||||
model.updateNumAllocations(1);
|
||||
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
|
||||
AssignmentStats assignmentStats = mock();
|
||||
when(assignmentStats.getNumberOfAllocations()).thenReturn(1);
|
||||
model.updateServiceSettings(assignmentStats);
|
||||
|
||||
model.updateNumAllocations(null);
|
||||
assertNull(model.getServiceSettings().getNumAllocations());
|
||||
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
|
||||
assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings());
|
||||
|
||||
TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance();
|
||||
CreateTrainedModelAssignmentAction.Response response = mock();
|
||||
when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment);
|
||||
model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response);
|
||||
|
||||
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
|
||||
assertThat(
|
||||
model.getServiceSettings().getAdaptiveAllocationsSettings(),
|
||||
equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue