private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); Args = args; InitParallelTraining(); }
private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) : base(env, name) { Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); }
public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, LightGbmArguments input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainLightGBM"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); return(LearnerEntryPointsUtils.Train <LightGbmArguments, CommonOutputs.RankingOutput>(host, input, () => new LightGbmRankingTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn), getGroup: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.GroupIdColumn))); }
protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); _predictionKind = predictionKind; _env = env; ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); }
/// <summary> /// If, after applying the advancedSettings delegate, the args are different that the default value /// and are also different than the value supplied directly to the xtension method, warn the user /// about which value is being used. /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. /// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>. /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. /// </summary> protected void CheckArgsAndAdvancedSettingMismatch(int?numLeaves, int?minDataPerLeaf, double?learningRate, int numBoostRound, LightGbmArguments snapshot, LightGbmArguments currentArgs) { using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) { // Check that the user didn't supply different parameters in the args, from what it specified directly. TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound)); TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf)); TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate)); } }
private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn, string groupIdColumn, int?numLeaves, int?minDataPerLeaf, double?learningRate, int numBoostRound, Action <LightGbmArguments> advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); Args.NumLeaves = numLeaves; Args.MinDataPerLeaf = minDataPerLeaf; Args.LearningRate = learningRate; Args.NumBoostRound = numBoostRound; //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) { Args.WeightColumn = Optional <string> .Explicit(weightColumn); } if (groupIdColumn != null) { Args.GroupIdColumn = Optional <string> .Explicit(groupIdColumn); } InitParallelTraining(); }
private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action <LightGbmArguments> advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new LightGbmArguments(); //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) { Args.WeightColumn = weightColumn; } if (groupIdColumn != null) { Args.GroupIdColumn = groupIdColumn; } InitParallelTraining(); }
internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { }
internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { _numClass = -1; }
public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, PredictionKind.MultiClassClassification, "LightGBMMulticlass") { _numClass = -1; }
internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) { }
public LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, PredictionKind.Regression, "LightGBMRegressor") { }
public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, LoadNameValue) { }
public LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, PredictionKind.Ranking, "LightGBMRanking") { }
public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, PredictionKind.BinaryClassification, "LGBBINCL") { }
public LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) : base(env, args, LoadNameValue) { _numClass = -1; }