internal TreeEnsembleFeaturizationTransformer(IHostEnvironment env, DataViewSchema inputSchema, DataViewSchema.Column featureColumn, TreeEnsembleModelParameters modelParameters, string treesColumnName, string leavesColumnName, string pathsColumnName) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TreeEnsembleFeaturizationTransformer)), modelParameters, inputSchema) { // Store featureColumn as a detached column because a fitted transformer can be applied to different IDataViews and different // IDataView may have different schemas. _featureDetachedColumn = new DataViewSchema.DetachedColumn(featureColumn); // Check if featureColumn matches a column in inputSchema. The answer is yes if they have the same name and type. // The indexed column, inputSchema[featureColumn.Index], should match the detached column, _featureDetachedColumn. CheckFeatureColumnCompatibility(inputSchema[featureColumn.Index]); // Store output column names so that this transformer can be saved into a file later. _treesColumnName = treesColumnName; _leavesColumnName = leavesColumnName; _pathsColumnName = pathsColumnName; // Create an argument, _scorerArgs, to pass the output column names to the underlying scorer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(env, _scorerArgs, modelParameters); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(inputSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx) { // *** Binary format *** // <base info> // string: feature column's name. // string: the name of the columns where tree prediction values are stored. // string: the name of the columns where trees' leave are stored. // string: the name of the columns where trees' paths are stored. // Load stored fields. string featureColumnName = ctx.LoadString(); _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]); _treesColumnName = ctx.LoadStringOrNull(); _leavesColumnName = ctx.LoadStringOrNull(); _pathsColumnName = ctx.LoadStringOrNull(); // Create an argument to specify output columns' names of this transformer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
// Factory method for SignatureDataTransform. private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile), "Please specify either a trainer or an input model file."); host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column"); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; if (!string.IsNullOrWhiteSpace(args.TrainedModelFile)) { if (args.Trainer != null) { ch.Warning("Both an input model and a trainer were specified. Using the model file."); } ch.Trace("Loading model"); IPredictor predictor; using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read)) using (var rep = RepositoryReader.Open(strm, ch)) ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor); ch.Trace("Creating scorer"); var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args); Contracts.Assert(data.Schema.Feature.HasValue); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } else { ch.AssertValue(args.Trainer); ch.Trace("Creating TrainAndScoreTransform"); var trainScoreArgs = new TrainAndScoreTransformer.Arguments(); args.CopyTo(trainScoreArgs); trainScoreArgs.Trainer = args.Trainer; trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); var scoreXf = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory); if (input == labelInput) { return(scoreXf); } return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput)); } } return(xf); }