示例#1
0
        private protected IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false)
        {
            Host.CheckValue(schema, nameof(schema));

            var roles = !needStrat || StratCols == null
                ? Enumerable.Empty <KeyValuePair <RoleMappedSchema.ColumnRole, string> >()
                : StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col));

            if (needName && schema.Name.HasValue)
            {
                roles = AnnotationUtils.Prepend(roles, RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Value.Name));
            }

            return(roles.Concat(GetInputColumnRolesCore(schema)));
        }
        protected override void GetAnnotationCore <TValue>(string kind, int iinfo, ref TValue value)
        {
            Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
            switch (kind)
            {
            case AnnotationUtils.Kinds.ScoreColumnSetId:
                _getScoreColumnSetId.Marshal(iinfo, ref value);
                break;

            default:
                if (iinfo < DerivedColumnCount)
                {
                    throw AnnotationUtils.ExceptGetAnnotation();
                }
                Mapper.OutputSchema[iinfo - DerivedColumnCount].Annotations.GetValue(kind, ref value);
                break;
            }
        }
            public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner, RoleMappedSchema schema,
                               string treesColumnName, string leavesColumnName, string pathsColumnName)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorDataViewType(NumberDataViewType.Single, owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

                // Start creating output schema with types derived above.
                var schemaBuilder = new DataViewSchema.Builder();

                _treesColumnName = treesColumnName;
                if (treesColumnName != null)
                {
                    // Metadata of tree values.
                    var treeIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    treeIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(treeValueType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetTreeSlotNames);

                    // Add the column of trees' output values
                    schemaBuilder.AddColumn(treesColumnName, treeValueType, treeIdMetadataBuilder.ToAnnotations());
                }

                _leavesColumnName = leavesColumnName;
                if (leavesColumnName != null)
                {
                    // Metadata of leaf IDs.
                    var leafIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(leafIdType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetLeafSlotNames);
                    leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);

                    // Add the column of leaves' IDs where the input example reaches.
                    schemaBuilder.AddColumn(leavesColumnName, leafIdType, leafIdMetadataBuilder.ToAnnotations());
                }

                _pathsColumnName = pathsColumnName;
                if (pathsColumnName != null)
                {
                    // Metadata of path IDs.
                    var pathIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(pathIdType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetPathSlotNames);
                    pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);

                    // Add the column of encoded paths which the input example passes.
                    schemaBuilder.AddColumn(pathsColumnName, pathIdType, pathIdMetadataBuilder.ToAnnotations());
                }

                OutputSchema = schemaBuilder.ToSchema();
            }
示例#4
0
 protected virtual void GetAnnotationCore <TValue>(string kind, int iinfo, ref TValue value)
 {
     Contracts.AssertNonEmpty(kind);
     Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
     throw AnnotationUtils.ExceptGetAnnotation();
 }