Beispiel #1
0
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainAP");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                                       () => new AveragedPerceptronTrainer(host, input),
                                                                                                       () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                       calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples));
        }
Beispiel #2
0
        public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainBinaryEnsemble");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <EnsembleTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                                                       () => new EnsembleTrainer(host, input),
                                                                                                                       () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)));
        }
        public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainLRMultiClass");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <MulticlassLogisticRegression.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
                                                                                                                                        () => new MulticlassLogisticRegression(host, input),
                                                                                                                                        () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                                                        () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)));
        }
Beispiel #4
0
        public static CommonOutputs.RankingOutput TrainRanking(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainLightGBM");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.RankingOutput>(host, input,
                                                                                        () => new LightGbmRankingTrainer(host, input),
                                                                                        getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                        getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
                                                                                        getGroup: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.GroupIdColumn)));
        }
Beispiel #5
0
        public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, FastForestRegression.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainFastForest");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <FastForestRegression.Arguments, CommonOutputs.RegressionOutput>(host, input,
                                                                                                                  () => new FastForestRegression(host, input),
                                                                                                                  () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                                  () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
                                                                                                                  () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.GroupIdColumn)));
        }
Beispiel #6
0
        internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Train a field-aware factorization machine");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.BinaryClassificationOutput>(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input),
                                                                                                     () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)));
        }
Beispiel #7
0
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, BinaryClassificationGamTrainer.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainGAM");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <BinaryClassificationGamTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                                                                    () => new BinaryClassificationGamTrainer(host, input),
                                                                                                                                    () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                                                    () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)));
        }
        internal static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainOGD");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.RegressionOutput>(host, input,
                                                                                           () => new OnlineGradientDescentTrainer(host, input),
                                                                                           () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)));
        }
        public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainOLS");

            host.CheckValue(options, nameof(options));
            EntryPointUtils.CheckInputArgs(host, options);

            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.RegressionOutput>(host, options,
                                                                                           () => new OlsLinearRegressionTrainer(host, options),
                                                                                           () => LearnerEntryPointsUtils.FindColumn(host, options.TrainingData.Schema, options.LabelColumn),
                                                                                           () => LearnerEntryPointsUtils.FindColumn(host, options.TrainingData.Schema, options.WeightColumn)));
        }
Beispiel #10
0
        public static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironment env, Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainPCAAnomaly");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <Arguments, CommonOutputs.AnomalyDetectionOutput>(host, input,
                                                                                                   () => new RandomizedPcaTrainer(host, input),
                                                                                                   getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)));
        }
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, NearestNeighborsBinaryClassificationTrainer_ArgumentsEntryPoint input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Train" + NearestNeighborsBinary.Name);

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(EntryPointsHelper.Train <NearestNeighborsBinaryClassificationTrainer_ArgumentsEntryPoint,
                                            CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                      () => new NearestNeighborsBinaryClassificationTrainer(host, input),
                                                                                      getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                      getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)));
        }
Beispiel #12
0
        internal static CommonOutputs.MulticlassClassificationOutput TrainMultiClassNaiveBayesTrainer(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainMultiClassNaiveBayes");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
                                                                                                         () => new MultiClassNaiveBayesTrainer(host, input),
                                                                                                         () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)));
        }
Beispiel #13
0
        internal static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainSymSGD");

            host.CheckValue(options, nameof(options));
            EntryPointUtils.CheckInputArgs(host, options);

            return(LearnerEntryPointsUtils.Train <Options, CommonOutputs.BinaryClassificationOutput>(host, options,
                                                                                                     () => new SymSgdClassificationTrainer(host, options),
                                                                                                     () => LearnerEntryPointsUtils.FindColumn(host, options.TrainingData.Schema, options.LabelColumn)));
        }
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestClassification.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainFastForest");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <FastForestClassification.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                                                              () => new FastForestClassification(host, input),
                                                                                                                              () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                                                              () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn),
                                                                                                                              () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.GroupIdColumn),
                                                                                                                              calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples));
        }
        public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, XGBoostArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Train" + EntryPointName);

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <XGBoostArguments,
                                                  CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                            () => new XGBoostBinaryTrainer(host, input),
                                                                                            getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
                                                                                            getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)));
        }