Exemple #1
0
        private static TOut CreatePipelineEnsemble <TOut>(IHostEnvironment env, PredictorModel[] predictors, SchemaBindablePipelineEnsembleBase ensemble)
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            var inputSchema = predictors[0].TransformModel.InputSchema;
            var dv          = new EmptyDataView(env, inputSchema);

            // The role mappings are specific to the individual predictors.
            var rmd            = new RoleMappedData(dv);
            var predictorModel = new PredictorModelImpl(env, rmd, dv, ensemble);

            var output = new TOut {
                PredictorModel = predictorModel
            };

            return(output);
        }
Exemple #2
0
        public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHostEnvironment env, ClassifierInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineModels");

            host.CheckValue(input, nameof(input));
            host.CheckNonEmpty(input.Models, nameof(input.Models));

            GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData);

            var args = new EnsembleTrainer.Arguments();

            switch (input.ModelCombiner)
            {
            case ClassifierCombiner.Median:
                args.OutputCombiner = new MedianFactory();
                break;

            case ClassifierCombiner.Average:
                args.OutputCombiner = new AverageFactory();
                break;

            case ClassifierCombiner.Vote:
                args.OutputCombiner = new VotingFactory();
                break;

            default:
                throw host.Except("Unknown combiner kind");
            }

            var trainer  = new EnsembleTrainer(host, args);
            var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing <float>));

            var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble);

            var output = new CommonOutputs.BinaryClassificationOutput {
                PredictorModel = predictorModel
            };

            return(output);
        }
        public void SetInputFromPath(GraphRunner runner, string varName, string path, TlcModule.DataKind kind)
        {
            _host.CheckUserArg(runner != null, nameof(runner), "Provide a GraphRunner instance.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(varName), nameof(varName), "Specify a graph variable name.");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(path), nameof(path), "Specify a valid file path.");

            switch (kind)
            {
            case TlcModule.DataKind.FileHandle:
                var fh = new SimpleFileHandle(_host, path, false, false);
                runner.SetInput(varName, fh);
                break;

            case TlcModule.DataKind.DataView:
                IDataView loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), path);
                runner.SetInput(varName, loader);
                break;

            case TlcModule.DataKind.PredictorModel:
                PredictorModelImpl pm;
                using (var fs = File.OpenRead(path))
                    pm = new PredictorModelImpl(_host, fs);
                runner.SetInput(varName, pm);
                break;

            case TlcModule.DataKind.TransformModel:
                TransformModelImpl tm;
                using (var fs = File.OpenRead(path))
                    tm = new TransformModelImpl(_host, fs);
                runner.SetInput(varName, tm);
                break;

            default:
                throw _host.Except("Port type {0} not supported", kind);
            }
        }