Beispiel #1
0
            public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner,
                               RoleMappedSchema schema)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

                // Start creating output schema with types derived above.
                var schemaBuilder = new SchemaBuilder();

                // Metadata of tree values.
                var treeIdMetadataBuilder = new MetadataBuilder();

                treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetTreeSlotNames);
                // Add the column of trees' output values
                schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());

                // Metadata of leaf IDs.
                var leafIdMetadataBuilder = new MetadataBuilder();

                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetLeafSlotNames);
                leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of leaves' IDs where the input example reaches.
                schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());

                // Metadata of path IDs.
                var pathIdMetadataBuilder = new MetadataBuilder();

                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.Size),
                                          (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetPathSlotNames);
                pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
                // Add the column of encoded paths which the input example passes.
                schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());

                OutputSchema = schemaBuilder.GetSchema();

                // Tree values must be the first output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
                // leaf IDs must be the second output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
                // Path IDs must be the third output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);
            }
Beispiel #2
0
        /// <summary>
        /// Create a new instance of <see cref="ColumnBindings"/>.
        /// </summary>
        /// <param name="input">The input schema that we're adding columns to.</param>
        /// <param name="addedColumns">The columns being added.</param>
        public ColumnBindings(Schema input, Schema.DetachedColumn[] addedColumns)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(addedColumns, nameof(addedColumns));

            InputSchema = input;

            // Construct the indices.
            var indices   = new List <int>();
            var namesUsed = new HashSet <string>();

            for (int i = 0; i < input.Count; i++)
            {
                namesUsed.Add(input[i].Name);
                indices.Add(i);
            }

            for (int i = 0; i < addedColumns.Length; i++)
            {
                string name = addedColumns[i].Name;
                if (namesUsed.Add(name))
                {
                    // New name. Append to the end.
                    indices.Add(~i);
                }
                else
                {
                    // Old name. Find last instance and add after it.
                    for (int j = indices.Count - 1; j >= 0; j--)
                    {
                        var colName = indices[j] >= 0 ? input[indices[j]].Name : addedColumns[~indices[j]].Name;
                        if (colName == name)
                        {
                            indices.Insert(j + 1, ~i);
                            break;
                        }
                    }
                }
            }
            Contracts.Assert(indices.Count == addedColumns.Length + input.Count);

            // Create the output schema.
            var schemaColumns = indices.Select(idx => idx >= 0 ? new Schema.DetachedColumn(input[idx]) : addedColumns[~idx]);

            Schema = SchemaBuilder.MakeSchema(schemaColumns);

            // Memorize column maps.
            _colMap = indices.ToArray();
            var addedIndices = new int[addedColumns.Length];

            for (int i = 0; i < _colMap.Length; i++)
            {
                int colIndex = _colMap[i];
                if (colIndex < 0)
                {
                    Contracts.Assert(addedIndices[~colIndex] == 0);
                    addedIndices[~colIndex] = i;
                }
            }

            AddedColumnIndices = addedIndices.AsReadOnly();
        }