public XGBoostMulticlassTrainer(IHostEnvironment env, XGBoostArguments args) : base(env, args, PredictionKind.MultiClassClassification, "eXGBoostMulticlass") { }
public XGBoostRegressorTrainer(IHostEnvironment env, XGBoostArguments args) : base(env, args, PredictionKind.Regression, "eXGBoostRegressor") { }
public XGBoostRankingTrainer(IHostEnvironment env, XGBoostArguments args) : base(env, args, PredictionKind.Ranking, "eXGBoostRanking") { }
public XGBoostBinaryTrainer(IHostEnvironment env, XGBoostArguments args) : base(env, args, PredictionKind.BinaryClassification, "eXGBoostBinary") { }
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, XGBoostArguments input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Train" + EntryPointName); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); return(LearnerEntryPointsUtils.Train <XGBoostArguments, CommonOutputs.BinaryClassificationOutput>(host, input, () => new XGBoostBinaryTrainer(host, input), getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn))); }