Enable force inference endpoint deleting for invalid models and after stopping model deployment fails (#129090)

* Enable force inference endpoint deleting for invalid models and after stopping model deployment fails

* Update docs/changelog/129090.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Dan Rubinstein 2025-07-16 12:44:48 -04:00 committed by GitHub
parent 037ddaa5c8
commit 9c6cf90456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 255 additions and 3 deletions

View File

@ -0,0 +1,6 @@
pr: 129090
summary: Enable force inference endpoint deleting for invalid models and after stopping
model deployment fails
area: Machine Learning
type: enhancement
issues: []

View File

@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
@ -128,10 +129,38 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
}
var service = serviceRegistry.getService(unparsedModel.service());
Model model;
if (service.isPresent()) {
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
service.get().stop(model, listener);
try {
model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
} catch (Exception e) {
if (request.isForceDelete()) {
listener.onResponse(true);
return;
} else {
listener.onFailure(
new ElasticsearchStatusException(
Strings.format(
"Failed to parse model configuration for inference endpoint [%s]",
request.getInferenceEndpointId()
),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
return;
}
}
service.get().stop(model, listener.delegateResponse((l, e) -> {
if (request.isForceDelete()) {
l.onResponse(true);
} else {
l.onFailure(e);
}
}));
} else if (request.isForceDelete()) {
listener.onResponse(true);
} else {
listener.onFailure(
new ElasticsearchStatusException(

View File

@ -17,8 +17,10 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
@ -32,11 +34,17 @@ import java.util.Map;
import java.util.Optional;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
@ -130,4 +138,213 @@ public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
assertTrue(response.isAcknowledged());
}
public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
}
public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);
var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
}
private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
}
public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
mockNoService(inferenceEndpointId, serviceName, taskType);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(anyString(), any());
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);
var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
}
public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
mockNoService(inferenceEndpointId, serviceName, taskType);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
}
private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
}
public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
var mockModel = mock(Model.class);
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockService).stop(eq(mockModel), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
}
public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
var mockModel = mock(Model.class);
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);
var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockService).stop(eq(mockModel), any());
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
}
private void mockStopDeploymentFails(
String inferenceEndpointId,
String serviceName,
TaskType taskType,
InferenceService mockService,
Model mockModel
) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
return Void.TYPE;
}).when(mockService).stop(eq(mockModel), any());
}
}