示例#1
0
        public static CommonOutputs.TransformOutput GcNormalize(IHostEnvironment env, LpNormNormalizerTransform.GcnArguments input)
        {
            var h  = EntryPointUtils.CheckArgsAndCreateHost(env, "GcNormalize", input);
            var xf = new LpNormNormalizerTransform(h, input, input.Data);

            return(new CommonOutputs.TransformOutput()
            {
                Model = new TransformModel(h, xf, input.Data),
                OutputData = xf
            });
        }
示例#2
0
        public ITransformer Fit(IDataView input)
        {
            var h = _host;

            h.CheckValue(input, nameof(input));

            var tparams = new TransformApplierParams(this);

            string[]      textCols       = _inputColumns;
            string[]      wordTokCols    = null;
            string[]      charTokCols    = null;
            string        wordFeatureCol = null;
            string        charFeatureCol = null;
            List <string> tempCols       = new List <string>();
            IDataView     view           = input;

            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                view = new ConcatTransform(h, textCols[0], srcCols).Transform(view);
            }

            if (tparams.NeedsNormalizeTransform)
            {
                var      xfCols  = new TextNormalizerCol[textCols.Length];
                string[] dstCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
                    tempCols.Add(dstCols[i]);
                    xfCols[i] = new TextNormalizerCol()
                    {
                        Source = textCols[i], Name = dstCols[i]
                    };
                }

                view = new TextNormalizerTransform(h,
                                                   new TextNormalizerArgs()
                {
                    Column           = xfCols,
                    KeepDiacritics   = tparams.KeepDiacritics,
                    KeepNumbers      = tparams.KeepNumbers,
                    KeepPunctuations = tparams.KeepPunctuations,
                    TextCase         = tparams.TextCase
                }, view);

                textCols = dstCols;
            }

            if (tparams.NeedsWordTokenizationTransform)
            {
                var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length];
                wordTokCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    var col = new DelimitedTokenizeTransform.Column();
                    col.Source = textCols[i];
                    col.Name   = GenerateColumnName(view.Schema, textCols[i], "WordTokenizer");

                    xfCols[i] = col;

                    wordTokCols[i] = col.Name;
                    tempCols.Add(col.Name);
                }

                view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments()
                {
                    Column = xfCols
                }, view);
            }

            if (tparams.NeedsRemoveStopwordsTransform)
            {
                Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
                var xfCols  = new StopWordsCol[wordTokCols.Length];
                var dstCols = new string[wordTokCols.Length];
                for (int i = 0; i < wordTokCols.Length; i++)
                {
                    var col = new StopWordsCol();
                    col.Source = wordTokCols[i];
                    col.Name   = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
                    dstCols[i] = col.Name;
                    tempCols.Add(col.Name);
                    col.Language = tparams.StopwordsLanguage;

                    xfCols[i] = col;
                }
                view        = tparams.StopWordsRemover.CreateComponent(h, view, xfCols);
                wordTokCols = dstCols;
            }

            if (tparams.WordExtractorFactory != null)
            {
                var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor");
                tempCols.Add(dstCol);
                view = tparams.WordExtractorFactory.Create(h, view, new[] {
                    new ExtractorColumn()
                    {
                        Name          = dstCol,
                        Source        = wordTokCols,
                        FriendlyNames = _inputColumns
                    }
                });
                wordFeatureCol = dstCol;
            }

            if (tparams.OutputTextTokens)
            {
                string[] srcCols = wordTokCols ?? textCols;
                view = new ConcatTransform(h, string.Format(TransformedTextColFormat, OutputColumn), srcCols).Transform(view);
            }

            if (tparams.CharExtractorFactory != null)
            {
                {
                    var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
                    charTokCols = new string[srcCols.Length];
                    var xfCols = new CharTokenizeTransform.Column[srcCols.Length];
                    for (int i = 0; i < srcCols.Length; i++)
                    {
                        var col = new CharTokenizeTransform.Column();
                        col.Source = srcCols[i];
                        col.Name   = GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer");
                        tempCols.Add(col.Name);
                        charTokCols[i] = col.Name;
                        xfCols[i]      = col;
                    }
                    view = new CharTokenizeTransform(h, new CharTokenizeTransform.Arguments()
                    {
                        Column = xfCols
                    }, view);
                }

                {
                    charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor");
                    tempCols.Add(charFeatureCol);
                    view = tparams.CharExtractorFactory.Create(h, view, new[] {
                        new ExtractorColumn()
                        {
                            Source        = charTokCols,
                            FriendlyNames = _inputColumns,
                            Name          = charFeatureCol
                        }
                    });
                }
            }

            if (tparams.VectorNormalizer != TextNormKind.None)
            {
                var xfCols = new List <LpNormNormalizerTransform.Column>(2);
                if (charFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = charFeatureCol,
                        Name   = dstCol
                    });
                    charFeatureCol = dstCol;
                }

                if (wordFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = wordFeatureCol,
                        Name   = dstCol
                    });
                    wordFeatureCol = dstCol;
                }
                if (xfCols.Count > 0)
                {
                    view = new LpNormNormalizerTransform(h, new LpNormNormalizerTransform.Arguments()
                    {
                        NormKind = tparams.LpNormalizerKind,
                        Column   = xfCols.ToArray()
                    }, view);
                }
            }

            {
                var srcTaggedCols = new List <KeyValuePair <string, string> >(2);
                if (charFeatureCol != null && wordFeatureCol != null)
                {
                    // If we're producing both char and word grams, then we need to disambiguate
                    // between them (e.g. the word 'a' vs. the char gram 'a').
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Char", charFeatureCol));
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Word", wordFeatureCol));
                }
                else
                {
                    // Otherwise, simply use the slot names, omitting the original source column names
                    // entirely. For the Concat transform setting the Key == Value of the TaggedColumn
                    // KVP signals this intent.
                    Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || tparams.OutputTextTokens);
                    if (charFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(charFeatureCol, charFeatureCol));
                    }
                    else if (wordFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(wordFeatureCol, wordFeatureCol));
                    }
                }
                if (srcTaggedCols.Count > 0)
                {
                    view = new ConcatTransform(h, new ConcatTransform.ColumnInfo(OutputColumn,
                                                                                 srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key))))
                           .Transform(view);
                }
            }

            view = new DropColumnsTransform(h,
                                            new DropColumnsTransform.Arguments()
            {
                Column = tempCols.ToArray()
            }, view);

            return(new Transformer(_host, input, view));
        }