예제 #1
0
        protected override void PrintOverallResultsCore(IChannel ch, string filename, Dictionary <string, IDataView>[] metrics)
        {
            ch.AssertNonEmpty(metrics);

            IDataView overall;

            if (!TryGetOverallMetrics(metrics, out overall))
            {
                throw ch.Except("No overall metrics found");
            }

            var args = new DropColumnsTransform.Arguments();

            args.Column = new[]
            {
                AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
                AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
                AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
                AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
            };
            overall = new DropColumnsTransform(Host, args, overall);
            MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length);
        }
예제 #2
0
        public ITransformer Fit(IDataView input)
        {
            var h = _host;

            h.CheckValue(input, nameof(input));

            var tparams = new TransformApplierParams(this);

            string[]      textCols       = _inputColumns;
            string[]      wordTokCols    = null;
            string[]      charTokCols    = null;
            string        wordFeatureCol = null;
            string        charFeatureCol = null;
            List <string> tempCols       = new List <string>();
            IDataView     view           = input;

            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                view = new ConcatTransform(h, textCols[0], srcCols).Transform(view);
            }

            if (tparams.NeedsNormalizeTransform)
            {
                var      xfCols  = new TextNormalizerCol[textCols.Length];
                string[] dstCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
                    tempCols.Add(dstCols[i]);
                    xfCols[i] = new TextNormalizerCol()
                    {
                        Source = textCols[i], Name = dstCols[i]
                    };
                }

                view = new TextNormalizerTransform(h,
                                                   new TextNormalizerArgs()
                {
                    Column           = xfCols,
                    KeepDiacritics   = tparams.KeepDiacritics,
                    KeepNumbers      = tparams.KeepNumbers,
                    KeepPunctuations = tparams.KeepPunctuations,
                    TextCase         = tparams.TextCase
                }, view);

                textCols = dstCols;
            }

            if (tparams.NeedsWordTokenizationTransform)
            {
                var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length];
                wordTokCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    var col = new DelimitedTokenizeTransform.Column();
                    col.Source = textCols[i];
                    col.Name   = GenerateColumnName(view.Schema, textCols[i], "WordTokenizer");

                    xfCols[i] = col;

                    wordTokCols[i] = col.Name;
                    tempCols.Add(col.Name);
                }

                view = new DelimitedTokenizeTransform(h, new DelimitedTokenizeTransform.Arguments()
                {
                    Column = xfCols
                }, view);
            }

            if (tparams.NeedsRemoveStopwordsTransform)
            {
                Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
                var xfCols  = new StopWordsCol[wordTokCols.Length];
                var dstCols = new string[wordTokCols.Length];
                for (int i = 0; i < wordTokCols.Length; i++)
                {
                    var col = new StopWordsCol();
                    col.Source = wordTokCols[i];
                    col.Name   = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
                    dstCols[i] = col.Name;
                    tempCols.Add(col.Name);
                    col.Language = tparams.StopwordsLanguage;

                    xfCols[i] = col;
                }
                view        = tparams.StopWordsRemover.CreateComponent(h, view, xfCols);
                wordTokCols = dstCols;
            }

            if (tparams.WordExtractorFactory != null)
            {
                var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor");
                tempCols.Add(dstCol);
                view = tparams.WordExtractorFactory.Create(h, view, new[] {
                    new ExtractorColumn()
                    {
                        Name          = dstCol,
                        Source        = wordTokCols,
                        FriendlyNames = _inputColumns
                    }
                });
                wordFeatureCol = dstCol;
            }

            if (tparams.OutputTextTokens)
            {
                string[] srcCols = wordTokCols ?? textCols;
                view = new ConcatTransform(h, string.Format(TransformedTextColFormat, OutputColumn), srcCols).Transform(view);
            }

            if (tparams.CharExtractorFactory != null)
            {
                {
                    var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
                    charTokCols = new string[srcCols.Length];
                    var xfCols = new CharTokenizeTransform.Column[srcCols.Length];
                    for (int i = 0; i < srcCols.Length; i++)
                    {
                        var col = new CharTokenizeTransform.Column();
                        col.Source = srcCols[i];
                        col.Name   = GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer");
                        tempCols.Add(col.Name);
                        charTokCols[i] = col.Name;
                        xfCols[i]      = col;
                    }
                    view = new CharTokenizeTransform(h, new CharTokenizeTransform.Arguments()
                    {
                        Column = xfCols
                    }, view);
                }

                {
                    charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor");
                    tempCols.Add(charFeatureCol);
                    view = tparams.CharExtractorFactory.Create(h, view, new[] {
                        new ExtractorColumn()
                        {
                            Source        = charTokCols,
                            FriendlyNames = _inputColumns,
                            Name          = charFeatureCol
                        }
                    });
                }
            }

            if (tparams.VectorNormalizer != TextNormKind.None)
            {
                var xfCols = new List <LpNormNormalizerTransform.Column>(2);
                if (charFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = charFeatureCol,
                        Name   = dstCol
                    });
                    charFeatureCol = dstCol;
                }

                if (wordFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizerTransform.Column()
                    {
                        Source = wordFeatureCol,
                        Name   = dstCol
                    });
                    wordFeatureCol = dstCol;
                }
                if (xfCols.Count > 0)
                {
                    view = new LpNormNormalizerTransform(h, new LpNormNormalizerTransform.Arguments()
                    {
                        NormKind = tparams.LpNormalizerKind,
                        Column   = xfCols.ToArray()
                    }, view);
                }
            }

            {
                var srcTaggedCols = new List <KeyValuePair <string, string> >(2);
                if (charFeatureCol != null && wordFeatureCol != null)
                {
                    // If we're producing both char and word grams, then we need to disambiguate
                    // between them (e.g. the word 'a' vs. the char gram 'a').
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Char", charFeatureCol));
                    srcTaggedCols.Add(new KeyValuePair <string, string>("Word", wordFeatureCol));
                }
                else
                {
                    // Otherwise, simply use the slot names, omitting the original source column names
                    // entirely. For the Concat transform setting the Key == Value of the TaggedColumn
                    // KVP signals this intent.
                    Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || tparams.OutputTextTokens);
                    if (charFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(charFeatureCol, charFeatureCol));
                    }
                    else if (wordFeatureCol != null)
                    {
                        srcTaggedCols.Add(new KeyValuePair <string, string>(wordFeatureCol, wordFeatureCol));
                    }
                }
                if (srcTaggedCols.Count > 0)
                {
                    view = new ConcatTransform(h, new ConcatTransform.ColumnInfo(OutputColumn,
                                                                                 srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key))))
                           .Transform(view);
                }
            }

            view = new DropColumnsTransform(h,
                                            new DropColumnsTransform.Arguments()
            {
                Column = tempCols.ToArray()
            }, view);

            return(new Transformer(_host, input, view));
        }
        private IDataView AppendPerInstanceDataViews(IEnumerable <IDataView> foldDataViews, IChannel ch)
        {
            // Make sure there are no variable size vector columns.
            // This is a dictionary from the column name to its vector size.
            var vectorSizes                   = new Dictionary <string, int>();
            var firstDvSlotNames              = new Dictionary <string, VBuffer <DvText> >();
            var firstDvKeyColumns             = new List <string>();
            var firstDvVectorKeyColumns       = new List <string>();
            var variableSizeVectorColumnNames = new List <string>();
            var list     = new List <IDataView>();
            int dvNumber = 0;

            foreach (var dv in foldDataViews)
            {
                var hidden = new List <int>();
                for (int i = 0; i < dv.Schema.ColumnCount; i++)
                {
                    if (dv.Schema.IsHidden(i))
                    {
                        hidden.Add(i);
                        continue;
                    }

                    var type = dv.Schema.GetColumnType(i);
                    var name = dv.Schema.GetColumnName(i);
                    if (type.IsVector)
                    {
                        if (dvNumber == 0)
                        {
                            if (dv.Schema.HasKeyNames(i, type.ItemType.KeyCount))
                            {
                                firstDvVectorKeyColumns.Add(name);
                            }
                            // Store the slot names of the 1st idv and use them as baseline.
                            if (dv.Schema.HasSlotNames(i, type.VectorSize))
                            {
                                VBuffer <DvText> slotNames = default(VBuffer <DvText>);
                                dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames);
                                firstDvSlotNames.Add(name, slotNames);
                            }
                        }

                        int cachedSize;
                        if (vectorSizes.TryGetValue(name, out cachedSize))
                        {
                            VBuffer <DvText> slotNames;
                            // In the event that no slot names were recorded here, then slotNames will be
                            // the default, length 0 vector.
                            firstDvSlotNames.TryGetValue(name, out slotNames);
                            if (!VerifyVectorColumnsMatch(cachedSize, i, dv, type, ref slotNames))
                            {
                                variableSizeVectorColumnNames.Add(name);
                            }
                        }
                        else
                        {
                            vectorSizes.Add(name, type.VectorSize);
                        }
                    }
                    else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount))
                    {
                        // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
                        firstDvKeyColumns.Add(name);
                    }
                }
                var idv = dv;
                if (hidden.Count > 0)
                {
                    var args = new ChooseColumnsByIndexTransform.Arguments();
                    args.Drop  = true;
                    args.Index = hidden.ToArray();
                    idv        = new ChooseColumnsByIndexTransform(Host, args, idv);
                }
                list.Add(idv);
                dvNumber++;
            }

            if (variableSizeVectorColumnNames.Count == 0 && firstDvKeyColumns.Count == 0)
            {
                return(AppendRowsDataView.Create(Host, null, list.ToArray()));
            }

            var views = list.ToArray();

            foreach (var keyCol in firstDvKeyColumns)
            {
                EvaluateUtils.ReconcileKeyValues(Host, views, keyCol);
            }
            foreach (var vectorKeyCol in firstDvVectorKeyColumns)
            {
                EvaluateUtils.ReconcileVectorKeyValues(Host, views, vectorKeyCol);
            }

            Func <IDataView, int, IDataView> keyToValue =
                (idv, i) =>
            {
                foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns))
                {
                    idv = new KeyToValueTransform(Host, new KeyToValueTransform.Arguments()
                    {
                        Column = new[] { new KeyToValueTransform.Column()
                                         {
                                             Name = keyCol
                                         }, }
                    }, idv);
                    var hidden = FindHiddenColumns(idv.Schema, keyCol);
                    idv = new ChooseColumnsByIndexTransform(Host, new ChooseColumnsByIndexTransform.Arguments()
                    {
                        Drop = true, Index = hidden.ToArray()
                    }, idv);
                }
                return(idv);
            };

            Func <IDataView, IDataView> selectDropNonVarLenthCol =
                (idv) =>
            {
                foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNames)
                {
                    int index;
                    idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index);
                    var type = idv.Schema.GetColumnType(index);

                    idv = Utils.MarshalInvoke(AddVarLengthColumn <int>, type.ItemType.RawType, Host, idv,
                                              variableSizeVectorColumnName, type);

                    // Drop the old column that does not have variable length.
                    idv = new DropColumnsTransform(Host, new DropColumnsTransform.Arguments()
                    {
                        Column = new[] { variableSizeVectorColumnName }
                    }, idv);
                }
                return(idv);
            };

            if (variableSizeVectorColumnNames.Count > 0)
            {
                ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames));
            }
            return(AppendRowsDataView.Create(Host, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray()));
        }
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("Categorical");

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));

            var replaceCols          = new List <NAReplaceTransform.ColumnInfo>();
            var naIndicatorCols      = new List <NAIndicatorTransform.Column>();
            var naConvCols           = new List <ConvertTransform.Column>();
            var concatCols           = new List <ConcatTransform.TaggedColumn>();
            var dropCols             = new List <string>();
            var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Column.Length, "IsMissing");
            var tmpReplaceColNames   = input.Schema.GetTempColumnNames(args.Column.Length, "Replace");

            for (int i = 0; i < args.Column.Length; i++)
            {
                var column = args.Column[i];

                var addInd = column.ConcatIndicator ?? args.Concat;
                if (!addInd)
                {
                    replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
                    continue;
                }

                // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
                // so that they can be concatenated.
                if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
                {
                    throw h.Except("Column '{0}' does not exist", column.Source);
                }
                var replaceType = input.Schema.GetColumnType(inputCol);
                if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
                {
                    throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
                                   BoolType.Instance, replaceType.ItemType);
                }

                // Find a temporary name for the NAReplaceTransform and NAIndicatorTransform output columns.
                var tmpIsMissingColName   = tmpIsMissingColNames[i];
                var tmpReplacementColName = tmpReplaceColNames[i];

                // Add an NAHandleTransform column.
                naIndicatorCols.Add(new NAIndicatorTransform.Column()
                {
                    Name = tmpIsMissingColName, Source = column.Source
                });

                // Add a ConvertTransform column if necessary.
                if (!identity)
                {
                    naConvCols.Add(new ConvertTransform.Column()
                    {
                        Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind
                    });
                }

                // Add the NAReplaceTransform column.
                replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));

                // Add the ConcatTransform column.
                if (replaceType.IsVector)
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[] {
                            new KeyValuePair <string, string>(tmpReplacementColName, tmpReplacementColName),
                            new KeyValuePair <string, string>("IsMissing", tmpIsMissingColName)
                        }
                    });
                }
                else
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[]
                        {
                            new KeyValuePair <string, string>(column.Source, tmpReplacementColName),
                            new KeyValuePair <string, string>(string.Format("IsMissing.{0}", column.Source), tmpIsMissingColName),
                        }
                    });
                }

                // Add the temp column to the list of columns to drop at the end.
                dropCols.Add(tmpIsMissingColName);
                dropCols.Add(tmpReplacementColName);
            }

            IDataTransform output = null;

            // Create the indicator columns.
            if (naIndicatorCols.Count > 0)
            {
                output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments()
                {
                    Column = naIndicatorCols.ToArray()
                }, input);
            }

            // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs.
            if (naConvCols.Count > 0)
            {
                h.AssertValue(output);
                output = new ConvertTransform(h, new ConvertTransform.Arguments()
                {
                    Column = naConvCols.ToArray()
                }, output);
            }
            // Create the NAReplace transform.
            output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());

            // Concat the NAReplaceTransform output and the NAIndicatorTransform output.
            if (naIndicatorCols.Count > 0)
            {
                output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments()
                {
                    Column = concatCols.ToArray()
                }, output);
            }

            // Finally, drop the temporary indicator columns.
            if (dropCols.Count > 0)
            {
                output = new DropColumnsTransform(h, new DropColumnsTransform.Arguments()
                {
                    Column = dropCols.ToArray()
                }, output);
            }

            return(output);
        }