private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Host.CheckValue(args, nameof(args));

            Args = args;
            InitParallelTraining();
        }
        private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name)
            : base(env, name)
        {
            Host.CheckValue(args, nameof(args));

            Args             = args;
            Options          = Args.ToDictionary(Host);
            ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();

            InitParallelTraining();
        }
        public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, LightGbmArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainLightGBM");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <LightGbmArguments, CommonOutputs.RankingOutput>(host, input,
                                                                                                  () => new LightGbmRankingTrainer(host, input),
                                                                                                  getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                  getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
                                                                                                  getGroup: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.GroupIdColumn)));
        }
Exemple #4
0
        protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckNonWhiteSpace(name, nameof(name));

            Host = env.Register(name);
            Host.CheckValue(args, nameof(args));

            Args             = args;
            Options          = Args.ToDictionary(Host);
            _predictionKind  = predictionKind;
            _env             = env;
            ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer();

            InitParallelTraining();
        }
 /// <summary>
 /// If, after applying the advancedSettings delegate, the args are different that the default value
 /// and are also different than the value supplied directly to the xtension method, warn the user
 /// about which value is being used.
 /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
 /// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
 /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
 /// </summary>
 protected void CheckArgsAndAdvancedSettingMismatch(int?numLeaves,
                                                    int?minDataPerLeaf,
                                                    double?learningRate,
                                                    int numBoostRound,
                                                    LightGbmArguments snapshot,
                                                    LightGbmArguments currentArgs)
 {
     using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
     {
         // Check that the user didn't supply different parameters in the args, from what it specified directly.
         TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
         TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound));
         TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf));
         TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate));
     }
 }
        private protected LightGbmTrainerBase(IHostEnvironment env,
                                              string name,
                                              SchemaShape.Column label,
                                              string featureColumn,
                                              string weightColumn,
                                              string groupIdColumn,
                                              int?numLeaves,
                                              int?minDataPerLeaf,
                                              double?learningRate,
                                              int numBoostRound,
                                              Action <LightGbmArguments> advancedSettings)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
        {
            Args = new LightGbmArguments();

            Args.NumLeaves      = numLeaves;
            Args.MinDataPerLeaf = minDataPerLeaf;
            Args.LearningRate   = learningRate;
            Args.NumBoostRound  = numBoostRound;

            //apply the advanced args, if the user supplied any
            advancedSettings?.Invoke(Args);

            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = Optional <string> .Explicit(weightColumn);
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = Optional <string> .Explicit(groupIdColumn);
            }

            InitParallelTraining();
        }
        private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn,
                                              string weightColumn = null, string groupIdColumn = null, Action <LightGbmArguments> advancedSettings = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Args = new LightGbmArguments();

            //apply the advanced args, if the user supplied any
            advancedSettings?.Invoke(Args);
            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = weightColumn;
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = groupIdColumn;
            }

            InitParallelTraining();
        }
Exemple #8
0
 internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn))
 {
 }
 internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
 {
     _numClass = -1;
 }
 public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, PredictionKind.MultiClassClassification, "LightGBMMulticlass")
 {
     _numClass = -1;
 }
Exemple #11
0
 internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
 {
 }
 public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, PredictionKind.Regression, "LightGBMRegressor")
 {
 }
 public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, LoadNameValue)
 {
 }
 public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, PredictionKind.Ranking, "LightGBMRanking")
 {
 }
 public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, PredictionKind.BinaryClassification, "LGBBINCL")
 {
 }
Exemple #16
0
 public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args)
     : base(env, args, LoadNameValue)
 {
     _numClass = -1;
 }