internal TreeEnsembleFeaturizationTransformer(IHostEnvironment env, DataViewSchema inputSchema, DataViewSchema.Column featureColumn, TreeEnsembleModelParameters modelParameters, string treesColumnName, string leavesColumnName, string pathsColumnName) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TreeEnsembleFeaturizationTransformer)), modelParameters, inputSchema) { // Store featureColumn as a detached column because a fitted transformer can be applied to different IDataViews and different // IDataView may have different schemas. _featureDetachedColumn = new DataViewSchema.DetachedColumn(featureColumn); // Check if featureColumn matches a column in inputSchema. The answer is yes if they have the same name and type. // The indexed column, inputSchema[featureColumn.Index], should match the detached column, _featureDetachedColumn. CheckFeatureColumnCompatibility(inputSchema[featureColumn.Index]); // Store output column names so that this transformer can be saved into a file later. _treesColumnName = treesColumnName; _leavesColumnName = leavesColumnName; _pathsColumnName = pathsColumnName; // Create an argument, _scorerArgs, to pass the output column names to the underlying scorer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(env, _scorerArgs, modelParameters); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(inputSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx) : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx) { // *** Binary format *** // <base info> // string: feature column's name. // string: the name of the columns where tree prediction values are stored. // string: the name of the columns where trees' leave are stored. // string: the name of the columns where trees' paths are stored. // Load stored fields. string featureColumnName = ctx.LoadString(); _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]); _treesColumnName = ctx.LoadStringOrNull(); _leavesColumnName = ctx.LoadStringOrNull(); _pathsColumnName = ctx.LoadStringOrNull(); // Create an argument to specify output columns' names of this transformer. _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments { TreesColumnName = _treesColumnName, LeavesColumnName = _leavesColumnName, PathsColumnName = _pathsColumnName }; // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce // a transformed IDataView. BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model); // Create a scorer. var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema); Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema); }
// Factory method for SignatureBindableMapper. private static ISchemaBindableMapper Create(IHostEnvironment env, TreeEnsembleFeaturizerBindableMapper.Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); env.CheckValue(predictor, nameof(predictor)); return(new TreeEnsembleFeaturizerBindableMapper(env, args, predictor)); }
public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model."); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; var predictor = args.PredictorModel.Predictor; ch.Trace("Prepare data"); RoleMappedData data = null; args.PredictorModel.PrepareData(env, input, out data, out var predictor2); ch.AssertValue(data); ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); ch.Done(); } return(xf); }
public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model."); using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix, TreesColumnName = "Trees", LeavesColumnName = "Leaves", PathsColumnName = "Paths" }; var predictor = args.PredictorModel.Predictor; ch.Trace("Prepare data"); RoleMappedData data = null; args.PredictorModel.PrepareData(env, input, out data, out var predictor2); ch.AssertValue(data); ch.Assert(data.Schema.Feature.HasValue); ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedModelParametersBase <IPredictorProducing <float>, Calibrators.ICalibrator> calibratedModelParametersBase) { predictor = calibratedModelParametersBase.SubModel; } // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); if (data != null && vm.InputType.GetVectorSize() != data.Schema.Feature.Value.Type.GetVectorSize()) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", vm.InputType.GetVectorSize(), data.Schema.Feature.Value.Type.GetVectorSize()); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); return(new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema)); } }
// Factory method for SignatureDataTransform. private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile), "Please specify either a trainer or an input model file."); host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column"); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; if (!string.IsNullOrWhiteSpace(args.TrainedModelFile)) { if (args.Trainer != null) { ch.Warning("Both an input model and a trainer were specified. Using the model file."); } ch.Trace("Loading model"); IPredictor predictor; using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read)) using (var rep = RepositoryReader.Open(strm, ch)) ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor); ch.Trace("Creating scorer"); var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args); Contracts.Assert(data.Schema.Feature.HasValue); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } else { ch.AssertValue(args.Trainer); ch.Trace("Creating TrainAndScoreTransform"); var trainScoreArgs = new TrainAndScoreTransformer.Arguments(); args.CopyTo(trainScoreArgs); trainScoreArgs.Trainer = args.Trainer; trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); var scoreXf = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory); if (input == labelInput) { return(scoreXf); } return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput)); } } return(xf); }
// Factory method for SignatureDataScorer. private static IDataScorerTransform Create(IHostEnvironment env, TreeEnsembleFeaturizerBindableMapper.Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { return(new GenericScorer(env, args, data, mapper, trainSchema)); }
public void TreeEnsembleFeaturizerOutputSchemaTest() { // Create data set var data = SamplesUtils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorFloatWeightSamples(1000).ToList(); var dataView = ML.Data.LoadFromEnumerable(data); // Define a tree model whose trees will be extracted to construct a tree featurizer. var trainer = ML.BinaryClassification.Trainers.FastTree( new FastTreeBinaryTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10, NumberOfLeaves = 5, }); // Train the defined tree model. var model = trainer.Fit(dataView); // From the trained tree model, a mapper of tree featurizer is created. const string treesColumnName = "MyTrees"; const string leavesColumnName = "MyLeaves"; const string pathsColumnName = "MyPaths"; var args = new TreeEnsembleFeaturizerBindableMapper.Arguments() { TreesColumnName = treesColumnName, LeavesColumnName = leavesColumnName, PathsColumnName = pathsColumnName }; var treeFeaturizer = new TreeEnsembleFeaturizerBindableMapper(Env, args, model.Model); // To get output schema, we need to create RoleMappedSchema for calling Bind(...). var roleMappedSchema = new RoleMappedSchema(dataView.Schema, label: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorFloatWeightSample.Label), feature: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorFloatWeightSample.Features)); // Retrieve output schema. var boundMapper = (treeFeaturizer as ISchemaBindableMapper).Bind(Env, roleMappedSchema); var outputSchema = boundMapper.OutputSchema; { // Check if output schema is correct. var treeValuesColumn = outputSchema[0]; Assert.Equal(treesColumnName, treeValuesColumn.Name); VectorDataViewType treeValuesType = treeValuesColumn.Type as VectorDataViewType; Assert.NotNull(treeValuesType); Assert.Equal(NumberDataViewType.Single, treeValuesType.ItemType); Assert.Equal(10, treeValuesType.Size); // Below we check the only metadata field. Assert.Single(treeValuesColumn.Annotations.Schema); VBuffer <ReadOnlyMemory <char> > slotNames = default; treeValuesColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames); Assert.Equal(10, slotNames.Length); // Just check the head and the tail of the extracted vector. Assert.Equal("Tree000", slotNames.GetItemOrDefault(0).ToString()); Assert.Equal("Tree009", slotNames.GetItemOrDefault(9).ToString()); } { var treeLeafIdsColumn = outputSchema[1]; // Check column of tree leaf IDs. Assert.Equal(leavesColumnName, treeLeafIdsColumn.Name); VectorDataViewType treeLeafIdsType = treeLeafIdsColumn.Type as VectorDataViewType; Assert.NotNull(treeLeafIdsType); Assert.Equal(NumberDataViewType.Single, treeLeafIdsType.ItemType); Assert.Equal(50, treeLeafIdsType.Size); // Below we check the two leaf-IDs column's metadata fields. Assert.Equal(2, treeLeafIdsColumn.Annotations.Schema.Count); // Check metadata field IsNormalized's content. bool leafIdsNormalizedFlag = false; treeLeafIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.IsNormalized, ref leafIdsNormalizedFlag); Assert.True(leafIdsNormalizedFlag); // Check metadata field SlotNames's content. VBuffer <ReadOnlyMemory <char> > leafIdsSlotNames = default; treeLeafIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref leafIdsSlotNames); Assert.Equal(50, leafIdsSlotNames.Length); // Just check the head and the tail of the extracted vector. Assert.Equal("Tree000Leaf000", leafIdsSlotNames.GetItemOrDefault(0).ToString()); Assert.Equal("Tree009Leaf004", leafIdsSlotNames.GetItemOrDefault(49).ToString()); } { var treePathIdsColumn = outputSchema[2]; // Check column of path IDs. Assert.Equal(pathsColumnName, treePathIdsColumn.Name); VectorDataViewType treePathIdsType = treePathIdsColumn.Type as VectorDataViewType; Assert.NotNull(treePathIdsType); Assert.Equal(NumberDataViewType.Single, treePathIdsType.ItemType); Assert.Equal(40, treePathIdsType.Size); // Below we check the two path-IDs column's metadata fields. Assert.Equal(2, treePathIdsColumn.Annotations.Schema.Count); // Check metadata field IsNormalized's content. bool pathIdsNormalizedFlag = false; treePathIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.IsNormalized, ref pathIdsNormalizedFlag); Assert.True(pathIdsNormalizedFlag); // Check metadata field SlotNames's content. VBuffer <ReadOnlyMemory <char> > pathIdsSlotNames = default; treePathIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref pathIdsSlotNames); Assert.Equal(40, pathIdsSlotNames.Length); // Just check the head and the tail of the extracted vector. Assert.Equal("Tree000Node000", pathIdsSlotNames.GetItemOrDefault(0).ToString()); Assert.Equal("Tree009Node003", pathIdsSlotNames.GetItemOrDefault(39).ToString()); } }