[ML] Add early stopping DFA configuration parameter (#68099)
The PR adds early_stopping_enabled optional data frame analysis configuration parameter. The enhancement was already described in elastic/ml-cpp#1676 and so I mark it here as non-issue.
This commit is contained in:
parent
4cbe61467c
commit
78368428b3
|
@ -63,6 +63,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
|
||||
static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
|
||||
static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||
|
@ -88,7 +89,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
(Double) a[15],
|
||||
(Double) a[16],
|
||||
(Double) a[17],
|
||||
(Integer) a[18]
|
||||
(Integer) a[18],
|
||||
(Boolean) a[19]
|
||||
));
|
||||
|
||||
static {
|
||||
|
@ -115,6 +117,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
|
@ -136,6 +139,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final Double softTreeDepthTolerance;
|
||||
private final Double downsampleFactor;
|
||||
private final Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private final Boolean earlyStoppingEnabled;
|
||||
|
||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||
|
@ -144,7 +148,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors,
|
||||
@Nullable Double alpha, @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit,
|
||||
@Nullable Double softTreeDepthTolerance, @Nullable Double downsampleFactor,
|
||||
@Nullable Integer maxOptimizationRoundsPerHyperparameter) {
|
||||
@Nullable Integer maxOptimizationRoundsPerHyperparameter, @Nullable Boolean earlyStoppingEnabled) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
|
@ -164,6 +168,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.softTreeDepthTolerance = softTreeDepthTolerance;
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -247,6 +252,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return maxOptimizationRoundsPerHyperparameter;
|
||||
}
|
||||
|
||||
public Boolean getEarlyStoppingEnable() {
|
||||
return earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -305,6 +314,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (maxOptimizationRoundsPerHyperparameter != null) {
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
}
|
||||
if (earlyStoppingEnabled != null) {
|
||||
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -313,7 +325,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors, alpha,
|
||||
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
|
||||
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
|
||||
earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -339,7 +352,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
|
||||
&& Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
|
||||
&& Objects.equals(downsampleFactor, that.downsampleFactor)
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
|
||||
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -380,6 +394,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private Double softTreeDepthTolerance;
|
||||
private Double downsampleFactor;
|
||||
private Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private Boolean earlyStoppingEnabled;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
@ -475,11 +490,16 @@ public class Classification implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Classification build() {
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
|
||||
classAssignmentObjective, featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance,
|
||||
downsampleFactor, maxOptimizationRoundsPerHyperparameter);
|
||||
downsampleFactor, maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,6 +65,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
|
||||
static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
||||
static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
|
||||
static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
||||
|
@ -90,7 +91,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
(Double) a[15],
|
||||
(Double) a[16],
|
||||
(Double) a[17],
|
||||
(Integer) a[18]
|
||||
(Integer) a[18],
|
||||
(Boolean) a[19]
|
||||
));
|
||||
|
||||
static {
|
||||
|
@ -116,6 +118,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
|
||||
}
|
||||
|
||||
private final String dependentVariable;
|
||||
|
@ -137,6 +140,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final Double softTreeDepthTolerance;
|
||||
private final Double downsampleFactor;
|
||||
private final Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private final Boolean earlyStoppingEnabled;
|
||||
|
||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||
|
@ -144,7 +148,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors, @Nullable Double alpha,
|
||||
@Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit, @Nullable Double softTreeDepthTolerance,
|
||||
@Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
|
||||
@Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter,
|
||||
@Nullable Boolean earlyStoppingEnabled) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
|
@ -164,6 +169,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
this.softTreeDepthTolerance = softTreeDepthTolerance;
|
||||
this.downsampleFactor = downsampleFactor;
|
||||
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -247,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return maxOptimizationRoundsPerHyperparameter;
|
||||
}
|
||||
|
||||
public Boolean getEarlyStoppingEnabled() {
|
||||
return earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -305,6 +315,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (maxOptimizationRoundsPerHyperparameter != null) {
|
||||
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
||||
}
|
||||
if (earlyStoppingEnabled != null) {
|
||||
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -313,7 +326,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors, alpha,
|
||||
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
|
||||
etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
|
||||
earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -339,7 +353,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
|
||||
&& Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
|
||||
&& Objects.equals(downsampleFactor, that.downsampleFactor)
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
|
||||
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
|
||||
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -367,6 +382,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private Double softTreeDepthTolerance;
|
||||
private Double downsampleFactor;
|
||||
private Integer maxOptimizationRoundsPerHyperparameter;
|
||||
private Boolean earlyStoppingEnabled;
|
||||
|
||||
private Builder(String dependentVariable) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
@ -462,11 +478,16 @@ public class Regression implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Regression build() {
|
||||
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
|
||||
featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor,
|
||||
maxOptimizationRoundsPerHyperparameter);
|
||||
maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setSoftTreeDepthTolerance(0.1)
|
||||
.setDownsampleFactor(0.5)
|
||||
.setMaxOptimizationRoundsPerHyperparameter(3)
|
||||
.setMaxOptimizationRoundsPerHyperparameter(3)
|
||||
.setEarlyStoppingEnabled(false)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
.build();
|
||||
|
@ -1417,6 +1419,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setSoftTreeDepthTolerance(0.1)
|
||||
.setDownsampleFactor(0.5)
|
||||
.setMaxOptimizationRoundsPerHyperparameter(3)
|
||||
.setEarlyStoppingEnabled(false)
|
||||
.build())
|
||||
.setDescription("this is a classification")
|
||||
.build();
|
||||
|
|
|
@ -3059,6 +3059,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setSoftTreeDepthTolerance(1.0) // <17>
|
||||
.setDownsampleFactor(0.5) // <18>
|
||||
.setMaxOptimizationRoundsPerHyperparameter(3) // <19>
|
||||
.setEarlyStoppingEnabled(true) // <20>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-classification
|
||||
|
||||
|
@ -3084,6 +3085,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setSoftTreeDepthTolerance(1.0) // <17>
|
||||
.setDownsampleFactor(0.5) // <18>
|
||||
.setMaxOptimizationRoundsPerHyperparameter(3) // <19>
|
||||
.setEarlyStoppingEnabled(true) // <20>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-regression
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
.setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
|
||||
.setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
|
||||
.setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
.setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
|
||||
.setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
|
||||
.setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||
<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
|
||||
<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
|
||||
<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
|
||||
<20> Whether to enable early stopping to finish training process if it is not finding better models.
|
||||
|
||||
===== Regression
|
||||
|
||||
|
@ -164,6 +165,7 @@ fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature e
|
|||
<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
|
||||
<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
|
||||
<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
|
||||
<20> Whether to enable early stopping to finish training process if it is not finding better models.
|
||||
|
||||
==== Analyzed fields
|
||||
|
||||
|
|
|
@ -117,6 +117,10 @@ different values in this field.
|
|||
(Optional, double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
|
||||
|
||||
`early_stopping_enabled`::::
|
||||
(Optional, Boolean)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
|
||||
|
||||
`eta`::::
|
||||
(Optional, double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
|
||||
|
@ -359,6 +363,10 @@ The data type of the field must be numeric.
|
|||
(Optional, double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
|
||||
|
||||
`early_stopping_enabled`::::
|
||||
(Optional, Boolean)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
|
||||
|
||||
`eta`::::
|
||||
(Optional, double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
|
||||
|
|
|
@ -557,6 +557,14 @@ Values must be greater than zero and less than or equal to 1.
|
|||
By default, this value is calculated during hyperparameter optimization.
|
||||
end::dfas-downsample-factor[]
|
||||
|
||||
tag::dfas-early-stopping-enabled[]
|
||||
Advanced configuration option.
|
||||
Specifies whether the training process should finish if it is not finding any
|
||||
better perfoming models. If disabled, the training process can take significantly
|
||||
longer and the chance of finding a better performing model is unremarkable.
|
||||
By default, early stoppping is enabled.
|
||||
end::dfas-early-stopping-enabled[]
|
||||
|
||||
tag::dfas-eta-growth[]
|
||||
Advanced configuration option.
|
||||
Specifies the rate at which `eta` increases for each new tree that is added
|
||||
|
|
|
@ -55,6 +55,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||
public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
|
||||
|
||||
private static final String STATE_DOC_ID_INFIX = "_classification_state#";
|
||||
|
||||
|
@ -82,7 +83,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
(Integer) a[15],
|
||||
(Double) a[16],
|
||||
(Long) a[17],
|
||||
(List<PreProcessor>) a[18]));
|
||||
(List<PreProcessor>) a[18],
|
||||
(Boolean) a[19]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -96,6 +98,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||
(classification) -> {/*TODO should we throw if this is not set?*/},
|
||||
FEATURE_PROCESSORS);
|
||||
parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -159,6 +162,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final double trainingPercent;
|
||||
private final long randomizeSeed;
|
||||
private final List<PreProcessor> featureProcessors;
|
||||
private final boolean earlyStoppingEnabled;
|
||||
|
||||
public Classification(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
|
@ -167,7 +171,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Nullable Integer numTopClasses,
|
||||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
@Nullable List<PreProcessor> featureProcessors,
|
||||
@Nullable Boolean earlyStoppingEnabled) {
|
||||
if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
|
||||
|
@ -184,10 +189,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
|
||||
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||
// Early stopping is true by default
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
public Classification(String dependentVariable) {
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
|
@ -211,6 +218,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
} else {
|
||||
featureProcessors = Collections.emptyList();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
|
||||
earlyStoppingEnabled = in.readBoolean();
|
||||
} else {
|
||||
earlyStoppingEnabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
|
@ -246,6 +258,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return featureProcessors;
|
||||
}
|
||||
|
||||
public Boolean getEarlyStoppingEnabled() {
|
||||
return earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -267,6 +283,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeNamedWriteableList(featureProcessors);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
|
||||
out.writeBoolean(earlyStoppingEnabled);;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -288,6 +307,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (featureProcessors.isEmpty() == false) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||
}
|
||||
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -312,6 +332,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||
}
|
||||
params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -457,6 +478,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled)
|
||||
&& trainingPercent == that.trainingPercent
|
||||
&& randomizeSeed == that.randomizeSeed;
|
||||
}
|
||||
|
@ -464,7 +486,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
|
||||
numTopClasses, trainingPercent, randomizeSeed, featureProcessors,
|
||||
earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
public enum ClassAssignmentObjective {
|
||||
|
|
|
@ -52,6 +52,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||
public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
|
||||
|
||||
private static final String STATE_DOC_ID_INFIX = "_regression_state#";
|
||||
|
||||
|
@ -72,7 +73,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
(Long) a[15],
|
||||
(LossFunction) a[16],
|
||||
(Double) a[17],
|
||||
(List<PreProcessor>) a[18]));
|
||||
(List<PreProcessor>) a[18],
|
||||
(Boolean) a[19]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -86,6 +88,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
|
||||
(regression) -> {/*TODO should we throw if this is not set?*/},
|
||||
FEATURE_PROCESSORS);
|
||||
parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -124,6 +127,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final LossFunction lossFunction;
|
||||
private final Double lossFunctionParameter;
|
||||
private final List<PreProcessor> featureProcessors;
|
||||
private final boolean earlyStoppingEnabled;
|
||||
|
||||
public Regression(String dependentVariable,
|
||||
BoostedTreeParams boostedTreeParams,
|
||||
|
@ -132,7 +136,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Nullable Long randomizeSeed,
|
||||
@Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
@Nullable List<PreProcessor> featureProcessors,
|
||||
@Nullable Boolean earlyStoppingEnabled) {
|
||||
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
|
@ -148,10 +153,12 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
this.lossFunctionParameter = lossFunctionParameter;
|
||||
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
|
||||
// Early stopping is true by default
|
||||
this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
|
@ -167,6 +174,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
} else {
|
||||
featureProcessors = Collections.emptyList();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
|
||||
earlyStoppingEnabled = in.readBoolean();
|
||||
} else {
|
||||
earlyStoppingEnabled = true;
|
||||
}
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
|
@ -202,6 +214,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return featureProcessors;
|
||||
}
|
||||
|
||||
public Boolean getEarlyStoppingEnabled() {
|
||||
return earlyStoppingEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -219,6 +235,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeNamedWriteableList(featureProcessors);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
|
||||
out.writeBoolean(earlyStoppingEnabled);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -242,6 +261,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (featureProcessors.isEmpty() == false) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||
}
|
||||
builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -263,6 +283,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
params.put(FEATURE_PROCESSORS.getPreferredName(),
|
||||
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
|
||||
}
|
||||
params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -348,13 +369,14 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& randomizeSeed == that.randomizeSeed
|
||||
&& lossFunction == that.lossFunction
|
||||
&& Objects.equals(featureProcessors, that.featureProcessors)
|
||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
|
||||
&& Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter, featureProcessors);
|
||||
lossFunctionParameter, featureProcessors, earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
public enum LossFunction {
|
||||
|
|
|
@ -329,6 +329,7 @@ public final class ReservedFieldNames {
|
|||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||
Regression.FEATURE_PROCESSORS.getPreferredName(),
|
||||
Regression.EARLY_STOPPING_ENABLED.getPreferredName(),
|
||||
Classification.NAME.getPreferredName(),
|
||||
Classification.DEPENDENT_VARIABLE.getPreferredName(),
|
||||
Classification.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
|
@ -336,6 +337,7 @@ public final class ReservedFieldNames {
|
|||
Classification.NUM_TOP_CLASSES.getPreferredName(),
|
||||
Classification.TRAINING_PERCENT.getPreferredName(),
|
||||
Classification.FEATURE_PROCESSORS.getPreferredName(),
|
||||
Classification.EARLY_STOPPING_ENABLED.getPreferredName(),
|
||||
BoostedTreeParams.ALPHA.getPreferredName(),
|
||||
BoostedTreeParams.DOWNSAMPLE_FACTOR.getPreferredName(),
|
||||
BoostedTreeParams.LAMBDA.getPreferredName(),
|
||||
|
|
|
@ -78,6 +78,9 @@
|
|||
},
|
||||
"training_percent" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"early_stopping_enabled" : {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -149,6 +152,9 @@
|
|||
},
|
||||
"training_percent" : {
|
||||
"type" : "double"
|
||||
},
|
||||
"early_stopping_enabled" : {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,7 +151,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
42L,
|
||||
bwcRegression.getLossFunction(),
|
||||
bwcRegression.getLossFunctionParameter(),
|
||||
bwcRegression.getFeatureProcessors());
|
||||
bwcRegression.getFeatureProcessors(),
|
||||
bwcRegression.getEarlyStoppingEnabled());
|
||||
testAnalysis = new Regression(testRegression.getDependentVariable(),
|
||||
testRegression.getBoostedTreeParams(),
|
||||
testRegression.getPredictionFieldName(),
|
||||
|
@ -159,7 +160,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
42L,
|
||||
testRegression.getLossFunction(),
|
||||
testRegression.getLossFunctionParameter(),
|
||||
bwcRegression.getFeatureProcessors());
|
||||
testRegression.getFeatureProcessors(),
|
||||
testRegression.getEarlyStoppingEnabled());
|
||||
} else {
|
||||
Classification testClassification = (Classification)testInstance.getAnalysis();
|
||||
Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
|
||||
|
@ -170,7 +172,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
bwcClassification.getNumTopClasses(),
|
||||
bwcClassification.getTrainingPercent(),
|
||||
42L,
|
||||
bwcClassification.getFeatureProcessors());
|
||||
bwcClassification.getFeatureProcessors(),
|
||||
bwcClassification.getEarlyStoppingEnabled());
|
||||
testAnalysis = new Classification(testClassification.getDependentVariable(),
|
||||
testClassification.getBoostedTreeParams(),
|
||||
testClassification.getPredictionFieldName(),
|
||||
|
@ -178,7 +181,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
|
|||
testClassification.getNumTopClasses(),
|
||||
testClassification.getTrainingPercent(),
|
||||
42L,
|
||||
testClassification.getFeatureProcessors());
|
||||
testClassification.getFeatureProcessors(),
|
||||
testClassification.getEarlyStoppingEnabled());
|
||||
}
|
||||
super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
|
||||
.setAnalysis(bwcAnalysis)
|
||||
|
|
|
@ -97,6 +97,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
numTopClasses, trainingPercent, randomizeSeed,
|
||||
randomBoolean() ?
|
||||
|
@ -105,7 +106,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
OneHotEncodingTests.createRandom(true),
|
||||
TargetMeanEncodingTests.createRandom(true)))
|
||||
.limit(randomIntBetween(0, 5))
|
||||
.collect(Collectors.toList()));
|
||||
.collect(Collectors.toList()),
|
||||
earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
public static Classification mutateForVersion(Classification instance, Version version) {
|
||||
|
@ -116,7 +118,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
instance.getNumTopClasses(),
|
||||
instance.getTrainingPercent(),
|
||||
instance.getRandomizeSeed(),
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
|
||||
version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,7 +136,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
bwcSerializedObject.getNumTopClasses(),
|
||||
bwcSerializedObject.getTrainingPercent(),
|
||||
42L,
|
||||
bwcSerializedObject.getFeatureProcessors());
|
||||
bwcSerializedObject.getFeatureProcessors(),
|
||||
bwcSerializedObject.getEarlyStoppingEnabled());
|
||||
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
|
@ -141,7 +145,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
testInstance.getNumTopClasses(),
|
||||
testInstance.getTrainingPercent(),
|
||||
42L,
|
||||
testInstance.getFeatureProcessors());
|
||||
testInstance.getFeatureProcessors(),
|
||||
testInstance.getEarlyStoppingEnabled());
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
|
@ -202,96 +207,96 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
|
||||
public void testConstructor_GivenTrainingPercentIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null, null);
|
||||
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testClassAssignmentObjective() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
||||
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
|
||||
Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||
|
||||
// class_assignment_objective == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
||||
}
|
||||
|
||||
public void testGetNumTopClasses() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
||||
// Special value: num_top_classes == -1
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(-1));
|
||||
|
||||
// Boundary condition: num_top_classes == 0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||
|
||||
// Boundary condition: num_top_classes == 1000
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
||||
|
||||
// num_top_classes == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(2));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null, null);
|
||||
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -316,7 +321,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
"prediction_field_name", "foo_prediction",
|
||||
"prediction_field_type", "bool",
|
||||
"num_classes", 10L,
|
||||
"training_percent", 100.0)));
|
||||
"training_percent", 100.0,
|
||||
"early_stopping_enabled", true)));
|
||||
assertThat(
|
||||
new Classification("bar").getParams(fieldInfo),
|
||||
equalTo(
|
||||
|
@ -327,7 +333,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
"prediction_field_name", "bar_prediction",
|
||||
"prediction_field_type", "int",
|
||||
"num_classes", 20L,
|
||||
"training_percent", 100.0)));
|
||||
"training_percent", 100.0,
|
||||
"early_stopping_enabled", true)));
|
||||
assertThat(
|
||||
new Classification("baz",
|
||||
BoostedTreeParams.builder().build() ,
|
||||
|
@ -336,6 +343,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
null,
|
||||
50.0,
|
||||
null,
|
||||
null,
|
||||
null).getParams(fieldInfo),
|
||||
equalTo(
|
||||
Map.of(
|
||||
|
@ -345,7 +353,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
"prediction_field_name", "baz_prediction",
|
||||
"prediction_field_type", "string",
|
||||
"num_classes", 30L,
|
||||
"training_percent", 50.0)));
|
||||
"training_percent", 50.0,
|
||||
"early_stopping_enabled", true)));
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
|
|
@ -88,6 +88,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
||||
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
||||
Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
|
||||
return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
|
||||
lossFunctionParameter,
|
||||
randomBoolean() ?
|
||||
|
@ -96,7 +97,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
OneHotEncodingTests.createRandom(true),
|
||||
TargetMeanEncodingTests.createRandom(true)))
|
||||
.limit(randomIntBetween(0, 5))
|
||||
.collect(Collectors.toList()));
|
||||
.collect(Collectors.toList()),
|
||||
earlyStoppingEnabled);
|
||||
}
|
||||
|
||||
public static Regression mutateForVersion(Regression instance, Version version) {
|
||||
|
@ -107,7 +109,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
instance.getRandomizeSeed(),
|
||||
instance.getLossFunction(),
|
||||
instance.getLossFunctionParameter(),
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
||||
version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
|
||||
version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -124,7 +127,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
42L,
|
||||
bwcSerializedObject.getLossFunction(),
|
||||
bwcSerializedObject.getLossFunctionParameter(),
|
||||
bwcSerializedObject.getFeatureProcessors());
|
||||
bwcSerializedObject.getFeatureProcessors(),
|
||||
bwcSerializedObject.getEarlyStoppingEnabled());
|
||||
Regression newInstance = new Regression(testInstance.getDependentVariable(),
|
||||
testInstance.getBoostedTreeParams(),
|
||||
testInstance.getPredictionFieldName(),
|
||||
|
@ -132,7 +136,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
42L,
|
||||
testInstance.getLossFunction(),
|
||||
testInstance.getLossFunctionParameter(),
|
||||
testInstance.getFeatureProcessors());
|
||||
testInstance.getFeatureProcessors(),
|
||||
testInstance.getEarlyStoppingEnabled());
|
||||
super.assertOnBWCObject(newBwc, newInstance, version);
|
||||
}
|
||||
|
||||
|
@ -198,21 +203,24 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
|
||||
public void testConstructor_GivenTrainingPercentIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null));
|
||||
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
|
@ -220,55 +228,48 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
|
||||
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
|
||||
Regression.LossFunction.MSE, 0.0, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsNegative() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
|
||||
Regression.LossFunction.MSE, -1.0, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
Regression regression = new Regression(
|
||||
"foo",
|
||||
BOOSTED_TREE_PARAMS,
|
||||
"result",
|
||||
50.0,
|
||||
randomLong(),
|
||||
Regression.LossFunction.MSE,
|
||||
1.0,
|
||||
null);
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
|
||||
Regression.LossFunction.MSE, 1.0, null, null);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("result"));
|
||||
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null);
|
||||
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
|
||||
}
|
||||
|
||||
public void testGetTrainingPercent() {
|
||||
Regression regression = new Regression("foo",
|
||||
BOOSTED_TREE_PARAMS,
|
||||
"result",
|
||||
50.0,
|
||||
randomLong(),
|
||||
Regression.LossFunction.MSE,
|
||||
1.0,
|
||||
null);
|
||||
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
|
||||
Regression.LossFunction.MSE, 1.0, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(50.0));
|
||||
|
||||
// Boundary condition: training_percent == 1.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
|
||||
// Boundary condition: training_percent == 100.0
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
|
||||
// training_percent == null, default applied
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
|
||||
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(),
|
||||
Regression.LossFunction.MSE, null, null, null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
|
@ -276,21 +277,17 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
int maxTrees = randomIntBetween(1, 100);
|
||||
Regression regression = new Regression("foo",
|
||||
BoostedTreeParams.builder().setMaxTrees(maxTrees).build(),
|
||||
null,
|
||||
100.0,
|
||||
0L,
|
||||
Regression.LossFunction.MSE,
|
||||
null,
|
||||
null);
|
||||
null, 100.0, 0L, Regression.LossFunction.MSE, null, null, null);
|
||||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
||||
assertThat(params.size(), equalTo(5));
|
||||
assertThat(params.size(), equalTo(6));
|
||||
assertThat(params.get("dependent_variable"), equalTo("foo"));
|
||||
assertThat(params.get("prediction_field_name"), equalTo("foo_prediction"));
|
||||
assertThat(params.get("max_trees"), equalTo(maxTrees));
|
||||
assertThat(params.get("training_percent"), equalTo(100.0));
|
||||
assertThat(params.get("loss_function"), equalTo("mse"));
|
||||
assertThat(params.get("early_stopping_enabled"), equalTo(true));
|
||||
}
|
||||
|
||||
public void testGetParams_GivenRandomWithoutBoostedTreeParams() {
|
||||
|
@ -298,7 +295,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
|
||||
Map<String, Object> params = regression.getParams(null);
|
||||
|
||||
int expectedParamsCount = 4
|
||||
int expectedParamsCount = 5
|
||||
+ (regression.getLossFunctionParameter() == null ? 0 : 1)
|
||||
+ (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
|
||||
assertThat(params.size(), equalTo(expectedParamsCount));
|
||||
|
@ -311,6 +308,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
} else {
|
||||
assertThat(params.get("loss_function_parameter"), equalTo(regression.getLossFunctionParameter()));
|
||||
}
|
||||
assertThat(params.get("early_stopping_enabled"), equalTo(regression.getEarlyStoppingEnabled()));
|
||||
}
|
||||
|
||||
public void testRequiredFieldsIsNonEmpty() {
|
||||
|
|
|
@ -141,6 +141,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null));
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -197,6 +198,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null));
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -317,7 +319,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
|
||||
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
|
||||
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
|
||||
)));
|
||||
),
|
||||
null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -386,7 +389,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
|
||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null,
|
||||
numTopClasses, 50.0, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -650,7 +654,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null, null));
|
||||
putAnalytics(firstJob);
|
||||
startAnalytics(firstJobId);
|
||||
waitUntilAnalyticsIsStopped(firstJobId);
|
||||
|
@ -660,7 +664,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
|
||||
new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null, null));
|
||||
|
||||
putAnalytics(secondJob);
|
||||
startAnalytics(secondJobId);
|
||||
|
|
|
@ -128,7 +128,8 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
|
|||
new OneHotEncoding("ngram.21", MapBuilder.<String, String>newMapBuilder().put("at", "is_cat").map(), true)
|
||||
},
|
||||
true)
|
||||
)))
|
||||
),
|
||||
null))
|
||||
.setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
|
||||
.build();
|
||||
putAnalytics(config);
|
||||
|
|
|
@ -105,6 +105,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
@ -124,6 +125,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
@ -152,6 +154,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null))
|
||||
.buildForExplain();
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
@ -251,7 +252,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(),
|
||||
null, 50.0, null, null, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -371,7 +373,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0,
|
||||
null, null, null, null, null));
|
||||
putAnalytics(firstJob);
|
||||
startAnalytics(firstJobId);
|
||||
waitUntilAnalyticsIsStopped(firstJobId);
|
||||
|
@ -381,7 +384,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
|
||||
DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0,
|
||||
randomizeSeed, null, null, null, null));
|
||||
|
||||
putAnalytics(secondJob);
|
||||
startAnalytics(secondJobId);
|
||||
|
@ -438,7 +442,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(),
|
||||
null, null, null, null, null, null, null));
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
|
@ -465,6 +470,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
@ -562,6 +568,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null);
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(jobId)
|
||||
|
@ -635,7 +642,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
Arrays.asList(
|
||||
new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
|
||||
Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
|
||||
))
|
||||
),
|
||||
null)
|
||||
);
|
||||
putAnalytics(config);
|
||||
|
||||
|
|
|
@ -1169,7 +1169,8 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
|
|||
null,
|
||||
null,
|
||||
null,
|
||||
featureprocessors))
|
||||
featureprocessors,
|
||||
null))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -1513,7 +1513,8 @@ setup:
|
|||
"soft_tree_depth_limit": 2.0,
|
||||
"soft_tree_depth_tolerance": 3.0,
|
||||
"downsample_factor": 0.5,
|
||||
"max_optimization_rounds_per_hyperparameter": 3
|
||||
"max_optimization_rounds_per_hyperparameter": 3,
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1538,7 +1539,8 @@ setup:
|
|||
"soft_tree_depth_limit": 2.0,
|
||||
"soft_tree_depth_tolerance": 3.0,
|
||||
"downsample_factor": 0.5,
|
||||
"max_optimization_rounds_per_hyperparameter": 3
|
||||
"max_optimization_rounds_per_hyperparameter": 3,
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
@ -1870,7 +1872,8 @@ setup:
|
|||
"soft_tree_depth_limit": 2.0,
|
||||
"soft_tree_depth_tolerance": 3.0,
|
||||
"downsample_factor": 0.5,
|
||||
"max_optimization_rounds_per_hyperparameter": 3
|
||||
"max_optimization_rounds_per_hyperparameter": 3,
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1895,7 +1898,8 @@ setup:
|
|||
"soft_tree_depth_limit": 2.0,
|
||||
"soft_tree_depth_tolerance": 3.0,
|
||||
"downsample_factor": 0.5,
|
||||
"max_optimization_rounds_per_hyperparameter": 3
|
||||
"max_optimization_rounds_per_hyperparameter": 3,
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
@ -1939,7 +1943,8 @@ setup:
|
|||
"training_percent": 100.0,
|
||||
"randomize_seed": 24,
|
||||
"class_assignment_objective": "maximize_minimum_recall",
|
||||
"num_top_classes": 2
|
||||
"num_top_classes": 2,
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
@ -1977,7 +1982,8 @@ setup:
|
|||
"prediction_field_name": "foo_prediction",
|
||||
"training_percent": 100.0,
|
||||
"randomize_seed": 42,
|
||||
"loss_function": "mse"
|
||||
"loss_function": "mse",
|
||||
"early_stopping_enabled": true
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
|
Loading…
Reference in New Issue