public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner,
                               RoleMappedSchema schema)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees);

                OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType));
            }
示例#2
0
        internal TreeEnsembleFeaturizationTransformer(IHostEnvironment env, DataViewSchema inputSchema,
                                                      DataViewSchema.Column featureColumn, TreeEnsembleModelParameters modelParameters,
                                                      string treesColumnName, string leavesColumnName, string pathsColumnName) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TreeEnsembleFeaturizationTransformer)), modelParameters, inputSchema)
        {
            // Store featureColumn as a detached column because a fitted transformer can be applied to different IDataViews and different
            // IDataView may have different schemas.
            _featureDetachedColumn = new DataViewSchema.DetachedColumn(featureColumn);

            // Check if featureColumn matches a column in inputSchema. The answer is yes if they have the same name and type.
            // The indexed column, inputSchema[featureColumn.Index], should match the detached column, _featureDetachedColumn.
            CheckFeatureColumnCompatibility(inputSchema[featureColumn.Index]);

            // Store output column names so that this transformer can be saved into a file later.
            _treesColumnName  = treesColumnName;
            _leavesColumnName = leavesColumnName;
            _pathsColumnName  = pathsColumnName;

            // Create an argument, _scorerArgs, to pass the output column names to the underlying scorer.
            _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments
            {
                TreesColumnName  = _treesColumnName,
                LeavesColumnName = _leavesColumnName,
                PathsColumnName  = _pathsColumnName
            };

            // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce
            // a transformed IDataView.
            BindableMapper = new TreeEnsembleFeaturizerBindableMapper(env, _scorerArgs, modelParameters);

            // Create a scorer.
            var roleMappedSchema = MakeFeatureRoleMappedSchema(inputSchema);

            Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema);
        }
示例#3
0
        private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // string: feature column's name.
            // string: the name of the columns where tree prediction values are stored.
            // string: the name of the columns where trees' leave are stored.
            // string: the name of the columns where trees' paths are stored.

            // Load stored fields.
            string featureColumnName = ctx.LoadString();

            _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]);
            _treesColumnName       = ctx.LoadStringOrNull();
            _leavesColumnName      = ctx.LoadStringOrNull();
            _pathsColumnName       = ctx.LoadStringOrNull();

            // Create an argument to specify output columns' names of this transformer.
            _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments
            {
                TreesColumnName  = _treesColumnName,
                LeavesColumnName = _leavesColumnName,
                PathsColumnName  = _pathsColumnName
            };

            // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce
            // a transformed IDataView.
            BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model);

            // Create a scorer.
            var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema);

            Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema);
        }
示例#4
0
        public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model.");

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix           = args.Suffix,
                    TreesColumnName  = "Trees",
                    LeavesColumnName = "Leaves",
                    PathsColumnName  = "Paths"
                };
                var predictor = args.PredictorModel.Predictor;
                ch.Trace("Prepare data");
                RoleMappedData data = null;
                args.PredictorModel.PrepareData(env, input, out data, out var predictor2);
                ch.AssertValue(data);
                ch.Assert(data.Schema.Feature.HasValue);
                ch.Assert(predictor == predictor2);

                // Make sure that the given predictor has the correct number of input features.
                if (predictor is CalibratedModelParametersBase <IPredictorProducing <float>, Calibrators.ICalibrator> calibratedModelParametersBase)
                {
                    predictor = calibratedModelParametersBase.SubModel;
                }
                // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                // be non-null.
                var vm = predictor as IValueMapper;
                ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type");
                if (data != null && vm.InputType.GetVectorSize() != data.Schema.Feature.Value.Type.GetVectorSize())
                {
                    throw ch.ExceptUserArg(nameof(args.PredictorModel),
                                           "Predictor expects {0} features, but data has {1} features",
                                           vm.InputType.GetVectorSize(), data.Schema.Feature.Value.Type.GetVectorSize());
                }

                ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                var bound = bindable.Bind(env, data.Schema);
                return(new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema));
            }
        }
                public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent,
                                  ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType)
                {
                    Contracts.CheckValueOrNull(ectx);
                    _ectx = ectx;
                    _ectx.AssertValue(parent);
                    _ectx.AssertValue(treeValueColType);
                    _ectx.AssertValue(leafIdColType);
                    _ectx.AssertValue(pathIdColType);

                    _parent = parent;

                    _names          = new string[3];
                    _names[TreeIdx] = OutputColumnNames.Trees;
                    _names[LeafIdx] = OutputColumnNames.Leaves;
                    _names[PathIdx] = OutputColumnNames.Paths;

                    _types          = new ColumnType[3];
                    _types[TreeIdx] = treeValueColType;
                    _types[LeafIdx] = leafIdColType;
                    _types[PathIdx] = pathIdColType;
                }
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
        public void TreeEnsembleFeaturizerOutputSchemaTest()
        {
            // Create data set
            var data     = SamplesUtils.DatasetUtils.GenerateBinaryLabelFloatFeatureVectorFloatWeightSamples(1000).ToList();
            var dataView = ML.Data.LoadFromEnumerable(data);

            // Define a tree model whose trees will be extracted to construct a tree featurizer.
            var trainer = ML.BinaryClassification.Trainers.FastTree(
                new FastTreeBinaryTrainer.Options
            {
                NumberOfThreads = 1,
                NumberOfTrees   = 10,
                NumberOfLeaves  = 5,
            });

            // Train the defined tree model.
            var model = trainer.Fit(dataView);

            // From the trained tree model, a mapper of tree featurizer is created.
            var treeFeaturizer = new TreeEnsembleFeaturizerBindableMapper(Env, new TreeEnsembleFeaturizerBindableMapper.Arguments(), model.Model);

            // To get output schema, we need to create RoleMappedSchema for calling Bind(...).
            var roleMappedSchema = new RoleMappedSchema(dataView.Schema,
                                                        label: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorFloatWeightSample.Label),
                                                        feature: nameof(SamplesUtils.DatasetUtils.BinaryLabelFloatFeatureVectorFloatWeightSample.Features));

            // Retrieve output schema.
            var boundMapper  = (treeFeaturizer as ISchemaBindableMapper).Bind(Env, roleMappedSchema);
            var outputSchema = boundMapper.OutputSchema;

            {
                // Check if output schema is correct.
                var treeValuesColumn = outputSchema[0];
                Assert.Equal("Trees", treeValuesColumn.Name);
                VectorDataViewType treeValuesType = treeValuesColumn.Type as VectorDataViewType;
                Assert.NotNull(treeValuesType);
                Assert.Equal(NumberDataViewType.Single, treeValuesType.ItemType);
                Assert.Equal(10, treeValuesType.Size);
                // Below we check the only metadata field.
                Assert.Single(treeValuesColumn.Annotations.Schema);
                VBuffer <ReadOnlyMemory <char> > slotNames = default;
                treeValuesColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames);
                Assert.Equal(10, slotNames.Length);
                // Just check the head and the tail of the extracted vector.
                Assert.Equal("Tree000", slotNames.GetItemOrDefault(0).ToString());
                Assert.Equal("Tree009", slotNames.GetItemOrDefault(9).ToString());
            }

            {
                var treeLeafIdsColumn = outputSchema[1];
                // Check column of tree leaf IDs.
                Assert.Equal("Leaves", treeLeafIdsColumn.Name);
                VectorDataViewType treeLeafIdsType = treeLeafIdsColumn.Type as VectorDataViewType;
                Assert.NotNull(treeLeafIdsType);
                Assert.Equal(NumberDataViewType.Single, treeLeafIdsType.ItemType);
                Assert.Equal(50, treeLeafIdsType.Size);
                // Below we check the two leaf-IDs column's metadata fields.
                Assert.Equal(2, treeLeafIdsColumn.Annotations.Schema.Count);
                // Check metadata field IsNormalized's content.
                bool leafIdsNormalizedFlag = false;
                treeLeafIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.IsNormalized, ref leafIdsNormalizedFlag);
                Assert.True(leafIdsNormalizedFlag);
                // Check metadata field SlotNames's content.
                VBuffer <ReadOnlyMemory <char> > leafIdsSlotNames = default;
                treeLeafIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref leafIdsSlotNames);
                Assert.Equal(50, leafIdsSlotNames.Length);
                // Just check the head and the tail of the extracted vector.
                Assert.Equal("Tree000Leaf000", leafIdsSlotNames.GetItemOrDefault(0).ToString());
                Assert.Equal("Tree009Leaf004", leafIdsSlotNames.GetItemOrDefault(49).ToString());
            }

            {
                var treePathIdsColumn = outputSchema[2];
                // Check column of path IDs.
                Assert.Equal("Paths", treePathIdsColumn.Name);
                VectorDataViewType treePathIdsType = treePathIdsColumn.Type as VectorDataViewType;
                Assert.NotNull(treePathIdsType);
                Assert.Equal(NumberDataViewType.Single, treePathIdsType.ItemType);
                Assert.Equal(40, treePathIdsType.Size);
                // Below we check the two path-IDs column's metadata fields.
                Assert.Equal(2, treePathIdsColumn.Annotations.Schema.Count);
                // Check metadata field IsNormalized's content.
                bool pathIdsNormalizedFlag = false;
                treePathIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.IsNormalized, ref pathIdsNormalizedFlag);
                Assert.True(pathIdsNormalizedFlag);
                // Check metadata field SlotNames's content.
                VBuffer <ReadOnlyMemory <char> > pathIdsSlotNames = default;
                treePathIdsColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref pathIdsSlotNames);
                Assert.Equal(40, pathIdsSlotNames.Length);
                // Just check the head and the tail of the extracted vector.
                Assert.Equal("Tree000Node000", pathIdsSlotNames.GetItemOrDefault(0).ToString());
                Assert.Equal("Tree009Node003", pathIdsSlotNames.GetItemOrDefault(39).ToString());
            }
        }
            public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner,
                               RoleMappedSchema schema)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

                // Start creating output schema with types derived above.
                var schemaBuilder = new SchemaBuilder();

                // Metadata of tree values.
                var treeIdMetadataBuilder = new MetadataBuilder();

                treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetTreeSlotNames);
                // Add the column of trees' output values
                schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());

                // Metadata of leaf IDs.
                var leafIdMetadataBuilder = new MetadataBuilder();

                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetLeafSlotNames);
                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of leaves' IDs where the input example reaches.
                schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());

                // Metadata of path IDs.
                var pathIdMetadataBuilder = new MetadataBuilder();

                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetPathSlotNames);
                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of encoded paths which the input example passes.
                schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());

                OutputSchema = schemaBuilder.GetSchema();

                // Tree values must be the first output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
                // leaf IDs must be the second output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
                // Path IDs must be the third output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);
            }