public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineModels");

            host.CheckValue(input, nameof(input));
            host.CheckNonEmpty(input.Models, nameof(input.Models));

            GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);

            var args = new EnsembleTrainer.Arguments();

            switch (input.ModelCombiner)
            {
            case ClassifierCombiner.Median:
                args.OutputCombiner = new MedianFactory();
                break;

            case ClassifierCombiner.Average:
                args.OutputCombiner = new AverageFactory();
                break;

            case ClassifierCombiner.Vote:
                args.OutputCombiner = new VotingFactory();
                break;

            default:
                throw host.Except("Unknown combiner kind");
            }

            var trainer  = new EnsembleTrainer(host, args);
            var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing <float>));

            var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);

            var output = new CommonOutputs.BinaryClassificationOutput {
                PredictorModel = predictorModel
            };

            return(output);
        }
Example #2
0
        public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, EnsembleTrainer.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainBinaryEnsemble");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return(LearnerEntryPointsUtils.Train <EnsembleTrainer.Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
                                                                                                                       () => new EnsembleTrainer(host, input),
                                                                                                                       () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)));
        }