Add Bounded Window to Inference Models for Rescoring to Ensure Positive Score Range (#125694)
* apply bounded window inference model * linting * add unit tests * [CI] Auto commit changes from spotless * add additional tests * remove unused constructor --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
509a12058f
commit
e77bf808ab
|
@ -0,0 +1,5 @@
|
|||
pr: 125694
|
||||
summary: LTR score bounding
|
||||
area: Ranking
|
||||
type: bug
|
||||
issues: []
|
|
@ -0,0 +1,14 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
public interface BoundedInferenceModel extends InferenceModel {
|
||||
double getMinPredictedValue();
|
||||
|
||||
double getMaxPredictedValue();
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.elasticsearch.common.logging.LoggerMessageFormat;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class BoundedWindowInferenceModel implements BoundedInferenceModel {
|
||||
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
|
||||
|
||||
private final BoundedInferenceModel model;
|
||||
private final double minPredictedValue;
|
||||
private final double maxPredictedValue;
|
||||
private final double adjustmentValue;
|
||||
|
||||
public BoundedWindowInferenceModel(BoundedInferenceModel model) {
|
||||
this.model = model;
|
||||
this.minPredictedValue = model.getMinPredictedValue();
|
||||
this.maxPredictedValue = model.getMaxPredictedValue();
|
||||
|
||||
if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
|
||||
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
|
||||
} else {
|
||||
this.adjustmentValue = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] getFeatureNames() {
|
||||
return model.getFeatureNames();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TargetType targetType() {
|
||||
return model.targetType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(double[] features, InferenceConfig config) {
|
||||
return boundInferenceResultScores(model.infer(features, config));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsFeatureImportance() {
|
||||
return model.supportsFeatureImportance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "bounded_window[" + model.getName() + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
model.rewriteFeatureIndices(newFeatureIndexMapping);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return model.ramBytesUsed();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMinPredictedValue() {
|
||||
return minPredictedValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMaxPredictedValue() {
|
||||
return maxPredictedValue;
|
||||
}
|
||||
|
||||
private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
|
||||
// if the min value < the default minimum, slide the values up by the adjustment value
|
||||
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
|
||||
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
|
||||
|
||||
predictedValue += this.adjustmentValue;
|
||||
|
||||
return new RegressionInferenceResults(
|
||||
predictedValue,
|
||||
inferenceResult.getResultsField(),
|
||||
((RegressionInferenceResults) inferenceResult).getFeatureImportance()
|
||||
);
|
||||
}
|
||||
|
||||
throw new IllegalStateException(
|
||||
LoggerMessageFormat.format(
|
||||
"Model used within a {} should return a {} but got {} instead",
|
||||
BoundedWindowInferenceModel.class.getSimpleName(),
|
||||
RegressionInferenceResults.class.getSimpleName(),
|
||||
inferenceResult.getClass().getSimpleName()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "BoundedWindowInferenceModel{"
|
||||
+ "model="
|
||||
+ model
|
||||
+ ", minPredictedValue="
|
||||
+ getMinPredictedValue()
|
||||
+ ", maxPredictedValue="
|
||||
+ getMaxPredictedValue()
|
||||
+ '}';
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
|
|||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.util.CachedSupplier;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
|
@ -36,6 +37,7 @@ import java.util.LinkedHashSet;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
|
@ -52,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
|
|||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
|
||||
|
||||
public class EnsembleInferenceModel implements InferenceModel {
|
||||
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
|
||||
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
||||
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
|
||||
|
@ -97,6 +99,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
private final List<String> classificationLabels;
|
||||
private final double[] classificationWeights;
|
||||
private volatile boolean preparedForInference = false;
|
||||
private final Supplier<double[]> predictedValuesBoundariesSupplier;
|
||||
|
||||
private EnsembleInferenceModel(
|
||||
List<InferenceModel> models,
|
||||
|
@ -112,6 +115,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
this.classificationWeights = classificationWeights == null
|
||||
? null
|
||||
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
|
||||
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -328,21 +332,57 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "EnsembleInferenceModel{"
|
||||
+ "featureNames="
|
||||
+ Arrays.toString(featureNames)
|
||||
+ ", models="
|
||||
+ models
|
||||
+ ", outputAggregator="
|
||||
+ outputAggregator
|
||||
+ ", targetType="
|
||||
+ targetType
|
||||
+ ", classificationLabels="
|
||||
+ classificationLabels
|
||||
+ ", classificationWeights="
|
||||
+ Arrays.toString(classificationWeights)
|
||||
+ ", preparedForInference="
|
||||
+ preparedForInference
|
||||
+ '}';
|
||||
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
|
||||
|
||||
builder.append("featureNames=")
|
||||
.append(Arrays.toString(featureNames))
|
||||
.append(", models=")
|
||||
.append(models)
|
||||
.append(", outputAggregator=")
|
||||
.append(outputAggregator)
|
||||
.append(", targetType=")
|
||||
.append(targetType);
|
||||
|
||||
if (targetType == TargetType.CLASSIFICATION) {
|
||||
builder.append(", classificationLabels=")
|
||||
.append(classificationLabels)
|
||||
.append(", classificationWeights=")
|
||||
.append(Arrays.toString(classificationWeights));
|
||||
} else if (targetType == TargetType.REGRESSION) {
|
||||
builder.append(", minPredictedValue=")
|
||||
.append(getMinPredictedValue())
|
||||
.append(", maxPredictedValue=")
|
||||
.append(getMaxPredictedValue());
|
||||
}
|
||||
|
||||
builder.append(", preparedForInference=").append(preparedForInference);
|
||||
|
||||
return builder.append('}').toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMinPredictedValue() {
|
||||
return this.predictedValuesBoundariesSupplier.get()[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMaxPredictedValue() {
|
||||
return this.predictedValuesBoundariesSupplier.get()[1];
|
||||
}
|
||||
|
||||
private double[] initModelBoundaries() {
|
||||
double[] modelsMinBoundaries = new double[models.size()];
|
||||
double[] modelsMaxBoundaries = new double[models.size()];
|
||||
int i = 0;
|
||||
for (InferenceModel model : models) {
|
||||
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
|
||||
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
|
||||
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
|
||||
} else {
|
||||
throw new IllegalStateException("All submodels have to be bounded");
|
||||
}
|
||||
}
|
||||
|
||||
return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentParser;
|
|||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -79,13 +80,21 @@ public class InferenceDefinition {
|
|||
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
preProcess(fields);
|
||||
|
||||
InferenceModel inferenceModel = trainedModel;
|
||||
|
||||
if (config instanceof LearningToRankConfig) {
|
||||
assert trainedModel instanceof BoundedInferenceModel;
|
||||
inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel);
|
||||
}
|
||||
|
||||
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Feature importance is not supported for the configured model of type [{}]",
|
||||
trainedModel.getName()
|
||||
);
|
||||
}
|
||||
return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
|
||||
return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
|
||||
}
|
||||
|
||||
public TargetType getTargetType() {
|
||||
|
|
|
@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
|||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
|
||||
|
||||
public class TreeInferenceModel implements InferenceModel {
|
||||
public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
|
||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
||||
|
@ -90,7 +90,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
private String[] featureNames;
|
||||
private final TargetType targetType;
|
||||
private List<String> classificationLabels;
|
||||
private final double highOrderCategory;
|
||||
private final double[] leafBoundaries;
|
||||
private final int maxDepth;
|
||||
private final int leafSize;
|
||||
private volatile boolean preparedForInference = false;
|
||||
|
@ -108,7 +108,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
|
||||
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.highOrderCategory = maxLeafValue();
|
||||
this.leafBoundaries = getLeafBoundaries();
|
||||
int leafSize = 1;
|
||||
for (Node node : this.nodes) {
|
||||
if (node instanceof LeafNode leafNode) {
|
||||
|
@ -218,7 +218,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
}
|
||||
// If we are classification, we should assume that the inference return value is whole.
|
||||
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
|
||||
double maxCategory = this.highOrderCategory;
|
||||
double maxCategory = getHighOrderCategory();
|
||||
// If we are classification, we should assume that the largest leaf value is whole.
|
||||
assert maxCategory == Math.rint(maxCategory);
|
||||
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
|
||||
|
@ -366,21 +366,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return size;
|
||||
}
|
||||
|
||||
private double maxLeafValue() {
|
||||
if (targetType != TargetType.CLASSIFICATION) {
|
||||
return Double.NaN;
|
||||
}
|
||||
double max = 0.0;
|
||||
private double[] getLeafBoundaries() {
|
||||
double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE };
|
||||
|
||||
for (Node node : this.nodes) {
|
||||
if (node instanceof LeafNode leafNode) {
|
||||
if (leafNode.leafValue.length > 1) {
|
||||
return leafNode.leafValue.length;
|
||||
return new double[] { 0, leafNode.leafValue.length };
|
||||
} else {
|
||||
max = Math.max(leafNode.leafValue[0], max);
|
||||
bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]);
|
||||
bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
return bounds;
|
||||
}
|
||||
|
||||
public Node[] getNodes() {
|
||||
|
@ -389,24 +388,35 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TreeInferenceModel{"
|
||||
+ "nodes="
|
||||
+ Arrays.toString(nodes)
|
||||
+ ", featureNames="
|
||||
+ Arrays.toString(featureNames)
|
||||
+ ", targetType="
|
||||
+ targetType
|
||||
+ ", classificationLabels="
|
||||
+ classificationLabels
|
||||
+ ", highOrderCategory="
|
||||
+ highOrderCategory
|
||||
+ ", maxDepth="
|
||||
+ maxDepth
|
||||
+ ", leafSize="
|
||||
+ leafSize
|
||||
+ ", preparedForInference="
|
||||
+ preparedForInference
|
||||
+ '}';
|
||||
StringBuilder builder = new StringBuilder("TreeInferenceModel{");
|
||||
|
||||
builder.append("nodes=")
|
||||
.append(Arrays.toString(nodes))
|
||||
.append(", featureNames=")
|
||||
.append(Arrays.toString(featureNames))
|
||||
.append(", targetType=")
|
||||
.append(targetType);
|
||||
|
||||
if (targetType == TargetType.CLASSIFICATION) {
|
||||
builder.append(", classificationLabels=")
|
||||
.append(classificationLabels)
|
||||
.append(", highOrderCategory=")
|
||||
.append(getHighOrderCategory());
|
||||
} else if (targetType == TargetType.REGRESSION) {
|
||||
builder.append(", minPredictedValue=")
|
||||
.append(getMinPredictedValue())
|
||||
.append(", maxPredictedValue=")
|
||||
.append(getMaxPredictedValue());
|
||||
}
|
||||
|
||||
builder.append(", maxDepth=")
|
||||
.append(maxDepth)
|
||||
.append(", leafSize=")
|
||||
.append(leafSize)
|
||||
.append(", preparedForInference=")
|
||||
.append(preparedForInference);
|
||||
|
||||
return builder.append('}').toString();
|
||||
}
|
||||
|
||||
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
||||
|
@ -420,6 +430,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return Math.max(depthLeft, depthRight) + 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMinPredictedValue() {
|
||||
return leafBoundaries[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getMaxPredictedValue() {
|
||||
return leafBoundaries[1];
|
||||
}
|
||||
|
||||
private double getHighOrderCategory() {
|
||||
return getMaxPredictedValue();
|
||||
}
|
||||
|
||||
static class NodeBuilder {
|
||||
|
||||
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class BoundedWindowInferenceModelTests extends ESTestCase {
|
||||
|
||||
private static final List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
|
||||
public void testBoundsSetting() throws IOException {
|
||||
BoundedWindowInferenceModel testModel = getModel(-2.0, 5.2, 10.5);
|
||||
assertThat(testModel.getMinPredictedValue(), equalTo(-2.0));
|
||||
assertThat(testModel.getMaxPredictedValue(), equalTo(10.5));
|
||||
}
|
||||
|
||||
public void testInferenceScoresWithoutAdjustment() throws IOException {
|
||||
BoundedWindowInferenceModel testModel = getModel(1.0, 5.2, 10.5);
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
|
||||
featureMap,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap()
|
||||
)).value();
|
||||
assertThat(lowResultValue, equalTo(1.0));
|
||||
|
||||
featureVector = Arrays.asList(12.0, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
|
||||
featureMap,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap()
|
||||
)).value();
|
||||
assertThat(highResultValue, equalTo(10.5));
|
||||
|
||||
double[] featureArray = new double[2];
|
||||
featureArray[0] = 12.0;
|
||||
featureArray[1] = 0.0;
|
||||
Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
|
||||
.value();
|
||||
assertThat(highResultValueFromFeatures, equalTo(10.5));
|
||||
}
|
||||
|
||||
public void testInferenceScoresWithAdjustment() throws IOException {
|
||||
BoundedWindowInferenceModel testModel = getModel(-5.0, 1.2, 6.5);
|
||||
|
||||
List<Double> featureVector = Arrays.asList(-10.0, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
Double lowResultValue = ((SingleValueInferenceResults) testModel.infer(
|
||||
featureMap,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap()
|
||||
)).value();
|
||||
assertThat(lowResultValue, equalTo(0.0));
|
||||
|
||||
featureVector = Arrays.asList(12.0, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
Double highResultValue = ((SingleValueInferenceResults) testModel.infer(
|
||||
featureMap,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap()
|
||||
)).value();
|
||||
assertThat(highResultValue, equalTo(11.5));
|
||||
|
||||
double[] featureArray = new double[2];
|
||||
featureArray[0] = 12.0;
|
||||
featureArray[1] = 0.0;
|
||||
Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS))
|
||||
.value();
|
||||
assertThat(highResultValueFromFeatures, equalTo(11.5));
|
||||
}
|
||||
|
||||
private BoundedWindowInferenceModel getModel(double lowerBoundValue, double midValue, double upperBoundValue) throws IOException {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), upperBoundValue);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
builder.addLeaf(leftChildNode.getLeftChild(), lowerBoundValue);
|
||||
builder.addLeaf(leftChildNode.getRightChild(), midValue);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree treeObject = builder.setFeatureNames(featureNames).build();
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
return new BoundedWindowInferenceModel(tree);
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
||||
}
|
|
@ -39,6 +39,7 @@ import java.util.stream.IntStream;
|
|||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
@ -537,6 +538,40 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
|
||||
}
|
||||
|
||||
public void testMinAndMaxBoundaries() throws IOException {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.3))
|
||||
.addNode(TreeNode.builder(2).setThreshold(0.8).setSplitFeature(1).setLeftChild(3).setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.1))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(0.2))
|
||||
.build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.5))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
Ensemble ensembleObject = Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.setOutputAggregator(new WeightedSum(new double[] { 0.5, 0.5 }))
|
||||
.build();
|
||||
|
||||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(
|
||||
ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent
|
||||
);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
assertThat(ensemble.getMinPredictedValue(), equalTo(1.0));
|
||||
assertThat(ensemble.getMaxPredictedValue(), equalTo(1.8));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
|
|
@ -13,6 +13,8 @@ import org.elasticsearch.common.bytes.BytesArray;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.core.Strings;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.inference.InferenceResults;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.NamedXContentRegistry;
|
||||
|
@ -25,17 +27,26 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.ParseException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.ENSEMBLE_MODEL;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.TREE_MODEL;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -176,6 +187,35 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testWithLearningToRankConfiguration() throws IOException {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), -2.0);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
builder.addLeaf(leftChildNode.getLeftChild(), 0.2);
|
||||
builder.addLeaf(leftChildNode.getRightChild(), 1.5);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree treeObject = builder.setFeatureNames(featureNames).build();
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
BoundedWindowInferenceModel testModel = new BoundedWindowInferenceModel(tree);
|
||||
|
||||
InferenceDefinition definition = new InferenceDefinition(testModel, null);
|
||||
LearningToRankConfig config = new LearningToRankConfig(
|
||||
randomBoolean() ? null : randomIntBetween(0, 10),
|
||||
randomBoolean()
|
||||
? null
|
||||
: Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : randomMap(0, 10, () -> Tuple.tuple(randomIdentifier(), randomIdentifier()))
|
||||
);
|
||||
|
||||
InferenceResults results = definition.infer(Map.of("foo", 1.0, "bar", 0.0), config);
|
||||
|
||||
assertThat(results.predictedValue(), equalTo(2.0));
|
||||
}
|
||||
|
||||
public static String getClassificationDefinition(boolean customPreprocessor) {
|
||||
return Strings.format("""
|
||||
{
|
||||
|
|
|
@ -284,6 +284,23 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
assertThat(featureImportance[1][0], closeTo(2.5, eps));
|
||||
}
|
||||
|
||||
public void testMinAndMaxBoundaries() throws IOException {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), 0.3);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
|
||||
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree treeObject = builder.setFeatureNames(featureNames).build();
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
assertThat(tree.getMinPredictedValue(), equalTo(0.1));
|
||||
assertThat(tree.getMaxPredictedValue(), equalTo(0.3));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue