public ObjectiveImpl(Dataset trainData, GamRegressionTrainer.Options options) :
     base(
         trainData,
         options.LearningRate,
         0,
         options.MaximumTreeOutput,
         options.GetDerivativesSampleRate,
         false,
         options.Seed)
 {
     _labels = GetDatasetRegressionLabels(trainData);
 }
Ejemplo n.º 2
0
        public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, GamRegressionTrainer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainGAM");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(TrainerEntryPointsUtils.Train <GamRegressionTrainer.Options, CommonOutputs.RegressionOutput>(host, input,
                                                                                                                () => new GamRegressionTrainer(host, input),
                                                                                                                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
                                                                                                                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName)));
        }