[ML] Fixing bug with TransportPutModelAction listener and adding timeout to request (#126805)
* Fixing bug with listener and adding timeout * Update docs/changelog/126805.yaml * Fixing tests * Fixing writeTo
This commit is contained in:
parent
1a1763c591
commit
4c507e27d9
|
@ -0,0 +1,5 @@
|
|||
pr: 126805
|
||||
summary: Adding timeout to request for creating inference endpoint
|
||||
area: Machine Learning
|
||||
type: bug
|
||||
issues: []
|
|
@ -172,6 +172,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
|
||||
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
|
||||
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
|
||||
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
|
||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||
|
@ -248,6 +249,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
|
||||
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
|
||||
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
|
||||
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_00_0);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.inference.action;
|
||||
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
|
@ -15,6 +16,7 @@ import org.elasticsearch.common.bytes.BytesReference;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
|
@ -41,13 +43,15 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
|
|||
private final String inferenceEntityId;
|
||||
private final BytesReference content;
|
||||
private final XContentType contentType;
|
||||
private final TimeValue timeout;
|
||||
|
||||
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
|
||||
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
|
||||
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
|
||||
this.taskType = taskType;
|
||||
this.inferenceEntityId = inferenceEntityId;
|
||||
this.content = content;
|
||||
this.contentType = contentType;
|
||||
this.timeout = timeout;
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
|
@ -56,6 +60,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
|
|||
this.taskType = TaskType.fromStream(in);
|
||||
this.content = in.readBytesReference();
|
||||
this.contentType = in.readEnum(XContentType.class);
|
||||
|
||||
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|
||||
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
|
||||
this.timeout = in.readTimeValue();
|
||||
} else {
|
||||
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
|
||||
}
|
||||
}
|
||||
|
||||
public TaskType getTaskType() {
|
||||
|
@ -74,6 +85,10 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
|
|||
return contentType;
|
||||
}
|
||||
|
||||
public TimeValue getTimeout() {
|
||||
return timeout;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
|
@ -81,6 +96,11 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
|
|||
taskType.writeTo(out);
|
||||
out.writeBytesReference(content);
|
||||
XContentHelper.writeTo(out, contentType);
|
||||
|
||||
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|
||||
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
|
||||
out.writeTimeValue(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -105,12 +125,13 @@ public class PutInferenceModelAction extends ActionType<PutInferenceModelAction.
|
|||
return taskType == request.taskType
|
||||
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
|
||||
&& Objects.equals(content, request.content)
|
||||
&& contentType == request.contentType;
|
||||
&& contentType == request.contentType
|
||||
&& Objects.equals(timeout, request.timeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(taskType, inferenceEntityId, content, contentType);
|
||||
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -34,13 +34,25 @@ public class PutInferenceModelActionTests extends ESTestCase {
|
|||
|
||||
public void testValidate() {
|
||||
// valid model ID
|
||||
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
|
||||
var request = new PutInferenceModelAction.Request(
|
||||
TASK_TYPE,
|
||||
MODEL_ID + "_-0",
|
||||
BYTES,
|
||||
X_CONTENT_TYPE,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT
|
||||
);
|
||||
ActionRequestValidationException validationException = request.validate();
|
||||
assertNull(validationException);
|
||||
|
||||
// invalid model IDs
|
||||
|
||||
var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
|
||||
var invalidRequest = new PutInferenceModelAction.Request(
|
||||
TASK_TYPE,
|
||||
"",
|
||||
BYTES,
|
||||
X_CONTENT_TYPE,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT
|
||||
);
|
||||
validationException = invalidRequest.validate();
|
||||
assertNotNull(validationException);
|
||||
|
||||
|
@ -48,12 +60,19 @@ public class PutInferenceModelActionTests extends ESTestCase {
|
|||
TASK_TYPE,
|
||||
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
|
||||
BYTES,
|
||||
X_CONTENT_TYPE
|
||||
X_CONTENT_TYPE,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT
|
||||
);
|
||||
validationException = invalidRequest2.validate();
|
||||
assertNotNull(validationException);
|
||||
|
||||
var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
|
||||
var invalidRequest3 = new PutInferenceModelAction.Request(
|
||||
TASK_TYPE,
|
||||
null,
|
||||
BYTES,
|
||||
X_CONTENT_TYPE,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT
|
||||
);
|
||||
validationException = invalidRequest3.validate();
|
||||
assertNotNull(validationException);
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
|
|||
return;
|
||||
}
|
||||
|
||||
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
|
||||
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
|
||||
}
|
||||
|
||||
private void parseAndStoreModel(
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
|
|||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.PUT;
|
||||
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
|
||||
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
|
||||
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
|
||||
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
|
||||
|
@ -49,8 +50,15 @@ public class RestPutInferenceModelAction extends BaseRestHandler {
|
|||
taskType = TaskType.ANY; // task type must be defined in the body
|
||||
}
|
||||
|
||||
var inferTimeout = parseTimeout(restRequest);
|
||||
var content = restRequest.requiredContent();
|
||||
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
|
||||
var request = new PutInferenceModelAction.Request(
|
||||
taskType,
|
||||
inferenceEntityId,
|
||||
content,
|
||||
restRequest.getXContentType(),
|
||||
inferTimeout
|
||||
);
|
||||
return channel -> client.execute(
|
||||
PutInferenceModelAction.INSTANCE,
|
||||
request,
|
||||
|
|
|
@ -7,13 +7,16 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.action;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
|
||||
public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
|
||||
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
|
||||
@Override
|
||||
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
|
||||
return PutInferenceModelAction.Request::new;
|
||||
|
@ -25,38 +28,29 @@ public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCa
|
|||
randomFrom(TaskType.values()),
|
||||
randomAlphaOfLength(6),
|
||||
randomBytesReference(50),
|
||||
randomFrom(XContentType.values())
|
||||
randomFrom(XContentType.values()),
|
||||
randomTimeValue()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
|
||||
return switch (randomIntBetween(0, 3)) {
|
||||
case 0 -> new PutInferenceModelAction.Request(
|
||||
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
|
||||
instance.getInferenceEntityId(),
|
||||
instance.getContent(),
|
||||
instance.getContentType()
|
||||
);
|
||||
case 1 -> new PutInferenceModelAction.Request(
|
||||
instance.getTaskType(),
|
||||
instance.getInferenceEntityId() + "foo",
|
||||
instance.getContent(),
|
||||
instance.getContentType()
|
||||
);
|
||||
case 2 -> new PutInferenceModelAction.Request(
|
||||
instance.getTaskType(),
|
||||
instance.getInferenceEntityId(),
|
||||
randomBytesReference(instance.getContent().length() + 1),
|
||||
instance.getContentType()
|
||||
);
|
||||
case 3 -> new PutInferenceModelAction.Request(
|
||||
return randomValueOtherThan(instance, this::createTestInstance);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
|
||||
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|
||||
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
|
||||
return instance;
|
||||
} else {
|
||||
return new PutInferenceModelAction.Request(
|
||||
instance.getTaskType(),
|
||||
instance.getInferenceEntityId(),
|
||||
instance.getContent(),
|
||||
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
|
||||
instance.getContentType(),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT
|
||||
);
|
||||
default -> throw new IllegalStateException();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue