/// <summary> /// Computes the column type and whether multiple indicator vectors need to be concatenated. /// Also populates the metadata. /// </summary> private static void ComputeType(KeyToBinaryVectorTransform trans, ISchema input, int iinfo, ColInfo info, MetadataDispatcher md, out VectorType type, out bool concat, out int bitsPerColumn) { Contracts.AssertValue(trans); Contracts.AssertValue(input); Contracts.AssertValue(info); Contracts.Assert(info.TypeSrc.ItemType.IsKey); Contracts.Assert(info.TypeSrc.ItemType.KeyCount > 0); //Add an additional bit for all 1s to represent missing values. bitsPerColumn = Utils.IbitHigh((uint)info.TypeSrc.ItemType.KeyCount) + 2; Contracts.Assert(bitsPerColumn > 0); // See if the source has key names. var typeNames = input.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, info.Source); if (typeNames == null || !typeNames.IsKnownSizeVector || !typeNames.ItemType.IsText || typeNames.VectorSize != info.TypeSrc.ItemType.KeyCount) { typeNames = null; } // Don't pass through any source column metadata. using (var bldr = md.BuildMetadata(iinfo)) { if (info.TypeSrc.ValueCount == 1) { // Output is a single vector computed as the sum of the output indicator vectors. concat = false; type = new VectorType(NumberType.Float, bitsPerColumn); if (typeNames != null) { bldr.AddGetter <VBuffer <DvText> >(MetadataUtils.Kinds.SlotNames, new VectorType(TextType.Instance, type), trans.GetKeyNames); } bldr.AddPrimitive(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, DvBool.True); } else { // Output is the concatenation of the multiple output indicator vectors. concat = true; type = new VectorType(NumberType.Float, info.TypeSrc.ValueCount, bitsPerColumn); if (typeNames != null && type.VectorSize > 0) { bldr.AddGetter <VBuffer <DvText> >(MetadataUtils.Kinds.SlotNames, new VectorType(TextType.Instance, type), trans.GetSlotNames); } } } }
public static IDataTransform CreateTransformCore( OutputKind argsOutputKind, OneToOneColumn[] columns, List <OutputKind?> columnOutputKinds, IDataTransform input, IHost h, IHostEnvironment env, CategoricalHashTransform.Arguments catHashArgs = null) { Contracts.CheckValue(columns, nameof(columns)); Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); using (var ch = h.Start("Create Transform Core")) { // Create the KeyToVectorTransform, if needed. List <KeyToVectorTransform.Column> cols = new List <KeyToVectorTransform.Column>(); bool binaryEncoding = argsOutputKind == OutputKind.Bin; for (int i = 0; i < columns.Length; i++) { var column = columns[i]; if (!column.TrySanitize()) { throw h.ExceptUserArg(nameof(Column.Name)); } bool? bag; OutputKind kind = columnOutputKinds[i].HasValue ? columnOutputKinds[i].Value : argsOutputKind; switch (kind) { default: throw env.ExceptUserArg(nameof(Column.OutputKind)); case OutputKind.Key: continue; case OutputKind.Bin: binaryEncoding = true; bag = false; break; case OutputKind.Ind: bag = false; break; case OutputKind.Bag: bag = true; break; } var col = new KeyToVectorTransform.Column(); col.Name = column.Name; col.Source = column.Name; col.Bag = bag; cols.Add(col); } if (cols.Count == 0) { return(input); } IDataTransform transform; if (binaryEncoding) { if ((catHashArgs?.InvertHash ?? 0) != 0) { ch.Warning("Invert hashing is being used with binary encoding."); } var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments(); keyToBinaryArgs.Column = cols.ToArray(); transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input); } else { var keyToVecArgs = new KeyToVectorTransform.Arguments { Bag = argsOutputKind == OutputKind.Bag, Column = cols.ToArray() }; transform = new KeyToVectorTransform(h, keyToVecArgs, input); } ch.Done(); return(transform); } }
private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns, List <CategoricalTransform.OutputKind?> columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null) { Contracts.CheckValue(columns, nameof(columns)); Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); using (var ch = h.Start("Create Transform Core")) { // Create the KeyToVectorTransform, if needed. var cols = new List <KeyToVectorTransform.Column>(); bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin; for (int i = 0; i < columns.Length; i++) { var column = columns[i]; if (!column.TrySanitize()) { throw h.ExceptUserArg(nameof(Column.Name)); } bool?bag; CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind; switch (kind) { default: throw ch.ExceptUserArg(nameof(Column.OutputKind)); case CategoricalTransform.OutputKind.Key: continue; case CategoricalTransform.OutputKind.Bin: binaryEncoding = true; bag = false; break; case CategoricalTransform.OutputKind.Ind: bag = false; break; case CategoricalTransform.OutputKind.Bag: bag = true; break; } var col = new KeyToVectorTransform.Column(); col.Name = column.Name; col.Source = column.Name; col.Bag = bag; cols.Add(col); } if (cols.Count == 0) { return(input); } IDataTransform transform; if (binaryEncoding) { if ((catHashArgs?.InvertHash ?? 0) != 0) { ch.Warning("Invert hashing is being used with binary encoding."); } var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); } else { var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray(); transform = KeyToVectorTransform.Create(h, input, keyToVecCols); } ch.Done(); return(transform); } }