Add Bounded Window to Inference Models for Rescoring to Ensure Positive Score Range (#125694)

* apply bounded window inference model

* linting

* add unit tests

* [CI] Auto commit changes from spotless

* add additional tests

* remove unused constructor

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Mark J. Hoy 2025-04-02 11:50:04 -04:00 committed by GitHub
parent 509a12058f
commit e77bf808ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 471 additions and 48 deletions

View File

@ -0,0 +1,5 @@
pr: 125694
summary: LTR score bounding
area: Ranking
type: bug
issues: []

View File

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
public interface BoundedInferenceModel extends InferenceModel {
double getMinPredictedValue();
double getMaxPredictedValue();
}

View File

@ -0,0 +1,123 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.util.Map;
public class BoundedWindowInferenceModel implements BoundedInferenceModel {
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
private final BoundedInferenceModel model;
private final double minPredictedValue;
private final double maxPredictedValue;
private final double adjustmentValue;
public BoundedWindowInferenceModel(BoundedInferenceModel model) {
this.model = model;
this.minPredictedValue = model.getMinPredictedValue();
this.maxPredictedValue = model.getMaxPredictedValue();
if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
} else {
this.adjustmentValue = 0.0;
}
}
@Override
public String[] getFeatureNames() {
return model.getFeatureNames();
}
@Override
public TargetType targetType() {
return model.targetType();
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
}
@Override
public InferenceResults infer(double[] features, InferenceConfig config) {
return boundInferenceResultScores(model.infer(features, config));
}
@Override
public boolean supportsFeatureImportance() {
return model.supportsFeatureImportance();
}
@Override
public String getName() {
return "bounded_window[" + model.getName() + "]";
}
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
model.rewriteFeatureIndices(newFeatureIndexMapping);
}
@Override
public long ramBytesUsed() {
return model.ramBytesUsed();
}
@Override
public double getMinPredictedValue() {
return minPredictedValue;
}
@Override
public double getMaxPredictedValue() {
return maxPredictedValue;
}
private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
// if the min value < the default minimum, slide the values up by the adjustment value
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
predictedValue += this.adjustmentValue;
return new RegressionInferenceResults(
predictedValue,
inferenceResult.getResultsField(),
((RegressionInferenceResults) inferenceResult).getFeatureImportance()
);
}
throw new IllegalStateException(
LoggerMessageFormat.format(
"Model used within a {} should return a {} but got {} instead",
BoundedWindowInferenceModel.class.getSimpleName(),
RegressionInferenceResults.class.getSimpleName(),
inferenceResult.getClass().getSimpleName()
)
);
}
@Override
public String toString() {
return "BoundedWindowInferenceModel{"
+ "model="
+ model
+ ", minPredictedValue="
+ getMinPredictedValue()
+ ", maxPredictedValue="
+ getMaxPredictedValue()
+ '}';
}
}

View File

@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
@ -36,6 +37,7 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@ -52,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
public class EnsembleInferenceModel implements InferenceModel {
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
@ -97,6 +99,7 @@ public class EnsembleInferenceModel implements InferenceModel {
private final List<String> classificationLabels;
private final double[] classificationWeights;
private volatile boolean preparedForInference = false;
private final Supplier<double[]> predictedValuesBoundariesSupplier;
private EnsembleInferenceModel(
List<InferenceModel> models,
@ -112,6 +115,7 @@ public class EnsembleInferenceModel implements InferenceModel {
this.classificationWeights = classificationWeights == null
? null
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
}
@Override
@ -328,21 +332,57 @@ public class EnsembleInferenceModel implements InferenceModel {
@Override
public String toString() {
return "EnsembleInferenceModel{"
+ "featureNames="
+ Arrays.toString(featureNames)
+ ", models="
+ models
+ ", outputAggregator="
+ outputAggregator
+ ", targetType="
+ targetType
+ ", classificationLabels="
+ classificationLabels
+ ", classificationWeights="
+ Arrays.toString(classificationWeights)
+ ", preparedForInference="
+ preparedForInference
+ '}';
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
builder.append("featureNames=")
.append(Arrays.toString(featureNames))
.append(", models=")
.append(models)
.append(", outputAggregator=")
.append(outputAggregator)
.append(", targetType=")
.append(targetType);
if (targetType == TargetType.CLASSIFICATION) {
builder.append(", classificationLabels=")
.append(classificationLabels)
.append(", classificationWeights=")
.append(Arrays.toString(classificationWeights));
} else if (targetType == TargetType.REGRESSION) {
builder.append(", minPredictedValue=")
.append(getMinPredictedValue())
.append(", maxPredictedValue=")
.append(getMaxPredictedValue());
}
builder.append(", preparedForInference=").append(preparedForInference);
return builder.append('}').toString();
}
@Override
public double getMinPredictedValue() {
return this.predictedValuesBoundariesSupplier.get()[0];
}
@Override
public double getMaxPredictedValue() {
return this.predictedValuesBoundariesSupplier.get()[1];
}
private double[] initModelBoundaries() {
double[] modelsMinBoundaries = new double[models.size()];
double[] modelsMaxBoundaries = new double[models.size()];
int i = 0;
for (InferenceModel model : models) {
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
} else {
throw new IllegalStateException("All submodels have to be bounded");
}
}
return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -79,13 +80,21 @@ public class InferenceDefinition {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
InferenceModel inferenceModel = trainedModel;
if (config instanceof LearningToRankConfig) {
assert trainedModel instanceof BoundedInferenceModel;
inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel);
}
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName()
);
}
return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}
public TargetType getTargetType() {

View File

@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
public class TreeInferenceModel implements InferenceModel {
public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel {
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
@ -90,7 +90,7 @@ public class TreeInferenceModel implements InferenceModel {
private String[] featureNames;
private final TargetType targetType;
private List<String> classificationLabels;
private final double highOrderCategory;
private final double[] leafBoundaries;
private final int maxDepth;
private final int leafSize;
private volatile boolean preparedForInference = false;
@ -108,7 +108,7 @@ public class TreeInferenceModel implements InferenceModel {
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highOrderCategory = maxLeafValue();
this.leafBoundaries = getLeafBoundaries();
int leafSize = 1;
for (Node node : this.nodes) {
if (node instanceof LeafNode leafNode) {
@ -218,7 +218,7 @@ public class TreeInferenceModel implements InferenceModel {
}
// If we are classification, we should assume that the inference return value is whole.
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
double maxCategory = this.highOrderCategory;
double maxCategory = getHighOrderCategory();
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
@ -366,21 +366,20 @@ public class TreeInferenceModel implements InferenceModel {
return size;
}
private double maxLeafValue() {
if (targetType != TargetType.CLASSIFICATION) {
return Double.NaN;
}
double max = 0.0;
private double[] getLeafBoundaries() {
double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE };
for (Node node : this.nodes) {
if (node instanceof LeafNode leafNode) {
if (leafNode.leafValue.length > 1) {
return leafNode.leafValue.length;
return new double[] { 0, leafNode.leafValue.length };
} else {
max = Math.max(leafNode.leafValue[0], max);
bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]);
bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]);
}
}
}
return max;
return bounds;
}
public Node[] getNodes() {
@ -389,24 +388,35 @@ public class TreeInferenceModel implements InferenceModel {
@Override
public String toString() {
return "TreeInferenceModel{"
+ "nodes="
+ Arrays.toString(nodes)
+ ", featureNames="
+ Arrays.toString(featureNames)
+ ", targetType="
+ targetType
+ ", classificationLabels="
+ classificationLabels
+ ", highOrderCategory="
+ highOrderCategory
+ ", maxDepth="
+ maxDepth
+ ", leafSize="
+ leafSize
+ ", preparedForInference="
+ preparedForInference
+ '}';
StringBuilder builder = new StringBuilder("TreeInferenceModel{");
builder.append("nodes=")
.append(Arrays.toString(nodes))
.append(", featureNames=")
.append(Arrays.toString(featureNames))
.append(", targetType=")
.append(targetType);
if (targetType == TargetType.CLASSIFICATION) {
builder.append(", classificationLabels=")
.append(classificationLabels)
.append(", highOrderCategory=")
.append(getHighOrderCategory());
} else if (targetType == TargetType.REGRESSION) {
builder.append(", minPredictedValue=")
.append(getMinPredictedValue())
.append(", maxPredictedValue=")
.append(getMaxPredictedValue());
}
builder.append(", maxDepth=")
.append(maxDepth)
.append(", leafSize=")
.append(leafSize)
.append(", preparedForInference=")
.append(preparedForInference);
return builder.append('}').toString();
}
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
@ -420,6 +430,20 @@ public class TreeInferenceModel implements InferenceModel {
return Math.max(depthLeft, depthRight) + 1;
}
@Override
public double getMinPredictedValue() {
return leafBoundaries[0];
}
@Override
public double getMaxPredictedValue() {
return leafBoundaries[1];
}
private double getHighOrderCategory() {
return getMaxPredictedValue();
}
static class NodeBuilder {
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(

View File

@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.equalTo;
public class BoundedWindowInferenceModelTests extends ESTestCase {
private static final List<String> featureNames = Arrays.asList("foo", "bar");
public void testBoundsSetting() throws IOException {
BoundedWindowInferenceModel testModel = getModel(-2.0, 5.2, 10.5);
assertThat(testModel.getMinPredictedValue(), equalTo(-2.0));
assertThat(testModel.getMaxPredictedValue(), equalTo(10.5));
}
public void testInferenceScoresWithoutAdjustment() throws IOException {
BoundedWindowInferenceModel testModel = getModel(1.0, 5.2, 10.5);
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
featureMap,
RegressionConfig.EMPTY_PARAMS,
Collections.emptyMap()
)).value();
assertThat(lowResultValue, equalTo(1.0));
featureVector = Arrays.asList(12.0, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
featureMap,
RegressionConfig.EMPTY_PARAMS,
Collections.emptyMap()
)).value();
assertThat(highResultValue, equalTo(10.5));
double[] featureArray = new double[2];
featureArray[0] = 12.0;
featureArray[1] = 0.0;
Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
.value();
assertThat(highResultValueFromFeatures, equalTo(10.5));
}
public void testInferenceScoresWithAdjustment() throws IOException {
BoundedWindowInferenceModel testModel = getModel(-5.0, 1.2, 6.5);
List<Double> featureVector = Arrays.asList(-10.0, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
featureMap,
RegressionConfig.EMPTY_PARAMS,
Collections.emptyMap()
)).value();
assertThat(lowResultValue, equalTo(0.0));
featureVector = Arrays.asList(12.0, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
featureMap,
RegressionConfig.EMPTY_PARAMS,
Collections.emptyMap()
)).value();
assertThat(highResultValue, equalTo(11.5));
double[] featureArray = new double[2];
featureArray[0] = 12.0;
featureArray[1] = 0.0;
Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
.value();
assertThat(highResultValueFromFeatures, equalTo(11.5));
}
private BoundedWindowInferenceModel getModel(double lowerBoundValue, double midValue, double upperBoundValue) throws IOException {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), upperBoundValue);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), lowerBoundValue);
builder.addLeaf(leftChildNode.getRightChild(), midValue);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).build();
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
return new BoundedWindowInferenceModel(tree);
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -39,6 +39,7 @@ import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
@ -537,6 +538,40 @@ public class EnsembleInferenceModelTests extends ESTestCase {
assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
}
public void testMinAndMaxBoundaries() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2).setThreshold(0.8).setSplitFeature(1).setLeftChild(3).setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensembleObject = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(new double[] { 0.5, 0.5 }))
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(
ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent
);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
assertThat(ensemble.getMinPredictedValue(), equalTo(1.0));
assertThat(ensemble.getMaxPredictedValue(), equalTo(1.8));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}

View File

@ -13,6 +13,8 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.NamedXContentRegistry;
@ -25,17 +27,26 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.ENSEMBLE_MODEL;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.TREE_MODEL;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
@ -176,6 +187,35 @@ public class InferenceDefinitionTests extends ESTestCase {
}
}
public void testWithLearningToRankConfiguration() throws IOException {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), -2.0);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.2);
builder.addLeaf(leftChildNode.getRightChild(), 1.5);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).build();
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
BoundedWindowInferenceModel testModel = new BoundedWindowInferenceModel(tree);
InferenceDefinition definition = new InferenceDefinition(testModel, null);
LearningToRankConfig config = new LearningToRankConfig(
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean()
? null
: Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList()),
randomBoolean() ? null : randomMap(0, 10, () -> Tuple.tuple(randomIdentifier(), randomIdentifier()))
);
InferenceResults results = definition.infer(Map.of("foo", 1.0, "bar", 0.0), config);
assertThat(results.predictedValue(), equalTo(2.0));
}
public static String getClassificationDefinition(boolean customPreprocessor) {
return Strings.format("""
{

View File

@ -284,6 +284,23 @@ public class TreeInferenceModelTests extends ESTestCase {
assertThat(featureImportance[1][0], closeTo(2.5, eps));
}
public void testMinAndMaxBoundaries() throws IOException {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).build();
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
assertThat(tree.getMinPredictedValue(), equalTo(0.1));
assertThat(tree.getMaxPredictedValue(), equalTo(0.3));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}