[ML] Fixing bug with TransportPutModelAction listener and adding timeout to request (#126805)

* Fixing bug with listener and adding timeout

* Update docs/changelog/126805.yaml

* Fixing tests

* Fixing writeTo
This commit is contained in:
Jonathan Buttner 2025-05-06 15:39:24 -04:00 committed by GitHub
parent 1a1763c591
commit 4c507e27d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 84 additions and 35 deletions

View File

@ -0,0 +1,5 @@
pr: 126805
summary: Adding timeout to request for creating inference endpoint
area: Machine Learning
type: bug
issues: []

View File

@ -172,6 +172,7 @@ public class TransportVersions {
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@ -248,6 +249,7 @@ public class TransportVersions {
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.inference.action;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
@ -15,6 +16,7 @@ import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
@ -41,13 +43,15 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TimeValue timeout;
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
}
public Request(StreamInput in) throws IOException {
@ -56,6 +60,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
this.taskType = TaskType.fromStream(in);
this.content = in.readBytesReference();
this.contentType = in.readEnum(XContentType.class);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
this.timeout = in.readTimeValue();
} else {
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
}
}
public TaskType getTaskType() {
@ -74,6 +85,10 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
return contentType;
}
public TimeValue getTimeout() {
return timeout;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
@ -81,6 +96,11 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
out.writeTimeValue(timeout);
}
}
@Override
@ -105,12 +125,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
return taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType;
&& contentType == request.contentType
&& Objects.equals(timeout, request.timeout);
}
@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
}
}

View File

@ -34,13 +34,25 @@ public class PutInferenceModelActionTests extends ESTestCase {
public void testValidate() {
// valid model ID
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
var request = new PutInferenceModelAction.Request(
TASK_TYPE,
MODEL_ID + "_-0",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
ActionRequestValidationException validationException = request.validate();
assertNull(validationException);
// invalid model IDs
var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
var invalidRequest = new PutInferenceModelAction.Request(
TASK_TYPE,
"",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest.validate();
assertNotNull(validationException);
@ -48,12 +60,19 @@ public class PutInferenceModelActionTests extends ESTestCase {
TASK_TYPE,
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
BYTES,
X_CONTENT_TYPE
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest2.validate();
assertNotNull(validationException);
var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
var invalidRequest3 = new PutInferenceModelAction.Request(
TASK_TYPE,
null,
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest3.validate();
assertNotNull(validationException);
}

View File

@ -177,7 +177,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
return;
}
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
}
private void parseAndStoreModel(

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import java.util.List;
import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
@ -49,8 +50,15 @@ public class RestPutInferenceModelAction extends BaseRestHandler {
taskType = TaskType.ANY; // task type must be defined in the body
}
var inferTimeout = parseTimeout(restRequest);
var content = restRequest.requiredContent();
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
var request = new PutInferenceModelAction.Request(
taskType,
inferenceEntityId,
content,
restRequest.getXContentType(),
inferTimeout
);
return channel -> client.execute(
PutInferenceModelAction.INSTANCE,
request,

View File

@ -7,13 +7,16 @@
package org.elasticsearch.xpack.inference.action;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
@Override
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
return PutInferenceModelAction.Request::new;
@ -25,38 +28,29 @@ public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCa
randomFrom(TaskType.values()),
randomAlphaOfLength(6),
randomBytesReference(50),
randomFrom(XContentType.values())
randomFrom(XContentType.values()),
randomTimeValue()
);
}
@Override
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
return switch (randomIntBetween(0, 3)) {
case 0 -> new PutInferenceModelAction.Request(
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
instance.getInferenceEntityId(),
instance.getContent(),
instance.getContentType()
);
case 1 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId() + "foo",
instance.getContent(),
instance.getContentType()
);
case 2 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
randomBytesReference(instance.getContent().length() + 1),
instance.getContentType()
);
case 3 -> new PutInferenceModelAction.Request(
return randomValueOtherThan(instance, this::createTestInstance);
}
@Override
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
return instance;
} else {
return new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getContent(),
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
instance.getContentType(),
InferenceAction.Request.DEFAULT_TIMEOUT
);
default -> throw new IllegalStateException();
};
}
}
}