public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("CombineModels"); host.CheckValue(input, nameof(input)); host.CheckNonEmpty(input.Models, nameof(input.Models)); GetPipeline(host, input, out IDataView startingData, out RoleMappedData transformedData); var args = new RegressionEnsembleTrainer.Arguments(); switch (input.ModelCombiner) { case ScoreCombiner.Median: args.OutputCombiner = new MedianFactory(); break; case ScoreCombiner.Average: args.OutputCombiner = new AverageFactory(); break; default: throw host.Except("Unknown combiner kind"); } var trainer = new RegressionEnsembleTrainer(host, args); var ensemble = trainer.CombineModels(input.Models.Select(pm => pm.Predictor as IPredictorProducing <float>)); var predictorModel = new PredictorModelImpl(host, transformedData, startingData, ensemble); var output = new CommonOutputs.RegressionOutput { PredictorModel = predictorModel }; return(output); }
public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvironment env, RegressionEnsembleTrainer.Arguments input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("TrainRegressionEnsemble"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); return(TrainerEntryPointsUtils.Train <RegressionEnsembleTrainer.Arguments, CommonOutputs.RegressionOutput>(host, input, () => new RegressionEnsembleTrainer(host, input), () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName))); }