Increment inference stats counter for shard bulk inference calls (#129140)

This change updates the inference stats counter to include chunked inference calls performed by the shard bulk inference filter on all semantic text fields.
It ensures that usage of inference on semantic text fields is properly recorded in the stats.
This commit is contained in:
Jim Ferenczi 2025-06-10 15:30:33 +01:00 committed by GitHub
parent b3becfa678
commit 42dec5b41f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 21 deletions

View File

@ -0,0 +1,5 @@
pr: 129140
summary: Increment inference stats counter for shard bulk inference calls
area: Machine Learning
type: enhancement
issues: []

View File

@ -344,22 +344,24 @@ public class InferencePlugin extends Plugin
}
inferenceServiceRegistry.set(serviceRegistry);
var meterRegistry = services.telemetryProvider().getMeterRegistry();
var inferenceStats = InferenceStats.create(meterRegistry);
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);
var actionFilter = new ShardBulkInferenceActionFilter(
services.clusterService(),
serviceRegistry,
modelRegistry.get(),
getLicenseState(),
services.indexingPressure()
services.indexingPressure(),
inferenceStats
);
shardBulkInferenceActionFilter.set(actionFilter);
var meterRegistry = services.telemetryProvider().getMeterRegistry();
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
components.add(serviceRegistry);
components.add(modelRegistry.get());
components.add(httpClientManager);
components.add(inferenceStats);
components.add(inferenceStatsBinding);
// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
// if the rate limiting feature flags are enabled, otherwise provide noop implementation

View File

@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import java.io.IOException;
import java.util.ArrayList;
@ -78,6 +79,8 @@ import java.util.stream.Collectors;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
@ -112,6 +115,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final IndexingPressure indexingPressure;
private final InferenceStats inferenceStats;
private volatile long batchSizeInBytes;
public ShardBulkInferenceActionFilter(
@ -119,13 +123,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
IndexingPressure indexingPressure
IndexingPressure indexingPressure,
InferenceStats inferenceStats
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.indexingPressure = indexingPressure;
this.inferenceStats = inferenceStats;
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
}
@ -386,10 +392,12 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
public void onResponse(List<ChunkedInference> results) {
try (onFinish) {
var requestsIterator = requests.iterator();
int success = 0;
for (ChunkedInference result : results) {
var request = requestsIterator.next();
var acc = inferenceResults.get(request.bulkItemIndex);
if (result instanceof ChunkedInferenceError error) {
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
acc.addFailure(
new InferenceException(
"Exception when running inference id [{}] on field [{}]",
@ -399,6 +407,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
)
);
} else {
success++;
acc.addOrUpdateResponse(
new FieldInferenceResponse(
request.field(),
@ -412,12 +421,16 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
);
}
}
if (success > 0) {
recordRequestCountMetrics(inferenceProvider.model, success, null);
}
}
}
@Override
public void onFailure(Exception exc) {
try (onFinish) {
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure(
request.bulkItemIndex,
@ -444,6 +457,14 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
);
}
private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) {
Map<String, Object> requestCountAttributes = new HashMap<>();
requestCountAttributes.putAll(modelAttributes(model));
requestCountAttributes.putAll(responseAttributes(throwable));
requestCountAttributes.put("inference_source", "semantic_text_bulk");
inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes);
}
/**
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
* for the specified {@code item}.

View File

@ -66,6 +66,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.junit.After;
import org.junit.Before;
import org.mockito.stubbing.Answer;
@ -80,6 +81,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
@ -103,9 +105,11 @@ import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.longThat;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -127,7 +131,9 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return List.of(new Object[] { true }, new Object[] { false });
List<Object[]> lst = new ArrayList<>();
lst.add(new Object[] { true });
return lst;
}
@Before
@ -142,7 +148,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -167,8 +181,16 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() throws InterruptedException {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false);
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
false,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -205,13 +227,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testInferenceNotFound() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
@ -251,14 +275,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testItemFailures() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true
true,
inferenceStats
);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
@ -316,10 +341,30 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
request.setInferenceFieldMap(inferenceFieldMap);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
AtomicInteger success = new AtomicInteger(0);
AtomicInteger failed = new AtomicInteger(0);
verify(inferenceStats.requestCount(), atMost(3)).incrementBy(anyLong(), assertArg(attributes -> {
var statusCode = attributes.get("status_code");
if (statusCode == null) {
failed.incrementAndGet();
assertThat(attributes.get("error.type"), is("IllegalArgumentException"));
} else {
success.incrementAndGet();
assertThat(statusCode, is(200));
}
assertThat(attributes.get("task_type"), is(model.getTaskType().toString()));
assertThat(attributes.get("model_id"), is(model.getServiceSettings().modelId()));
assertThat(attributes.get("service"), is(model.getConfigurations().getService()));
assertThat(attributes.get("inference_source"), is("semantic_text_bulk"));
}));
assertThat(success.get(), equalTo(1));
assertThat(failed.get(), equalTo(2));
}
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testExplicitNull() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
@ -329,7 +374,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Map.of(model.getInferenceEntityId(), model),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
@ -394,13 +440,15 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testHandleEmptyInput() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
@ -447,6 +495,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testManyRandomDocs() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
Map<String, StaticModel> inferenceModelMap = new HashMap<>();
int numModels = randomIntBetween(1, 3);
for (int i = 0; i < numModels; i++) {
@ -471,7 +520,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
modifiedRequests[id] = res[1];
}
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
NOOP_INDEXING_PRESSURE,
useLegacyFormat,
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
@ -503,6 +559,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testIndexingPressure() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
@ -511,7 +568,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
indexingPressure,
useLegacyFormat,
true
true,
inferenceStats
);
XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value");
@ -619,6 +677,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
);
@ -628,7 +687,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
indexingPressure,
useLegacyFormat,
true
true,
inferenceStats
);
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
@ -702,6 +762,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
);
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar")));
@ -710,7 +771,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
indexingPressure,
useLegacyFormat,
true
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
@ -813,12 +875,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
.build()
);
final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
final ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
indexingPressure,
useLegacyFormat,
true
true,
inferenceStats
);
CountDownLatch chainExecuted = new CountDownLatch(1);
@ -893,7 +957,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
Map<String, StaticModel> modelMap,
IndexingPressure indexingPressure,
boolean useLegacyFormat,
boolean isLicenseValidForInference
boolean isLicenseValidForInference,
InferenceStats inferenceStats
) {
ModelRegistry modelRegistry = mock(ModelRegistry.class);
Answer<?> unparsedModelAnswer = invocationOnMock -> {
@ -970,7 +1035,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
inferenceServiceRegistry,
modelRegistry,
licenseState,
indexingPressure
indexingPressure,
inferenceStats
);
}