예제 #1
0
        public ITransformer Fit(IDataView input)
        {
            var h = _host;

            h.CheckValue(input, nameof(input));

            var tparams = new TransformApplierParams(this);

            string[]      textCols       = _inputColumns;
            string[]      wordTokCols    = null;
            string[]      charTokCols    = null;
            string        wordFeatureCol = null;
            string        charFeatureCol = null;
            List <string> tempCols       = new List <string>();
            IDataView     view           = input;

            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                view = new ConcatTransform(h, textCols[0], srcCols).Transform(view);
            }

            if (tparams.NeedsNormalizeTransform)
            {
                var      xfCols  = new (string input, string output)[textCols.Length];
예제 #2
0
        public ITransformer Fit(IDataView input)
        {
            _host.CheckValue(input, nameof(input));

            var xf    = new ConcatTransform(_host, input, _name, _source);
            var empty = new EmptyDataView(_host, input.Schema);
            var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input);

            return(new ConcatTransformer(_host, chunk));
        }
예제 #3
0
        public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(columns, nameof(columns));
            env.CheckValue(input, nameof(input));

            IDataView view       = input;
            var       concatCols = new List <ConcatTransform.Column>();

            foreach (var col in columns)
            {
                env.CheckUserArg(col != null, nameof(WordBagTransform.Arguments.Column));
                env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
                env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
                env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));

                if (col.Source.Length > 1)
                {
                    concatCols.Add(
                        new ConcatTransform.Column
                    {
                        Source = col.Source,
                        Name   = col.Name
                    });
                }
            }
            if (concatCols.Count > 0)
            {
                var concatArgs = new ConcatTransform.Arguments {
                    Column = concatCols.ToArray()
                };
                return(ConcatTransform.Create(env, concatArgs, view));
            }

            return(view);
        }
예제 #4
0
        public ITransformer Fit(IDataView input)
        {
            var h = _host;

            h.CheckValue(input, nameof(input));

            var tparams = new TransformApplierParams(this);

            string[]      textCols       = _inputColumns;
            string[]      wordTokCols    = null;
            string[]      charTokCols    = null;
            string        wordFeatureCol = null;
            string        charFeatureCol = null;
            List <string> tempCols       = new List <string>();
            IDataView     view           = input;

            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                view = new ConcatTransform(h, textCols[0], srcCols).Transform(view);
            }

            if (tparams.NeedsNormalizeTransform)
            {
                var      xfCols  = new TextNormalizerCol[textCols.Length];
                string[] dstCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
                    tempCols.Add(dstCols[i]);
                    xfCols[i] = new TextNormalizerCol()
                    {
                        Source = textCols[i], Name = dstCols[i]
                    };
                }

                view = new TextNormalizerTransform(h,
                                                   new TextNormalizerArgs()
                {
                    Column           = xfCols,
                    KeepDiacritics   = tparams.KeepDiacritics,
                    KeepNumbers      = tparams.KeepNumbers,
                    KeepPunctuations = tparams.KeepPunctuations,
                    TextCase         = tparams.TextCase
                }, view);

                textCols = dstCols;
            }

            if (tparams.NeedsWordTokenizationTransform)
            {
                var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length];
                wordTokCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    var col = new DelimitedTokenizeTransform.Column();
                    col.Source = textCols[i];
                    col.Name   = GenerateColumnName(view.Schema, textCols[i], "WordTokenizer");

                    xfCols[i] = col;

                    wordTokCols[i] = col.Name;
                    tempCols.Add(col.Name);
                }

                view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments()
                {
                    Column = xfCols
                }, view);
            }

            if (tparams.NeedsRemoveStopwordsTransform)
            {
                Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
                var xfCols  = new StopWordsCol[wordTokCols.Length];
                var dstCols = new string[wordTokCols.Length];
                for (int i = 0; i < wordTokCols.Length; i++)
                {
                    var col = new StopWordsCol();
                    col.Source = wordTokCols[i];
                    col.Name   = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
                    dstCols[i] = col.Name;
                    tempCols.Add(col.Name);
                    col.Language = tparams.StopwordsLanguage;

                    xfCols[i] = col;
                }
                view        = tparams.StopWordsRemover.CreateComponent(h, view, xfCols);
                wordTokCols = dstCols;
            }

            if (tparams.WordExtractorFactory != null)
            {
                var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor");
                tempCols.Add(dstCol);
                view = tparams.WordExtractorFactory.Create(h, view, new[] {
                    new ExtractorColumn()
                    {
                        Name          = dstCol,
                        Source        = wordTokCols,
                        FriendlyNames = _inputColumns
                    }
                });
                wordFeatureCol = dstCol;
            }

            if (tparams.OutputTextTokens)
            {
                string[] srcCols = wordTokCols ?? textCols;
                view = new ConcatTransform(h, string.Format(TransformedTextColFormat, OutputColumn), srcCols).Transform(view);
            }

            if (tparams.CharExtractorFactory != null)
            {
                {
                    var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
                    charTokCols = new string[srcCols.Length];
                    var xfCols = new CharTokenizeTransform.Column[srcCols.Length];
                    for (int i = 0; i < srcCols.Length; i++)
                    {
                        var col = new CharTokenizeTransform.Column();
                        col.Source = srcCols[i];
                        col.Name   = GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer");
                        tempCols.Add(col.Name);
                        charTokCols[i] = col.Name;
                        xfCols[i]      = col;
                    }
                    view = new CharTokenizeTransform(h, new CharTokenizeTransform.Arguments()
                    {
                        Column = xfCols
                    }, view);
                }

                {
                    charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor");
                    tempCols.Add(charFeatureCol);
                    view = tparams.CharExtractorFactory.Create(h, view, new[] {
                        new ExtractorColumn()
                        {
                            Source        = charTokCols,
                            FriendlyNames = _inputColumns,
                            Name          = charFeatureCol
                        }
                    });
                }
            }

            if (tparams.VectorNormalizer != TextNormKind.None)
            {
                var xfCols = new List <LpNormNormalizerTransform.Column>(2);
                if (charFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = charFeatureCol,
                        Name   = dstCol
                    });
                    charFeatureCol = dstCol;
                }

                if (wordFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = wordFeatureCol,
                        Name   = dstCol
                    });
                    wordFeatureCol = dstCol;
                }
                if (xfCols.Count > 0)
                {
                    view = new LpNormNormalizerTransform(h, new LpNormNormalizerTransform.Arguments()
                    {
                        NormKind = tparams.LpNormalizerKind,
                        Column   = xfCols.ToArray()
                    }, view);
                }
            }

            {
                var srcTaggedCols = new List <KeyValuePair <string, string> >(2);
                if (charFeatureCol != null && wordFeatureCol != null)
                {
                    // If we're producing both char and word grams, then we need to disambiguate
                    // between them (e.g. the word 'a' vs. the char gram 'a').
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Char", charFeatureCol));
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Word", wordFeatureCol));
                }
                else
                {
                    // Otherwise, simply use the slot names, omitting the original source column names
                    // entirely. For the Concat transform setting the Key == Value of the TaggedColumn
                    // KVP signals this intent.
                    Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || tparams.OutputTextTokens);
                    if (charFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(charFeatureCol, charFeatureCol));
                    }
                    else if (wordFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(wordFeatureCol, wordFeatureCol));
                    }
                }
                if (srcTaggedCols.Count > 0)
                {
                    view = new ConcatTransform(h, new ConcatTransform.ColumnInfo(OutputColumn,
                                                                                 srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key))))
                           .Transform(view);
                }
            }

            view = new DropColumnsTransform(h,
                                            new DropColumnsTransform.Arguments()
            {
                Column = tempCols.ToArray()
            }, view);

            return(new Transformer(_host, input, view));
        }
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("Categorical");

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));

            var replaceCols          = new List <NAReplaceTransform.ColumnInfo>();
            var naIndicatorCols      = new List <NAIndicatorTransform.Column>();
            var naConvCols           = new List <ConvertTransform.Column>();
            var concatCols           = new List <ConcatTransform.TaggedColumn>();
            var dropCols             = new List <string>();
            var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Column.Length, "IsMissing");
            var tmpReplaceColNames   = input.Schema.GetTempColumnNames(args.Column.Length, "Replace");

            for (int i = 0; i < args.Column.Length; i++)
            {
                var column = args.Column[i];

                var addInd = column.ConcatIndicator ?? args.Concat;
                if (!addInd)
                {
                    replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
                    continue;
                }

                // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
                // so that they can be concatenated.
                if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
                {
                    throw h.Except("Column '{0}' does not exist", column.Source);
                }
                var replaceType = input.Schema.GetColumnType(inputCol);
                if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
                {
                    throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
                                   BoolType.Instance, replaceType.ItemType);
                }

                // Find a temporary name for the NAReplaceTransform and NAIndicatorTransform output columns.
                var tmpIsMissingColName   = tmpIsMissingColNames[i];
                var tmpReplacementColName = tmpReplaceColNames[i];

                // Add an NAHandleTransform column.
                naIndicatorCols.Add(new NAIndicatorTransform.Column()
                {
                    Name = tmpIsMissingColName, Source = column.Source
                });

                // Add a ConvertTransform column if necessary.
                if (!identity)
                {
                    naConvCols.Add(new ConvertTransform.Column()
                    {
                        Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind
                    });
                }

                // Add the NAReplaceTransform column.
                replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));

                // Add the ConcatTransform column.
                if (replaceType.IsVector)
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[] {
                            new KeyValuePair <string, string>(tmpReplacementColName, tmpReplacementColName),
                            new KeyValuePair <string, string>("IsMissing", tmpIsMissingColName)
                        }
                    });
                }
                else
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[]
                        {
                            new KeyValuePair <string, string>(column.Source, tmpReplacementColName),
                            new KeyValuePair <string, string>(string.Format("IsMissing.{0}", column.Source), tmpIsMissingColName),
                        }
                    });
                }

                // Add the temp column to the list of columns to drop at the end.
                dropCols.Add(tmpIsMissingColName);
                dropCols.Add(tmpReplacementColName);
            }

            IDataTransform output = null;

            // Create the indicator columns.
            if (naIndicatorCols.Count > 0)
            {
                output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments()
                {
                    Column = naIndicatorCols.ToArray()
                }, input);
            }

            // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs.
            if (naConvCols.Count > 0)
            {
                h.AssertValue(output);
                output = new ConvertTransform(h, new ConvertTransform.Arguments()
                {
                    Column = naConvCols.ToArray()
                }, output);
            }
            // Create the NAReplace transform.
            output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());

            // Concat the NAReplaceTransform output and the NAIndicatorTransform output.
            if (naIndicatorCols.Count > 0)
            {
                output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments()
                {
                    Column = concatCols.ToArray()
                }, output);
            }

            // Finally, drop the temporary indicator columns.
            if (dropCols.Count > 0)
            {
                output = new DropColumnsTransform(h, new DropColumnsTransform.Arguments()
                {
                    Column = dropCols.ToArray()
                }, output);
            }

            return(output);
        }