Increment inference stats counter for shard bulk inference calls (#129140)
This change updates the inference stats counter to include chunked inference calls performed by the shard bulk inference filter on all semantic text fields. It ensures that usage of inference on semantic text fields is properly recorded in the stats.
This commit is contained in:
parent
b3becfa678
commit
42dec5b41f
|
@ -0,0 +1,5 @@
|
|||
pr: 129140
|
||||
summary: Increment inference stats counter for shard bulk inference calls
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -344,22 +344,24 @@ public class InferencePlugin extends Plugin
|
|||
}
|
||||
inferenceServiceRegistry.set(serviceRegistry);
|
||||
|
||||
var meterRegistry = services.telemetryProvider().getMeterRegistry();
|
||||
var inferenceStats = InferenceStats.create(meterRegistry);
|
||||
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);
|
||||
|
||||
var actionFilter = new ShardBulkInferenceActionFilter(
|
||||
services.clusterService(),
|
||||
serviceRegistry,
|
||||
modelRegistry.get(),
|
||||
getLicenseState(),
|
||||
services.indexingPressure()
|
||||
services.indexingPressure(),
|
||||
inferenceStats
|
||||
);
|
||||
shardBulkInferenceActionFilter.set(actionFilter);
|
||||
|
||||
var meterRegistry = services.telemetryProvider().getMeterRegistry();
|
||||
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
|
||||
|
||||
components.add(serviceRegistry);
|
||||
components.add(modelRegistry.get());
|
||||
components.add(httpClientManager);
|
||||
components.add(inferenceStats);
|
||||
components.add(inferenceStatsBinding);
|
||||
|
||||
// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
|
||||
// if the rate limiting feature flags are enabled, otherwise provide noop implementation
|
||||
|
|
|
@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
|
|||
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
|
||||
import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -78,6 +79,8 @@ import java.util.stream.Collectors;
|
|||
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
|
||||
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
||||
|
||||
/**
|
||||
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
|
||||
|
@ -112,6 +115,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
private final ModelRegistry modelRegistry;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final IndexingPressure indexingPressure;
|
||||
private final InferenceStats inferenceStats;
|
||||
private volatile long batchSizeInBytes;
|
||||
|
||||
public ShardBulkInferenceActionFilter(
|
||||
|
@ -119,13 +123,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
InferenceServiceRegistry inferenceServiceRegistry,
|
||||
ModelRegistry modelRegistry,
|
||||
XPackLicenseState licenseState,
|
||||
IndexingPressure indexingPressure
|
||||
IndexingPressure indexingPressure,
|
||||
InferenceStats inferenceStats
|
||||
) {
|
||||
this.clusterService = clusterService;
|
||||
this.inferenceServiceRegistry = inferenceServiceRegistry;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.licenseState = licenseState;
|
||||
this.indexingPressure = indexingPressure;
|
||||
this.inferenceStats = inferenceStats;
|
||||
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
|
||||
}
|
||||
|
@ -386,10 +392,12 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
public void onResponse(List<ChunkedInference> results) {
|
||||
try (onFinish) {
|
||||
var requestsIterator = requests.iterator();
|
||||
int success = 0;
|
||||
for (ChunkedInference result : results) {
|
||||
var request = requestsIterator.next();
|
||||
var acc = inferenceResults.get(request.bulkItemIndex);
|
||||
if (result instanceof ChunkedInferenceError error) {
|
||||
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
|
||||
acc.addFailure(
|
||||
new InferenceException(
|
||||
"Exception when running inference id [{}] on field [{}]",
|
||||
|
@ -399,6 +407,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
)
|
||||
);
|
||||
} else {
|
||||
success++;
|
||||
acc.addOrUpdateResponse(
|
||||
new FieldInferenceResponse(
|
||||
request.field(),
|
||||
|
@ -412,12 +421,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
);
|
||||
}
|
||||
}
|
||||
if (success > 0) {
|
||||
recordRequestCountMetrics(inferenceProvider.model, success, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception exc) {
|
||||
try (onFinish) {
|
||||
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
|
||||
for (FieldInferenceRequest request : requests) {
|
||||
addInferenceResponseFailure(
|
||||
request.bulkItemIndex,
|
||||
|
@ -444,6 +457,14 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|||
);
|
||||
}
|
||||
|
||||
private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) {
|
||||
Map<String, Object> requestCountAttributes = new HashMap<>();
|
||||
requestCountAttributes.putAll(modelAttributes(model));
|
||||
requestCountAttributes.putAll(responseAttributes(throwable));
|
||||
requestCountAttributes.put("inference_source", "semantic_text_bulk");
|
||||
inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
|
||||
* for the specified {@code item}.
|
||||
|
|
|
@ -66,6 +66,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin;
|
|||
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
|
||||
import org.elasticsearch.xpack.inference.model.TestModel;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
@ -80,6 +81,7 @@ import java.util.Optional;
|
|||
import java.util.Set;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
||||
|
@ -103,9 +105,11 @@ import static org.hamcrest.Matchers.notNullValue;
|
|||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.longThat;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
|
@ -127,7 +131,9 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parameters() throws Exception {
|
||||
return List.of(new Object[] { true }, new Object[] { false });
|
||||
List<Object[]> lst = new ArrayList<>();
|
||||
lst.add(new Object[] { true });
|
||||
return lst;
|
||||
}
|
||||
|
||||
@Before
|
||||
|
@ -142,7 +148,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testFilterNoop() throws Exception {
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -167,8 +181,16 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testLicenseInvalidForInference() throws InterruptedException {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
StaticModel model = StaticModel.createRandomInstance();
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false);
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
false,
|
||||
inferenceStats
|
||||
);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -205,13 +227,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testInferenceNotFound() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
StaticModel model = StaticModel.createRandomInstance();
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
|
@ -251,14 +275,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testItemFailures() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
||||
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
|
||||
model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
|
||||
|
@ -316,10 +341,30 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
request.setInferenceFieldMap(inferenceFieldMap);
|
||||
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
||||
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
||||
|
||||
AtomicInteger success = new AtomicInteger(0);
|
||||
AtomicInteger failed = new AtomicInteger(0);
|
||||
verify(inferenceStats.requestCount(), atMost(3)).incrementBy(anyLong(), assertArg(attributes -> {
|
||||
var statusCode = attributes.get("status_code");
|
||||
if (statusCode == null) {
|
||||
failed.incrementAndGet();
|
||||
assertThat(attributes.get("error.type"), is("IllegalArgumentException"));
|
||||
} else {
|
||||
success.incrementAndGet();
|
||||
assertThat(statusCode, is(200));
|
||||
}
|
||||
assertThat(attributes.get("task_type"), is(model.getTaskType().toString()));
|
||||
assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId()));
|
||||
assertThat(attributes.get("service"), is(model.getConfigurations().getService()));
|
||||
assertThat(attributes.get("inference_source"), is("semantic_text_bulk"));
|
||||
}));
|
||||
assertThat(success.get(), equalTo(1));
|
||||
assertThat(failed.get(), equalTo(2));
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testExplicitNull() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
||||
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
|
||||
model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
|
||||
|
@ -329,7 +374,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Map.of(model.getInferenceEntityId(), model),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
|
@ -394,13 +440,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testHandleEmptyInput() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
StaticModel model = StaticModel.createRandomInstance();
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(model.getInferenceEntityId(), model),
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
|
@ -447,6 +495,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testManyRandomDocs() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
Map<String, StaticModel> inferenceModelMap = new HashMap<>();
|
||||
int numModels = randomIntBetween(1, 3);
|
||||
for (int i = 0; i < numModels; i++) {
|
||||
|
@ -471,7 +520,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
modifiedRequests[id] = res[1];
|
||||
}
|
||||
|
||||
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
|
||||
ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
inferenceModelMap,
|
||||
NOOP_INDEXING_PRESSURE,
|
||||
useLegacyFormat,
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
||||
try {
|
||||
|
@ -503,6 +559,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
public void testIndexingPressure() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
|
||||
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
||||
final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
|
||||
|
@ -511,7 +568,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
|
||||
indexingPressure,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value");
|
||||
|
@ -619,6 +677,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
|
||||
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
|
||||
);
|
||||
|
@ -628,7 +687,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
||||
indexingPressure,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
|
||||
|
@ -702,6 +762,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
|
||||
);
|
||||
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
||||
sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar")));
|
||||
|
||||
|
@ -710,7 +771,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
||||
indexingPressure,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
|
@ -813,12 +875,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
.build()
|
||||
);
|
||||
|
||||
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
|
||||
final ShardBulkInferenceActionFilter filter = createFilter(
|
||||
threadPool,
|
||||
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
||||
indexingPressure,
|
||||
useLegacyFormat,
|
||||
true
|
||||
true,
|
||||
inferenceStats
|
||||
);
|
||||
|
||||
CountDownLatch chainExecuted = new CountDownLatch(1);
|
||||
|
@ -893,7 +957,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
Map<String, StaticModel> modelMap,
|
||||
IndexingPressure indexingPressure,
|
||||
boolean useLegacyFormat,
|
||||
boolean isLicenseValidForInference
|
||||
boolean isLicenseValidForInference,
|
||||
InferenceStats inferenceStats
|
||||
) {
|
||||
ModelRegistry modelRegistry = mock(ModelRegistry.class);
|
||||
Answer<?> unparsedModelAnswer = invocationOnMock -> {
|
||||
|
@ -970,7 +1035,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|||
inferenceServiceRegistry,
|
||||
modelRegistry,
|
||||
licenseState,
|
||||
indexingPressure
|
||||
indexingPressure,
|
||||
inferenceStats
|
||||
);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue