public ITransformer Fit(IDataView input) { var h = _host; h.CheckValue(input, nameof(input)); var tparams = new TransformApplierParams(this); string[] textCols = _inputColumns; string[] wordTokCols = null; string[] charTokCols = null; string wordFeatureCol = null; string charFeatureCol = null; List <string> tempCols = new List <string>(); IDataView view = input; if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1) { var srcCols = textCols; textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") }; tempCols.Add(textCols[0]); view = new ConcatTransform(h, textCols[0], srcCols).Transform(view); } if (tparams.NeedsNormalizeTransform) { var xfCols = new (string input, string output)[textCols.Length];
public ITransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); var xf = new ConcatTransform(_host, input, _name, _source); var empty = new EmptyDataView(_host, input.Schema); var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input); return(new ConcatTransformer(_host, chunk)); }
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(columns, nameof(columns)); env.CheckValue(input, nameof(input)); IDataView view = input; var concatCols = new List <ConcatTransform.Column>(); foreach (var col in columns) { env.CheckUserArg(col != null, nameof(WordBagTransform.Arguments.Column)); env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name)); env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source)); env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source)); if (col.Source.Length > 1) { concatCols.Add( new ConcatTransform.Column { Source = col.Source, Name = col.Name }); } } if (concatCols.Count > 0) { var concatArgs = new ConcatTransform.Arguments { Column = concatCols.ToArray() }; return(ConcatTransform.Create(env, concatArgs, view)); } return(view); }
public ITransformer Fit(IDataView input) { var h = _host; h.CheckValue(input, nameof(input)); var tparams = new TransformApplierParams(this); string[] textCols = _inputColumns; string[] wordTokCols = null; string[] charTokCols = null; string wordFeatureCol = null; string charFeatureCol = null; List <string> tempCols = new List <string>(); IDataView view = input; if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1) { var srcCols = textCols; textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") }; tempCols.Add(textCols[0]); view = new ConcatTransform(h, textCols[0], srcCols).Transform(view); } if (tparams.NeedsNormalizeTransform) { var xfCols = new TextNormalizerCol[textCols.Length]; string[] dstCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer"); tempCols.Add(dstCols[i]); xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] }; } view = new TextNormalizerTransform(h, new TextNormalizerArgs() { Column = xfCols, KeepDiacritics = tparams.KeepDiacritics, KeepNumbers = tparams.KeepNumbers, KeepPunctuations = tparams.KeepPunctuations, TextCase = tparams.TextCase }, view); textCols = dstCols; } if (tparams.NeedsWordTokenizationTransform) { var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length]; wordTokCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { var col = new DelimitedTokenizeTransform.Column(); col.Source = textCols[i]; col.Name = GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"); xfCols[i] = col; wordTokCols[i] = col.Name; tempCols.Add(col.Name); } view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments() { Column = xfCols }, view); } if (tparams.NeedsRemoveStopwordsTransform) { Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text."); var xfCols = new StopWordsCol[wordTokCols.Length]; var dstCols = new string[wordTokCols.Length]; for (int i = 0; i < wordTokCols.Length; i++) { var col = new StopWordsCol(); col.Source = wordTokCols[i]; col.Name = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform"); dstCols[i] = col.Name; tempCols.Add(col.Name); col.Language = tparams.StopwordsLanguage; xfCols[i] = col; } view = tparams.StopWordsRemover.CreateComponent(h, view, xfCols); wordTokCols = dstCols; } if (tparams.WordExtractorFactory != null) { var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor"); tempCols.Add(dstCol); view = tparams.WordExtractorFactory.Create(h, view, new[] { new ExtractorColumn() { Name = dstCol, Source = wordTokCols, FriendlyNames = _inputColumns } }); wordFeatureCol = dstCol; } if (tparams.OutputTextTokens) { string[] srcCols = wordTokCols ?? textCols; view = new ConcatTransform(h, string.Format(TransformedTextColFormat, OutputColumn), srcCols).Transform(view); } if (tparams.CharExtractorFactory != null) { { var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols; charTokCols = new string[srcCols.Length]; var xfCols = new CharTokenizeTransform.Column[srcCols.Length]; for (int i = 0; i < srcCols.Length; i++) { var col = new CharTokenizeTransform.Column(); col.Source = srcCols[i]; col.Name = GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer"); tempCols.Add(col.Name); charTokCols[i] = col.Name; xfCols[i] = col; } view = new CharTokenizeTransform(h, new CharTokenizeTransform.Arguments() { Column = xfCols }, view); } { charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor"); tempCols.Add(charFeatureCol); view = tparams.CharExtractorFactory.Create(h, view, new[] { new ExtractorColumn() { Source = charTokCols, FriendlyNames = _inputColumns, Name = charFeatureCol } }); } } if (tparams.VectorNormalizer != TextNormKind.None) { var xfCols = new List <LpNormNormalizerTransform.Column>(2); if (charFeatureCol != null) { var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm"); tempCols.Add(dstCol); xfCols.Add(new LpNormNormalizerTransform.Column() { Source = charFeatureCol, Name = dstCol }); charFeatureCol = dstCol; } if (wordFeatureCol != null) { var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm"); tempCols.Add(dstCol); xfCols.Add(new LpNormNormalizerTransform.Column() { Source = wordFeatureCol, Name = dstCol }); wordFeatureCol = dstCol; } if (xfCols.Count > 0) { view = new LpNormNormalizerTransform(h, new LpNormNormalizerTransform.Arguments() { NormKind = tparams.LpNormalizerKind, Column = xfCols.ToArray() }, view); } } { var srcTaggedCols = new List <KeyValuePair <string, string> >(2); if (charFeatureCol != null && wordFeatureCol != null) { // If we're producing both char and word grams, then we need to disambiguate // between them (e.g. the word 'a' vs. the char gram 'a'). srcTaggedCols.Add(new KeyValuePair <string, string>("Char", charFeatureCol)); srcTaggedCols.Add(new KeyValuePair <string, string>("Word", wordFeatureCol)); } else { // Otherwise, simply use the slot names, omitting the original source column names // entirely. For the Concat transform setting the Key == Value of the TaggedColumn // KVP signals this intent. Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || tparams.OutputTextTokens); if (charFeatureCol != null) { srcTaggedCols.Add(new KeyValuePair <string, string>(charFeatureCol, charFeatureCol)); } else if (wordFeatureCol != null) { srcTaggedCols.Add(new KeyValuePair <string, string>(wordFeatureCol, wordFeatureCol)); } } if (srcTaggedCols.Count > 0) { view = new ConcatTransform(h, new ConcatTransform.ColumnInfo(OutputColumn, srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key)))) .Transform(view); } } view = new DropColumnsTransform(h, new DropColumnsTransform.Arguments() { Column = tempCols.ToArray() }, view); return(new Transformer(_host, input, view)); }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("Categorical"); h.CheckValue(args, nameof(args)); h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column)); var replaceCols = new List <NAReplaceTransform.ColumnInfo>(); var naIndicatorCols = new List <NAIndicatorTransform.Column>(); var naConvCols = new List <ConvertTransform.Column>(); var concatCols = new List <ConcatTransform.TaggedColumn>(); var dropCols = new List <string>(); var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Column.Length, "IsMissing"); var tmpReplaceColNames = input.Schema.GetTempColumnNames(args.Column.Length, "Replace"); for (int i = 0; i < args.Column.Length; i++) { var column = args.Column[i]; var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type, // so that they can be concatenated. if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol)) { throw h.Except("Column '{0}' does not exist", column.Source); } var replaceType = input.Schema.GetColumnType(inputCol); if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity)) { throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'", BoolType.Instance, replaceType.ItemType); } // Find a temporary name for the NAReplaceTransform and NAIndicatorTransform output columns. var tmpIsMissingColName = tmpIsMissingColNames[i]; var tmpReplacementColName = tmpReplaceColNames[i]; // Add an NAHandleTransform column. naIndicatorCols.Add(new NAIndicatorTransform.Column() { Name = tmpIsMissingColName, Source = column.Source }); // Add a ConvertTransform column if necessary. if (!identity) { naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind }); } // Add the NAReplaceTransform column. replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType.IsVector) { concatCols.Add(new ConcatTransform.TaggedColumn() { Name = column.Name, Source = new[] { new KeyValuePair <string, string>(tmpReplacementColName, tmpReplacementColName), new KeyValuePair <string, string>("IsMissing", tmpIsMissingColName) } }); } else { concatCols.Add(new ConcatTransform.TaggedColumn() { Name = column.Name, Source = new[] { new KeyValuePair <string, string>(column.Source, tmpReplacementColName), new KeyValuePair <string, string>(string.Format("IsMissing.{0}", column.Source), tmpIsMissingColName), } }); } // Add the temp column to the list of columns to drop at the end. dropCols.Add(tmpIsMissingColName); dropCols.Add(tmpReplacementColName); } IDataTransform output = null; // Create the indicator columns. if (naIndicatorCols.Count > 0) { output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments() { Column = naIndicatorCols.ToArray() }, input); } // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs. if (naConvCols.Count > 0) { h.AssertValue(output); output = new ConvertTransform(h, new ConvertTransform.Arguments() { Column = naConvCols.ToArray() }, output); } // Create the NAReplace transform. output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray()); // Concat the NAReplaceTransform output and the NAIndicatorTransform output. if (naIndicatorCols.Count > 0) { output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments() { Column = concatCols.ToArray() }, output); } // Finally, drop the temporary indicator columns. if (dropCols.Count > 0) { output = new DropColumnsTransform(h, new DropColumnsTransform.Arguments() { Column = dropCols.ToArray() }, output); } return(output); }