Пример #1
0
        protected CalibratorEstimatorBase(IHostEnvironment env,
                                          TCalibratorTrainer calibratorTrainer,
                                          IPredictor predictor = null,
                                          string labelColumn   = DefaultColumnNames.Label,
                                          string featureColumn = DefaultColumnNames.Features,
                                          string weightColumn  = null)
        {
            Host              = env;
            Predictor         = predictor;
            CalibratorTrainer = calibratorTrainer;

            ScoreColumn    = TrainerUtils.MakeR4ScalarColumn(DefaultColumnNames.Score); // Do we fantom this being named anything else (renaming column)? Complete metadata?
            LabelColumn    = TrainerUtils.MakeBoolScalarLabel(labelColumn);
            FeatureColumn  = TrainerUtils.MakeR4VecFeature(featureColumn);
            PredictedLabel = new SchemaShape.Column(DefaultColumnNames.PredictedLabel,
                                                    SchemaShape.Column.VectorKind.Scalar,
                                                    BoolType.Instance,
                                                    false,
                                                    new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()));

            if (weightColumn != null)
            {
                WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn);
            }
        }
        private protected LightGbmTrainerBase(IHostEnvironment env,
                                              string name,
                                              SchemaShape.Column labelColumn,
                                              string featureColumnName,
                                              string exampleWeightColumnName,
                                              string rowGroupColumnName,
                                              int?numberOfLeaves,
                                              int?minimumExampleCountPerLeaf,
                                              double?learningRate,
                                              int numberOfIterations)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumnName),
                   labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(exampleWeightColumnName), TrainerUtils.MakeU4ScalarColumn(rowGroupColumnName))
        {
            LightGbmTrainerOptions = new Options();

            LightGbmTrainerOptions.NumberOfLeaves             = numberOfLeaves;
            LightGbmTrainerOptions.MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf;
            LightGbmTrainerOptions.LearningRate       = learningRate;
            LightGbmTrainerOptions.NumberOfIterations = numberOfIterations;

            LightGbmTrainerOptions.LabelColumnName         = labelColumn.Name;
            LightGbmTrainerOptions.FeatureColumnName       = featureColumnName;
            LightGbmTrainerOptions.ExampleWeightColumnName = exampleWeightColumnName;
            LightGbmTrainerOptions.RowGroupColumnName      = rowGroupColumnName;

            InitParallelTraining();
        }
Пример #3
0
 /// <summary>
 /// Initializes a new instance of <see cref="OrdinaryLeastSquaresRegressionTrainer"/>
 /// </summary>
 internal OrdinaryLeastSquaresRegressionTrainer(IHostEnvironment env, Options options)
     : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
            TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName), TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName))
 {
     Host.CheckValue(options, nameof(options));
     Host.CheckUserArg(options.L2Weight >= 0, nameof(options.L2Weight), "L2 regularization term cannot be negative");
     _l2Weight = options.L2Weight;
     _perParameterSignificance = options.PerParameterSignificance;
 }
 /// <summary>
 /// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
 /// </summary>
 internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
     : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
            TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
 {
     Host.CheckValue(args, nameof(args));
     Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
     _l2Weight = args.L2Weight;
     _perParameterSignificance = args.PerParameterSignificance;
 }
Пример #5
0
        internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null)
            : base(env, args, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Host.CheckValue(labelColumn, nameof(labelColumn));
            Host.CheckValue(featureColumn, nameof(featureColumn));

            _loss = args.LossFunction.CreateComponent(env);
            Loss  = _loss;
        }
        private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options options, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), label,
                   TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName), TrainerUtils.MakeU4ScalarColumn(options.RowGroupColumnName))
        {
            Host.CheckValue(options, nameof(options));

            LightGbmTrainerOptions = options;
            InitParallelTraining();
        }
        internal LbfgsTrainerBase(IHostEnvironment env,
                                  TArgs args,
                                  SchemaShape.Column labelColumn,
                                  Action <TArgs> advancedSettings = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
                   labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
        {
            Host.CheckValue(args, nameof(args));
            Args = args;

            // Apply the advanced args, if the user supplied any.
            advancedSettings?.Invoke(args);

            args.FeatureColumn = FeatureColumn.Name;
            args.LabelColumn   = LabelColumn.Name;
            args.WeightColumn  = WeightColumn?.Name;
            Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
                              nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
            Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
            Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
            Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
            Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
            Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
            Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
            Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

            Host.CheckParam(!(Args.L2Weight < 0), nameof(Args.L2Weight), "Must be non-negative, if provided.");
            Host.CheckParam(!(Args.L1Weight < 0), nameof(Args.L1Weight), "Must be non-negative, if provided");
            Host.CheckParam(!(Args.OptTol <= 0), nameof(Args.OptTol), "Must be positive, if provided.");
            Host.CheckParam(!(Args.MemorySize <= 0), nameof(Args.MemorySize), "Must be positive, if provided.");

            L2Weight      = Args.L2Weight;
            L1Weight      = Args.L1Weight;
            OptTol        = Args.OptTol;
            MemorySize    = Args.MemorySize;
            MaxIterations = Args.MaxIterations;
            SgdInitializationTolerance = Args.SgdInitializationTolerance;
            Quiet                = Args.Quiet;
            InitWtsDiameter      = Args.InitWtsDiameter;
            UseThreads           = Args.UseThreads;
            NumThreads           = Args.NumThreads;
            DenseOptimizer       = Args.DenseOptimizer;
            EnforceNonNegativity = Args.EnforceNonNegativity;

            if (EnforceNonNegativity && ShowTrainingStats)
            {
                ShowTrainingStats = false;
                using (var ch = Host.Start("Initialization"))
                {
                    ch.Warning("The training statistics cannot be computed with non-negativity constraint.");
                }
            }

            ShowTrainingStats = false;
            _srcPredictor     = default;
        }
Пример #8
0
        internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn,
                                  float?l1Weight = null,
                                  float?l2Weight = null,
                                  float?optimizationTolerance = null,
                                  int?memorySize           = null,
                                  bool?enforceNoNegativity = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
                   labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Host.CheckValue(args, nameof(args));
            Args = args;

            Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
                              nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
            Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
            Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
            Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
            Host.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
            Host.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
            Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
            Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

            Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided.");
            Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided");
            Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided.");
            Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided.");

            // Review: Warn about the overriding behavior
            L2Weight      = l2Weight ?? Args.L2Weight;
            L1Weight      = l1Weight ?? Args.L1Weight;
            OptTol        = optimizationTolerance ?? Args.OptTol;
            MemorySize    = memorySize ?? Args.MemorySize;
            MaxIterations = Args.MaxIterations;
            SgdInitializationTolerance = Args.SgdInitializationTolerance;
            Quiet                = Args.Quiet;
            InitWtsDiameter      = Args.InitWtsDiameter;
            UseThreads           = Args.UseThreads;
            NumThreads           = Args.NumThreads;
            DenseOptimizer       = Args.DenseOptimizer;
            EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity;

            if (EnforceNonNegativity && ShowTrainingStats)
            {
                ShowTrainingStats = false;
                using (var ch = Host.Start("Initialization"))
                {
                    ch.Warning("The training statistics cannot be computed with non-negativity constraint.");
                    ch.Done();
                }
            }

            ShowTrainingStats = false;
            _srcPredictor     = default;
        }
Пример #9
0
        public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), args, TrainerUtils.MakeR4VecFeature(featureColumn),
                   TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Host.CheckValue(labelColumn, nameof(labelColumn));
            Host.CheckValue(featureColumn, nameof(featureColumn));

            _loss = args.LossFunction.CreateComponent(env);
            Loss  = _loss;
            _args = args;
        }
Пример #10
0
 /// <summary>
 /// Initializes a new instance of <see cref="SdcaMultiClassTrainer"/>
 /// </summary>
 /// <param name="env">The environment to use.</param>
 /// <param name="featureColumn">The features, or independent variables.</param>
 /// <param name="labelColumn">The label, or dependent variable.</param>
 /// <param name="loss">The custom loss.</param>
 /// <param name="weightColumn">The optional example weights.</param>
 /// <param name="l2Const">The L2 regularization hyperparameter.</param>
 /// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
 /// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
 /// <param name="advancedSettings">A delegate to set more settings.</param>
 public SdcaMultiClassTrainer(IHostEnvironment env,
                              string featureColumn,
                              string labelColumn,
                              string weightColumn = null,
                              ISupportSdcaClassificationLoss loss = null,
                              float?l2Const     = null,
                              float?l1Threshold = null,
                              int?maxIterations = null,
                              Action <Arguments> advancedSettings = null)
     : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings,
            l2Const, l1Threshold, maxIterations)
 {
     Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
     Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
     _loss = loss ?? Args.LossFunction.CreateComponent(env);
     Loss  = _loss;
 }
        private protected CalibratorEstimatorBase(IHostEnvironment env,
                                                  ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn)
        {
            Host = env;
            _calibratorTrainer = calibratorTrainer;

            if (!string.IsNullOrWhiteSpace(labelColumn))
            {
                LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn);
            }
            else
            {
                env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified.");
            }
            ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata?

            if (weightColumn != null)
            {
                WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn);
            }
        }
Пример #12
0
        private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
                   label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Contracts.CheckValue(env, nameof(env));
            Host.CheckValue(args, nameof(args));

            Host.CheckParam(args.LearningRates > 0, nameof(args.LearningRates), "Must be positive.");
            Host.CheckParam(args.NumThreads == null || args.NumThreads > 0, nameof(args.NumThreads), "Must be positive.");
            Host.CheckParam(0 <= args.EntropyCoefficient && args.EntropyCoefficient <= 1, nameof(args.EntropyCoefficient), "Must be in [0, 1].");
            Host.CheckParam(0 <= args.GainConfidenceLevel && args.GainConfidenceLevel < 1, nameof(args.GainConfidenceLevel), "Must be in [0, 1).");
            Host.CheckParam(0 < args.MaxBins, nameof(args.MaxBins), "Must be posittive.");
            Host.CheckParam(0 < args.NumIterations, nameof(args.NumIterations), "Must be positive.");
            Host.CheckParam(0 < args.MinDocuments, nameof(args.MinDocuments), "Must be positive.");

            Args = args;

            Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
            _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2);
            _entropyCoefficient = Args.EntropyCoefficient * 1e-6;

            InitializeThreads();
        }
Пример #13
0
        private protected GamTrainerBase(IHostEnvironment env, TOptions options, string name, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
                   label, TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName))
        {
            Contracts.CheckValue(env, nameof(env));
            Host.CheckValue(options, nameof(options));

            Host.CheckParam(options.LearningRate > 0, nameof(options.LearningRate), "Must be positive.");
            Host.CheckParam(options.NumberOfThreads == null || options.NumberOfThreads > 0, nameof(options.NumberOfThreads), "Must be positive.");
            Host.CheckParam(0 <= options.EntropyCoefficient && options.EntropyCoefficient <= 1, nameof(options.EntropyCoefficient), "Must be in [0, 1].");
            Host.CheckParam(0 <= options.GainConfidenceLevel && options.GainConfidenceLevel < 1, nameof(options.GainConfidenceLevel), "Must be in [0, 1).");
            Host.CheckParam(0 < options.MaximumBinCountPerFeature, nameof(options.MaximumBinCountPerFeature), "Must be positive.");
            Host.CheckParam(0 < options.NumberOfIterations, nameof(options.NumberOfIterations), "Must be positive.");
            Host.CheckParam(0 < options.MinimumExampleCountPerLeaf, nameof(options.MinimumExampleCountPerLeaf), "Must be positive.");

            GamTrainerOptions = options;

            Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
            _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2);
            _entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6;

            InitializeThreads();
        }
Пример #14
0
        private RandomizedPcaTrainer(IHostEnvironment env, Options options, string featureColumn, string weightColumn,
                                     int rank = 20, int oversampling = 20, bool center = true, int?seed = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            // if the args are not null, we got here from maml, and the internal ctor.
            if (options != null)
            {
                _rank         = options.Rank;
                _center       = options.Center;
                _oversampling = options.Oversampling;
                _seed         = options.Seed ?? Host.Rand.Next();
            }
            else
            {
                _rank         = rank;
                _center       = center;
                _oversampling = oversampling;
                _seed         = seed ?? Host.Rand.Next();
            }

            _featureColumn = featureColumn;

            Host.CheckUserArg(_rank > 0, nameof(_rank), "Rank must be positive");
            Host.CheckUserArg(_oversampling >= 0, nameof(_oversampling), "Oversampling must be non-negative");
        }
        protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights))
        {
            Contracts.CheckValue(args, nameof(args));
            Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
            Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
            Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);

            Args = args;
            Name = name;
            // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
            Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
        }
Пример #16
0
        private protected GamTrainerBase(IHostEnvironment env,
                                         string name,
                                         SchemaShape.Column label,
                                         string featureColumnName,
                                         string weightCrowGroupColumnName,
                                         int numberOfIterations,
                                         double learningRate,
                                         int maximumBinCountPerFeature)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumnName), label, TrainerUtils.MakeR4ScalarWeightColumn(weightCrowGroupColumnName))
        {
            GamTrainerOptions = new TOptions();
            GamTrainerOptions.NumberOfIterations        = numberOfIterations;
            GamTrainerOptions.LearningRate              = learningRate;
            GamTrainerOptions.MaximumBinCountPerFeature = maximumBinCountPerFeature;

            GamTrainerOptions.LabelColumnName   = label.Name;
            GamTrainerOptions.FeatureColumnName = featureColumnName;

            if (weightCrowGroupColumnName != null)
            {
                GamTrainerOptions.ExampleWeightColumnName = weightCrowGroupColumnName;
            }

            Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
            _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - GamTrainerOptions.GainConfidenceLevel) * 0.5), 2);
            _entropyCoefficient = GamTrainerOptions.EntropyCoefficient * 1e-6;

            InitializeThreads();
        }
Пример #17
0
        private protected LightGbmTrainerBase(IHostEnvironment env,
                                              string name,
                                              SchemaShape.Column label,
                                              string featureColumn,
                                              string weightColumn,
                                              string groupIdColumn,
                                              int?numLeaves,
                                              int?minDataPerLeaf,
                                              double?learningRate,
                                              int numBoostRound)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
        {
            Args = new Options();

            Args.NumLeaves      = numLeaves;
            Args.MinDataPerLeaf = minDataPerLeaf;
            Args.LearningRate   = learningRate;
            Args.NumBoostRound  = numBoostRound;

            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = Optional <string> .Explicit(weightColumn);
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = Optional <string> .Explicit(groupIdColumn);
            }

            InitParallelTraining();
        }
        private protected LightGbmTrainerBase(IHostEnvironment env,
                                              string name,
                                              SchemaShape.Column label,
                                              string featureColumn,
                                              string weightColumn,
                                              string groupIdColumn,
                                              int?numLeaves,
                                              int?minDataPerLeaf,
                                              double?learningRate,
                                              int numBoostRound)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
        {
            LightGbmTrainerOptions = new Options();

            LightGbmTrainerOptions.NumLeaves      = numLeaves;
            LightGbmTrainerOptions.MinDataPerLeaf = minDataPerLeaf;
            LightGbmTrainerOptions.LearningRate   = learningRate;
            LightGbmTrainerOptions.NumBoostRound  = numBoostRound;

            LightGbmTrainerOptions.LabelColumnName         = label.Name;
            LightGbmTrainerOptions.FeatureColumnName       = featureColumn;
            LightGbmTrainerOptions.ExampleWeightColumnName = weightColumn;
            LightGbmTrainerOptions.RowGroupColumnName      = groupIdColumn;

            InitParallelTraining();
        }
Пример #19
0
        private protected GamTrainerBase(IHostEnvironment env,
                                         string name,
                                         SchemaShape.Column label,
                                         string featureColumn,
                                         string weightColumn,
                                         int numIterations,
                                         double learningRate,
                                         int maxBins)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Args = new TArgs();
            Args.NumIterations = numIterations;
            Args.LearningRates = learningRate;
            Args.MaxBins       = maxBins;

            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = weightColumn;
            }

            Info = new TrainerInfo(normalization: false, calibration: NeedCalibration, caching: false, supportValid: true);
            _gainConfidenceInSquaredStandardDeviations = Math.Pow(ProbabilityFunctions.Probit(1 - (1 - Args.GainConfidenceLevel) * 0.5), 2);
            _entropyCoefficient = Args.EntropyCoefficient * 1e-6;

            InitializeThreads();
        }
        private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Host.CheckValue(args, nameof(args));

            Args = args;
            InitParallelTraining();
        }
        private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn,
                                              string weightColumn = null, string groupIdColumn = null, Action <LightGbmArguments> advancedSettings = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Args = new LightGbmArguments();

            //apply the advanced args, if the user supplied any
            advancedSettings?.Invoke(Args);

            // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
            TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);

            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = weightColumn;
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = groupIdColumn;
            }

            InitParallelTraining();
        }