[ML] Add early stopping DFA configuration parameter (#68099)

The PR adds early_stopping_enabled optional data frame analysis configuration parameter. The enhancement was already described in elastic/ml-cpp#1676 and so I mark it here as non-issue.
This commit is contained in:
Valeriy Khakhutskyy 2021-02-01 11:41:28 +01:00 committed by GitHub
parent 4cbe61467c
commit 78368428b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 257 additions and 104 deletions

View File

@ -63,6 +63,7 @@ public class Classification implements DataFrameAnalysis {
static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Classification, Void> PARSER =
@ -88,7 +89,8 @@ public class Classification implements DataFrameAnalysis {
(Double) a[15],
(Double) a[16],
(Double) a[17],
(Integer) a[18]
(Integer) a[18],
(Boolean) a[19]
));
static {
@ -115,6 +117,7 @@ public class Classification implements DataFrameAnalysis {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
}
private final String dependentVariable;
@ -136,6 +139,7 @@ public class Classification implements DataFrameAnalysis {
private final Double softTreeDepthTolerance;
private final Double downsampleFactor;
private final Integer maxOptimizationRoundsPerHyperparameter;
private final Boolean earlyStoppingEnabled;
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@ -144,7 +148,7 @@ public class Classification implements DataFrameAnalysis {
@Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors,
@Nullable Double alpha, @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit,
@Nullable Double softTreeDepthTolerance, @Nullable Double downsampleFactor,
@Nullable Integer maxOptimizationRoundsPerHyperparameter) {
@Nullable Integer maxOptimizationRoundsPerHyperparameter, @Nullable Boolean earlyStoppingEnabled) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -164,6 +168,7 @@ public class Classification implements DataFrameAnalysis {
this.softTreeDepthTolerance = softTreeDepthTolerance;
this.downsampleFactor = downsampleFactor;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.earlyStoppingEnabled = earlyStoppingEnabled;
}
@Override
@ -247,6 +252,10 @@ public class Classification implements DataFrameAnalysis {
return maxOptimizationRoundsPerHyperparameter;
}
public Boolean getEarlyStoppingEnable() {
return earlyStoppingEnabled;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -305,6 +314,9 @@ public class Classification implements DataFrameAnalysis {
if (maxOptimizationRoundsPerHyperparameter != null) {
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
}
if (earlyStoppingEnabled != null) {
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
}
builder.endObject();
return builder;
}
@ -313,7 +325,8 @@ public class Classification implements DataFrameAnalysis {
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors, alpha,
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
earlyStoppingEnabled);
}
@Override
@ -339,7 +352,8 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
&& Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
&& Objects.equals(downsampleFactor, that.downsampleFactor)
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
}
@Override
@ -380,6 +394,7 @@ public class Classification implements DataFrameAnalysis {
private Double softTreeDepthTolerance;
private Double downsampleFactor;
private Integer maxOptimizationRoundsPerHyperparameter;
private Boolean earlyStoppingEnabled;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -475,11 +490,16 @@ public class Classification implements DataFrameAnalysis {
return this;
}
public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
this.earlyStoppingEnabled = earlyStoppingEnabled;
return this;
}
public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
classAssignmentObjective, featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance,
downsampleFactor, maxOptimizationRoundsPerHyperparameter);
downsampleFactor, maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
}
}
}

View File

@ -65,6 +65,7 @@ public class Regression implements DataFrameAnalysis {
static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Regression, Void> PARSER =
@ -90,7 +91,8 @@ public class Regression implements DataFrameAnalysis {
(Double) a[15],
(Double) a[16],
(Double) a[17],
(Integer) a[18]
(Integer) a[18],
(Boolean) a[19]
));
static {
@ -116,6 +118,7 @@ public class Regression implements DataFrameAnalysis {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
}
private final String dependentVariable;
@ -137,6 +140,7 @@ public class Regression implements DataFrameAnalysis {
private final Double softTreeDepthTolerance;
private final Double downsampleFactor;
private final Integer maxOptimizationRoundsPerHyperparameter;
private final Boolean earlyStoppingEnabled;
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@ -144,7 +148,8 @@ public class Regression implements DataFrameAnalysis {
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors, @Nullable Double alpha,
@Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit, @Nullable Double softTreeDepthTolerance,
@Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
@Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter,
@Nullable Boolean earlyStoppingEnabled) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -164,6 +169,7 @@ public class Regression implements DataFrameAnalysis {
this.softTreeDepthTolerance = softTreeDepthTolerance;
this.downsampleFactor = downsampleFactor;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.earlyStoppingEnabled = earlyStoppingEnabled;
}
@Override
@ -247,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
return maxOptimizationRoundsPerHyperparameter;
}
public Boolean getEarlyStoppingEnabled() {
return earlyStoppingEnabled;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -305,6 +315,9 @@ public class Regression implements DataFrameAnalysis {
if (maxOptimizationRoundsPerHyperparameter != null) {
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
}
if (earlyStoppingEnabled != null) {
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
}
builder.endObject();
return builder;
}
@ -313,7 +326,8 @@ public class Regression implements DataFrameAnalysis {
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors, alpha,
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
earlyStoppingEnabled);
}
@Override
@ -339,7 +353,8 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
&& Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
&& Objects.equals(downsampleFactor, that.downsampleFactor)
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
}
@Override
@ -367,6 +382,7 @@ public class Regression implements DataFrameAnalysis {
private Double softTreeDepthTolerance;
private Double downsampleFactor;
private Integer maxOptimizationRoundsPerHyperparameter;
private Boolean earlyStoppingEnabled;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -462,11 +478,16 @@ public class Regression implements DataFrameAnalysis {
return this;
}
public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
this.earlyStoppingEnabled = earlyStoppingEnabled;
return this;
}
public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor,
maxOptimizationRoundsPerHyperparameter);
maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
}
}

View File

@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setSoftTreeDepthTolerance(0.1)
.setDownsampleFactor(0.5)
.setMaxOptimizationRoundsPerHyperparameter(3)
.setMaxOptimizationRoundsPerHyperparameter(3)
.setEarlyStoppingEnabled(false)
.build())
.setDescription("this is a regression")
.build();
@ -1417,6 +1419,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setSoftTreeDepthTolerance(0.1)
.setDownsampleFactor(0.5)
.setMaxOptimizationRoundsPerHyperparameter(3)
.setEarlyStoppingEnabled(false)
.build())
.setDescription("this is a classification")
.build();

View File

@ -3059,6 +3059,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setSoftTreeDepthTolerance(1.0) // <17>
.setDownsampleFactor(0.5) // <18>
.setMaxOptimizationRoundsPerHyperparameter(3) // <19>
.setEarlyStoppingEnabled(true) // <20>
.build();
// end::put-data-frame-analytics-classification
@ -3084,6 +3085,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setSoftTreeDepthTolerance(1.0) // <17>
.setDownsampleFactor(0.5) // <18>
.setMaxOptimizationRoundsPerHyperparameter(3) // <19>
.setEarlyStoppingEnabled(true) // <20>
.build();
// end::put-data-frame-analytics-regression

View File

@ -60,6 +60,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
.setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
.setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
.build();
}

View File

@ -59,6 +59,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
.setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
.setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
.setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
.build();
}

View File

@ -134,6 +134,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
<20> Whether to enable early stopping to finish training process if it is not finding better models.
===== Regression
@ -164,6 +165,7 @@ fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature e
<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
<20> Whether to enable early stopping to finish training process if it is not finding better models.
==== Analyzed fields

View File

@ -117,6 +117,10 @@ different values in this field.
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
`early_stopping_enabled`::::
(Optional, Boolean)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
`eta`::::
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
@ -359,6 +363,10 @@ The data type of the field must be numeric.
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
`early_stopping_enabled`::::
(Optional, Boolean)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
`eta`::::
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]

View File

@ -557,6 +557,14 @@ Values must be greater than zero and less than or equal to 1.
By default, this value is calculated during hyperparameter optimization.
end::dfas-downsample-factor[]
tag::dfas-early-stopping-enabled[]
Advanced configuration option.
Specifies whether the training process should finish if it is not finding any
better perfoming models. If disabled, the training process can take significantly
longer and the chance of finding a better performing model is unremarkable.
By default, early stoppping is enabled.
end::dfas-early-stopping-enabled[]
tag::dfas-eta-growth[]
Advanced configuration option.
Specifies the rate at which `eta` increases for each new tree that is added

View File

@ -55,6 +55,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
private static final String STATE_DOC_ID_INFIX = "_classification_state#";
@ -82,7 +83,8 @@ public class Classification implements DataFrameAnalysis {
(Integer) a[15],
(Double) a[16],
(Long) a[17],
(List<PreProcessor>) a[18]));
(List<PreProcessor>) a[18],
(Boolean) a[19]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@ -96,6 +98,7 @@ public class Classification implements DataFrameAnalysis {
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(classification) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
return parser;
}
@ -159,6 +162,7 @@ public class Classification implements DataFrameAnalysis {
private final double trainingPercent;
private final long randomizeSeed;
private final List<PreProcessor> featureProcessors;
private final boolean earlyStoppingEnabled;
public Classification(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@ -167,7 +171,8 @@ public class Classification implements DataFrameAnalysis {
@Nullable Integer numTopClasses,
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) {
@Nullable List<PreProcessor> featureProcessors,
@Nullable Boolean earlyStoppingEnabled) {
if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException(
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
@ -184,10 +189,12 @@ public class Classification implements DataFrameAnalysis {
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
// Early stopping is true by default
this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
}
public Classification(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
}
public Classification(StreamInput in) throws IOException {
@ -211,6 +218,11 @@ public class Classification implements DataFrameAnalysis {
} else {
featureProcessors = Collections.emptyList();
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
earlyStoppingEnabled = in.readBoolean();
} else {
earlyStoppingEnabled = true;
}
}
public String getDependentVariable() {
@ -246,6 +258,10 @@ public class Classification implements DataFrameAnalysis {
return featureProcessors;
}
public Boolean getEarlyStoppingEnabled() {
return earlyStoppingEnabled;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -267,6 +283,9 @@ public class Classification implements DataFrameAnalysis {
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeNamedWriteableList(featureProcessors);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeBoolean(earlyStoppingEnabled);;
}
}
@Override
@ -288,6 +307,7 @@ public class Classification implements DataFrameAnalysis {
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
builder.endObject();
return builder;
}
@ -312,6 +332,7 @@ public class Classification implements DataFrameAnalysis {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
return params;
}
@ -457,6 +478,7 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(featureProcessors, that.featureProcessors)
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled)
&& trainingPercent == that.trainingPercent
&& randomizeSeed == that.randomizeSeed;
}
@ -464,7 +486,8 @@ public class Classification implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
numTopClasses, trainingPercent, randomizeSeed, featureProcessors,
earlyStoppingEnabled);
}
public enum ClassAssignmentObjective {

View File

@ -52,6 +52,7 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
private static final String STATE_DOC_ID_INFIX = "_regression_state#";
@ -72,7 +73,8 @@ public class Regression implements DataFrameAnalysis {
(Long) a[15],
(LossFunction) a[16],
(Double) a[17],
(List<PreProcessor>) a[18]));
(List<PreProcessor>) a[18],
(Boolean) a[19]));
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
BoostedTreeParams.declareFields(parser);
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@ -86,6 +88,7 @@ public class Regression implements DataFrameAnalysis {
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
(regression) -> {/*TODO should we throw if this is not set?*/},
FEATURE_PROCESSORS);
parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
return parser;
}
@ -124,6 +127,7 @@ public class Regression implements DataFrameAnalysis {
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
private final List<PreProcessor> featureProcessors;
private final boolean earlyStoppingEnabled;
public Regression(String dependentVariable,
BoostedTreeParams boostedTreeParams,
@ -132,7 +136,8 @@ public class Regression implements DataFrameAnalysis {
@Nullable Long randomizeSeed,
@Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter,
@Nullable List<PreProcessor> featureProcessors) {
@Nullable List<PreProcessor> featureProcessors,
@Nullable Boolean earlyStoppingEnabled) {
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
}
@ -148,10 +153,12 @@ public class Regression implements DataFrameAnalysis {
}
this.lossFunctionParameter = lossFunctionParameter;
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
// Early stopping is true by default
this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
}
public Regression(String dependentVariable) {
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
}
public Regression(StreamInput in) throws IOException {
@ -167,6 +174,11 @@ public class Regression implements DataFrameAnalysis {
} else {
featureProcessors = Collections.emptyList();
}
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
earlyStoppingEnabled = in.readBoolean();
} else {
earlyStoppingEnabled = true;
}
}
public String getDependentVariable() {
@ -202,6 +214,10 @@ public class Regression implements DataFrameAnalysis {
return featureProcessors;
}
public Boolean getEarlyStoppingEnabled() {
return earlyStoppingEnabled;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -219,6 +235,9 @@ public class Regression implements DataFrameAnalysis {
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeNamedWriteableList(featureProcessors);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeBoolean(earlyStoppingEnabled);
}
}
@Override
@ -242,6 +261,7 @@ public class Regression implements DataFrameAnalysis {
if (featureProcessors.isEmpty() == false) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
builder.endObject();
return builder;
}
@ -263,6 +283,7 @@ public class Regression implements DataFrameAnalysis {
params.put(FEATURE_PROCESSORS.getPreferredName(),
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
}
params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
return params;
}
@ -348,13 +369,14 @@ public class Regression implements DataFrameAnalysis {
&& randomizeSeed == that.randomizeSeed
&& lossFunction == that.lossFunction
&& Objects.equals(featureProcessors, that.featureProcessors)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
}
@Override
public int hashCode() {
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter, featureProcessors);
lossFunctionParameter, featureProcessors, earlyStoppingEnabled);
}
public enum LossFunction {

View File

@ -329,6 +329,7 @@ public final class ReservedFieldNames {
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
Regression.TRAINING_PERCENT.getPreferredName(),
Regression.FEATURE_PROCESSORS.getPreferredName(),
Regression.EARLY_STOPPING_ENABLED.getPreferredName(),
Classification.NAME.getPreferredName(),
Classification.DEPENDENT_VARIABLE.getPreferredName(),
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
@ -336,6 +337,7 @@ public final class ReservedFieldNames {
Classification.NUM_TOP_CLASSES.getPreferredName(),
Classification.TRAINING_PERCENT.getPreferredName(),
Classification.FEATURE_PROCESSORS.getPreferredName(),
Classification.EARLY_STOPPING_ENABLED.getPreferredName(),
BoostedTreeParams.ALPHA.getPreferredName(),
BoostedTreeParams.DOWNSAMPLE_FACTOR.getPreferredName(),
BoostedTreeParams.LAMBDA.getPreferredName(),

View File

@ -78,6 +78,9 @@
},
"training_percent" : {
"type" : "double"
},
"early_stopping_enabled" : {
"type": "boolean"
}
}
},
@ -149,6 +152,9 @@
},
"training_percent" : {
"type" : "double"
},
"early_stopping_enabled" : {
"type": "boolean"
}
}
}

View File

@ -151,7 +151,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
42L,
bwcRegression.getLossFunction(),
bwcRegression.getLossFunctionParameter(),
bwcRegression.getFeatureProcessors());
bwcRegression.getFeatureProcessors(),
bwcRegression.getEarlyStoppingEnabled());
testAnalysis = new Regression(testRegression.getDependentVariable(),
testRegression.getBoostedTreeParams(),
testRegression.getPredictionFieldName(),
@ -159,7 +160,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
42L,
testRegression.getLossFunction(),
testRegression.getLossFunctionParameter(),
bwcRegression.getFeatureProcessors());
testRegression.getFeatureProcessors(),
testRegression.getEarlyStoppingEnabled());
} else {
Classification testClassification = (Classification)testInstance.getAnalysis();
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
@ -170,7 +172,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
bwcClassification.getNumTopClasses(),
bwcClassification.getTrainingPercent(),
42L,
bwcClassification.getFeatureProcessors());
bwcClassification.getFeatureProcessors(),
bwcClassification.getEarlyStoppingEnabled());
testAnalysis = new Classification(testClassification.getDependentVariable(),
testClassification.getBoostedTreeParams(),
testClassification.getPredictionFieldName(),
@ -178,7 +181,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
testClassification.getNumTopClasses(),
testClassification.getTrainingPercent(),
42L,
testClassification.getFeatureProcessors());
testClassification.getFeatureProcessors(),
testClassification.getEarlyStoppingEnabled());
}
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
.setAnalysis(bwcAnalysis)

View File

@ -97,6 +97,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
Long randomizeSeed = randomBoolean() ? null : randomLong();
Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed,
randomBoolean() ?
@ -105,7 +106,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
OneHotEncodingTests.createRandom(true),
TargetMeanEncodingTests.createRandom(true)))
.limit(randomIntBetween(0, 5))
.collect(Collectors.toList()));
.collect(Collectors.toList()),
earlyStoppingEnabled);
}
public static Classification mutateForVersion(Classification instance, Version version) {
@ -116,7 +118,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
instance.getNumTopClasses(),
instance.getTrainingPercent(),
instance.getRandomizeSeed(),
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
}
@Override
@ -133,7 +136,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
bwcSerializedObject.getNumTopClasses(),
bwcSerializedObject.getTrainingPercent(),
42L,
bwcSerializedObject.getFeatureProcessors());
bwcSerializedObject.getFeatureProcessors(),
bwcSerializedObject.getEarlyStoppingEnabled());
Classification newInstance = new Classification(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
@ -141,7 +145,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
testInstance.getNumTopClasses(),
testInstance.getTrainingPercent(),
42L,
testInstance.getFeatureProcessors());
testInstance.getFeatureProcessors(),
testInstance.getEarlyStoppingEnabled());
super.assertOnBWCObject(newBwc, newInstance, version);
}
@ -202,96 +207,96 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
public void testConstructor_GivenTrainingPercentIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null, null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
}
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null, null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
}
public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
assertThat(classification.getPredictionFieldName(), equalTo("result"));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null, null);
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testClassAssignmentObjective() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null, null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null, null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
// class_assignment_objective == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
}
public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
assertThat(classification.getNumTopClasses(), equalTo(7));
// Special value: num_top_classes == -1
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null, null);
assertThat(classification.getNumTopClasses(), equalTo(-1));
// Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null, null);
assertThat(classification.getNumTopClasses(), equalTo(0));
// Boundary condition: num_top_classes == 1000
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null, null);
assertThat(classification.getNumTopClasses(), equalTo(1000));
// num_top_classes == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null, null);
assertThat(classification.getNumTopClasses(), equalTo(2));
}
public void testGetTrainingPercent() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
assertThat(classification.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null, null);
assertThat(classification.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null, null);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null, null);
assertThat(classification.getTrainingPercent(), equalTo(100.0));
}
@ -316,7 +321,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
"prediction_field_name", "foo_prediction",
"prediction_field_type", "bool",
"num_classes", 10L,
"training_percent", 100.0)));
"training_percent", 100.0,
"early_stopping_enabled", true)));
assertThat(
new Classification("bar").getParams(fieldInfo),
equalTo(
@ -327,7 +333,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
"prediction_field_name", "bar_prediction",
"prediction_field_type", "int",
"num_classes", 20L,
"training_percent", 100.0)));
"training_percent", 100.0,
"early_stopping_enabled", true)));
assertThat(
new Classification("baz",
BoostedTreeParams.builder().build() ,
@ -336,6 +343,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
null,
50.0,
null,
null,
null).getParams(fieldInfo),
equalTo(
Map.of(
@ -345,7 +353,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
"prediction_field_name", "baz_prediction",
"prediction_field_type", "string",
"num_classes", 30L,
"training_percent", 50.0)));
"training_percent", 50.0,
"early_stopping_enabled", true)));
}
public void testRequiredFieldsIsNonEmpty() {

View File

@ -88,6 +88,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Long randomizeSeed = randomBoolean() ? null : randomLong();
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
lossFunctionParameter,
randomBoolean() ?
@ -96,7 +97,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
OneHotEncodingTests.createRandom(true),
TargetMeanEncodingTests.createRandom(true)))
.limit(randomIntBetween(0, 5))
.collect(Collectors.toList()));
.collect(Collectors.toList()),
earlyStoppingEnabled);
}
public static Regression mutateForVersion(Regression instance, Version version) {
@ -107,7 +109,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
instance.getRandomizeSeed(),
instance.getLossFunction(),
instance.getLossFunctionParameter(),
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
}
@Override
@ -124,7 +127,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
42L,
bwcSerializedObject.getLossFunction(),
bwcSerializedObject.getLossFunctionParameter(),
bwcSerializedObject.getFeatureProcessors());
bwcSerializedObject.getFeatureProcessors(),
bwcSerializedObject.getEarlyStoppingEnabled());
Regression newInstance = new Regression(testInstance.getDependentVariable(),
testInstance.getBoostedTreeParams(),
testInstance.getPredictionFieldName(),
@ -132,7 +136,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
42L,
testInstance.getLossFunction(),
testInstance.getLossFunctionParameter(),
testInstance.getFeatureProcessors());
testInstance.getFeatureProcessors(),
testInstance.getEarlyStoppingEnabled());
super.assertOnBWCObject(newBwc, newInstance, version);
}
@ -198,21 +203,24 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
public void testConstructor_GivenTrainingPercentIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(),
Regression.LossFunction.MSE, null, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(),
Regression.LossFunction.MSE, null, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(),
Regression.LossFunction.MSE, null, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
@ -220,55 +228,48 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
public void testConstructor_GivenLossFunctionParameterIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
Regression.LossFunction.MSE, 0.0, null, null));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
}
public void testConstructor_GivenLossFunctionParameterIsNegative() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
Regression.LossFunction.MSE, -1.0, null, null));
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
}
public void testGetPredictionFieldName() {
Regression regression = new Regression(
"foo",
BOOSTED_TREE_PARAMS,
"result",
50.0,
randomLong(),
Regression.LossFunction.MSE,
1.0,
null);
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
Regression.LossFunction.MSE, 1.0, null, null);
assertThat(regression.getPredictionFieldName(), equalTo("result"));
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(),
Regression.LossFunction.MSE, null, null, null);
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetTrainingPercent() {
Regression regression = new Regression("foo",
BOOSTED_TREE_PARAMS,
"result",
50.0,
randomLong(),
Regression.LossFunction.MSE,
1.0,
null);
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
Regression.LossFunction.MSE, 1.0, null, null);
assertThat(regression.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(),
Regression.LossFunction.MSE, null, null, null);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
Regression.LossFunction.MSE, null, null, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(),
Regression.LossFunction.MSE, null, null, null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
@ -276,21 +277,17 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
int maxTrees = randomIntBetween(1, 100);
Regression regression = new Regression("foo",
BoostedTreeParams.builder().setMaxTrees(maxTrees).build(),
null,
100.0,
0L,
Regression.LossFunction.MSE,
null,
null);
null, 100.0, 0L, Regression.LossFunction.MSE, null, null, null);
Map<String, Object> params = regression.getParams(null);
assertThat(params.size(), equalTo(5));
assertThat(params.size(), equalTo(6));
assertThat(params.get("dependent_variable"), equalTo("foo"));
assertThat(params.get("prediction_field_name"), equalTo("foo_prediction"));
assertThat(params.get("max_trees"), equalTo(maxTrees));
assertThat(params.get("training_percent"), equalTo(100.0));
assertThat(params.get("loss_function"), equalTo("mse"));
assertThat(params.get("early_stopping_enabled"), equalTo(true));
}
public void testGetParams_GivenRandomWithoutBoostedTreeParams() {
@ -298,7 +295,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Map<String, Object> params = regression.getParams(null);
int expectedParamsCount = 4
int expectedParamsCount = 5
+ (regression.getLossFunctionParameter() == null ? 0 : 1)
+ (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
assertThat(params.size(), equalTo(expectedParamsCount));
@ -311,6 +308,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
} else {
assertThat(params.get("loss_function_parameter"), equalTo(regression.getLossFunctionParameter()));
}
assertThat(params.get("early_stopping_enabled"), equalTo(regression.getEarlyStoppingEnabled()));
}
public void testRequiredFieldsIsNonEmpty() {

View File

@ -141,6 +141,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null));
putAnalytics(config);
@ -197,6 +198,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null));
putAnalytics(config);
@ -317,7 +319,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
)));
),
null));
putAnalytics(config);
assertIsStopped(jobId);
@ -386,7 +389,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null,
numTopClasses, 50.0, null, null, null));
putAnalytics(config);
assertIsStopped(jobId);
@ -650,7 +654,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null, null));
putAnalytics(firstJob);
startAnalytics(firstJobId);
waitUntilAnalyticsIsStopped(firstJobId);
@ -660,7 +664,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null, null));
putAnalytics(secondJob);
startAnalytics(secondJobId);

View File

@ -128,7 +128,8 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
new OneHotEncoding("ngram.21", MapBuilder.<String, String>newMapBuilder().put("at", "is_cat").map(), true)
},
true)
)))
),
null))
.setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
.build();
putAnalytics(config);

View File

@ -105,6 +105,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
null,
null,
null,
null,
null))
.buildForExplain();
@ -124,6 +125,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
null,
null,
null,
null,
null))
.buildForExplain();
@ -152,6 +154,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
null,
null,
null,
null,
null))
.buildForExplain();

View File

@ -115,6 +115,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null)
);
putAnalytics(config);
@ -251,7 +252,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(),
null, 50.0, null, null, null, null, null));
putAnalytics(config);
assertIsStopped(jobId);
@ -371,7 +373,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
.build();
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0,
null, null, null, null, null));
putAnalytics(firstJob);
startAnalytics(firstJobId);
waitUntilAnalyticsIsStopped(firstJobId);
@ -381,7 +384,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0,
randomizeSeed, null, null, null, null));
putAnalytics(secondJob);
startAnalytics(secondJobId);
@ -438,7 +442,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(),
null, null, null, null, null, null, null));
putAnalytics(config);
assertIsStopped(jobId);
@ -465,6 +470,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null)
);
putAnalytics(config);
@ -562,6 +568,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
null,
null,
null,
null,
null);
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(jobId)
@ -635,7 +642,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Arrays.asList(
new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
))
),
null)
);
putAnalytics(config);

View File

@ -1169,7 +1169,8 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
null,
null,
null,
featureprocessors))
featureprocessors,
null))
.build();
}

View File

@ -1513,7 +1513,8 @@ setup:
"soft_tree_depth_limit": 2.0,
"soft_tree_depth_tolerance": 3.0,
"downsample_factor": 0.5,
"max_optimization_rounds_per_hyperparameter": 3
"max_optimization_rounds_per_hyperparameter": 3,
"early_stopping_enabled": true
}
}
}
@ -1538,7 +1539,8 @@ setup:
"soft_tree_depth_limit": 2.0,
"soft_tree_depth_tolerance": 3.0,
"downsample_factor": 0.5,
"max_optimization_rounds_per_hyperparameter": 3
"max_optimization_rounds_per_hyperparameter": 3,
"early_stopping_enabled": true
}
}}
- is_true: create_time
@ -1870,7 +1872,8 @@ setup:
"soft_tree_depth_limit": 2.0,
"soft_tree_depth_tolerance": 3.0,
"downsample_factor": 0.5,
"max_optimization_rounds_per_hyperparameter": 3
"max_optimization_rounds_per_hyperparameter": 3,
"early_stopping_enabled": true
}
}
}
@ -1895,7 +1898,8 @@ setup:
"soft_tree_depth_limit": 2.0,
"soft_tree_depth_tolerance": 3.0,
"downsample_factor": 0.5,
"max_optimization_rounds_per_hyperparameter": 3
"max_optimization_rounds_per_hyperparameter": 3,
"early_stopping_enabled": true
}
}}
- is_true: create_time
@ -1939,7 +1943,8 @@ setup:
"training_percent": 100.0,
"randomize_seed": 24,
"class_assignment_objective": "maximize_minimum_recall",
"num_top_classes": 2
"num_top_classes": 2,
"early_stopping_enabled": true
}
}}
- is_true: create_time
@ -1977,7 +1982,8 @@ setup:
"prediction_field_name": "foo_prediction",
"training_percent": 100.0,
"randomize_seed": 42,
"loss_function": "mse"
"loss_function": "mse",
"early_stopping_enabled": true
}
}}
- is_true: create_time