Enable force inference endpoint deleting for invalid models and after stopping model deployment fails (#129090)
* Enable force inference endpoint deleting for invalid models and after stopping model deployment fails * Update docs/changelog/129090.yaml --------- Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
037ddaa5c8
commit
9c6cf90456
|
@ -0,0 +1,6 @@
|
|||
pr: 129090
|
||||
summary: Enable force inference endpoint deleting for invalid models and after stopping
|
||||
model deployment fails
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService;
|
|||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.injection.guice.Inject;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -128,10 +129,38 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
|
|||
}
|
||||
|
||||
var service = serviceRegistry.getService(unparsedModel.service());
|
||||
Model model;
|
||||
if (service.isPresent()) {
|
||||
var model = service.get()
|
||||
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
|
||||
service.get().stop(model, listener);
|
||||
try {
|
||||
model = service.get()
|
||||
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
|
||||
} catch (Exception e) {
|
||||
if (request.isForceDelete()) {
|
||||
listener.onResponse(true);
|
||||
return;
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(
|
||||
Strings.format(
|
||||
"Failed to parse model configuration for inference endpoint [%s]",
|
||||
request.getInferenceEndpointId()
|
||||
),
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
service.get().stop(model, listener.delegateResponse((l, e) -> {
|
||||
if (request.isForceDelete()) {
|
||||
l.onResponse(true);
|
||||
} else {
|
||||
l.onFailure(e);
|
||||
}
|
||||
}));
|
||||
} else if (request.isForceDelete()) {
|
||||
listener.onResponse(true);
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(
|
||||
|
|
|
@ -17,8 +17,10 @@ import org.elasticsearch.cluster.service.ClusterService;
|
|||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceService;
|
||||
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
@ -32,11 +34,17 @@ import java.util.Map;
|
|||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
|
||||
|
@ -130,4 +138,213 @@ public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
|
|||
|
||||
assertTrue(response.isAcknowledged());
|
||||
}
|
||||
|
||||
public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
var mockService = mock(InferenceService.class);
|
||||
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
|
||||
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
|
||||
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
|
||||
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
|
||||
}
|
||||
|
||||
public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
var mockService = mock(InferenceService.class);
|
||||
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(true);
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var response = listener.actionGet(TIMEOUT);
|
||||
assertTrue(response.isAcknowledged());
|
||||
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
|
||||
}
|
||||
|
||||
private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
|
||||
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
|
||||
}
|
||||
|
||||
public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
mockNoService(inferenceEndpointId, serviceName, taskType);
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(true);
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).deleteModel(anyString(), any());
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var response = listener.actionGet(TIMEOUT);
|
||||
assertTrue(response.isAcknowledged());
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
|
||||
}
|
||||
|
||||
public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
mockNoService(inferenceEndpointId, serviceName, taskType);
|
||||
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
|
||||
}
|
||||
|
||||
private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
|
||||
}
|
||||
|
||||
public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
var mockService = mock(InferenceService.class);
|
||||
var mockModel = mock(Model.class);
|
||||
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
|
||||
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
|
||||
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
verify(mockService).stop(eq(mockModel), any());
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
|
||||
}
|
||||
|
||||
public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
|
||||
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
|
||||
var serviceName = randomAlphanumericOfLength(10);
|
||||
var taskType = randomFrom(TaskType.values());
|
||||
var mockService = mock(InferenceService.class);
|
||||
var mockModel = mock(Model.class);
|
||||
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(true);
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
|
||||
|
||||
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
|
||||
action.masterOperation(
|
||||
mock(Task.class),
|
||||
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
|
||||
ClusterState.EMPTY_STATE,
|
||||
listener
|
||||
);
|
||||
|
||||
var response = listener.actionGet(TIMEOUT);
|
||||
assertTrue(response.isAcknowledged());
|
||||
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
|
||||
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
verify(mockService).stop(eq(mockModel), any());
|
||||
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
|
||||
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
|
||||
}
|
||||
|
||||
private void mockStopDeploymentFails(
|
||||
String inferenceEndpointId,
|
||||
String serviceName,
|
||||
TaskType taskType,
|
||||
InferenceService mockService,
|
||||
Model mockModel
|
||||
) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
|
||||
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
|
||||
return Void.TYPE;
|
||||
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
|
||||
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
|
||||
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
|
||||
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
|
||||
return Void.TYPE;
|
||||
}).when(mockService).stop(eq(mockModel), any());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue