예제 #1
0
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);
            IDataView data = CreateAndSaveLoader();

            if (!string.IsNullOrWhiteSpace(Args.Columns))
            {
                var args = new ChooseColumnsTransform.Arguments();
                args.Column = Args.Columns
                              .Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).Select(s => new ChooseColumnsTransform.Column()
                {
                    Name = s
                }).ToArray();
                if (Utils.Size(args.Column) > 0)
                {
                    data = new ChooseColumnsTransform(Host, args, data);
                }
            }

            IDataSaver saver;

            if (Args.Saver != null)
            {
                saver = Args.Saver.CreateComponent(Host);
            }
            else
            {
                saver = new TextSaver(Host, new TextSaver.Arguments()
                {
                    Dense = Args.Dense
                });
            }
            var cols = new List <int>();

            for (int i = 0; i < data.Schema.ColumnCount; i++)
            {
                if (!Args.KeepHidden && data.Schema.IsHidden(i))
                {
                    continue;
                }
                var type = data.Schema.GetColumnType(i);
                if (saver.IsColumnSavable(type))
                {
                    cols.Add(i);
                }
                else
                {
                    ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", data.Schema.GetColumnName(i));
                }
            }
            Host.NotSensitive().Check(cols.Count > 0, "No valid columns to save");

            // Send the first N lines to console.
            if (Args.Rows > 0)
            {
                var args = new SkipTakeFilter.TakeArguments()
                {
                    Count = Args.Rows
                };
                data = SkipTakeFilter.Create(Host, args, data);
            }
            var textSaver = saver as TextSaver;

            // If it is a text saver, utilize a special utility for this purpose.
            if (textSaver != null)
            {
                textSaver.WriteData(data, true, cols.ToArray());
            }
            else
            {
                using (MemoryStream mem = new MemoryStream())
                {
                    using (Stream wrapStream = new SubsetStream(mem))
                        saver.SaveData(wrapStream, data, cols.ToArray());
                    mem.Seek(0, SeekOrigin.Begin);
                    using (StreamReader reader = new StreamReader(mem))
                    {
                        string result = reader.ReadToEnd();
                        ch.Info(MessageSensitivity.UserData | MessageSensitivity.Schema, result);
                    }
                }
            }
        }
        private IDataView WrapPerInstance(RoleMappedData perInst)
        {
            var idv = perInst.Data;

            // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
            // the per-instance data computed by the evaluator in a ChooseColumnsTransform.
            var cols = new List <ChooseColumnsTransform.Column>();

            // If perInst is the result of cross-validation and contains a fold Id column, include it.
            int foldCol;

            if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
            {
                cols.Add(new ChooseColumnsTransform.Column()
                {
                    Source = MetricKinds.ColumnNames.FoldIndex
                });
            }

            // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
            if (perInst.Schema.Name == null)
            {
                var args = new GenerateNumberTransform.Arguments();
                args.Column = new[] { new GenerateNumberTransform.Column()
                                      {
                                          Name = "Instance"
                                      } };
                args.UseCounter = true;
                idv             = new GenerateNumberTransform(Host, args, idv);
                cols.Add(new ChooseColumnsTransform.Column()
                {
                    Name = "Instance"
                });
            }
            else
            {
                cols.Add(new ChooseColumnsTransform.Column()
                {
                    Source = perInst.Schema.Name.Name, Name = "Instance"
                });
            }

            // Maml outputs the weight column if it exists.
            if (perInst.Schema.Weight != null)
            {
                cols.Add(new ChooseColumnsTransform.Column()
                {
                    Name = perInst.Schema.Weight.Name
                });
            }

            // Get the other columns from the evaluator.
            foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
            {
                cols.Add(new ChooseColumnsTransform.Column()
                {
                    Name = col
                });
            }

            var chooseArgs = new ChooseColumnsTransform.Arguments();

            chooseArgs.Column = cols.ToArray();
            idv = new ChooseColumnsTransform(Host, chooseArgs, idv);
            return(GetPerInstanceMetricsCore(idv, perInst.Schema));
        }
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView top;

            if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
            {
                throw Host.Except("Did not find the top-k results data view");
            }
            var sb = new StringBuilder();

            using (var cursor = top.GetRowCursor(col => true))
            {
                int index;
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Instance, out index))
                {
                    throw Host.Except("Data view does not contain the 'Instance' column");
                }
                var instanceGetter = cursor.GetGetter <ReadOnlyMemory <char> >(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.AnomalyScore, out index))
                {
                    throw Host.Except("Data view does not contain the 'Anomaly Score' column");
                }
                var scoreGetter = cursor.GetGetter <Single>(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Label, out index))
                {
                    throw Host.Except("Data view does not contain the 'Label' column");
                }
                var labelGetter = cursor.GetGetter <Single>(index);

                bool hasRows = false;
                while (cursor.MoveNext())
                {
                    if (!hasRows)
                    {
                        sb.AppendFormat("{0} Top-scored Results", _topScored);
                        sb.AppendLine();
                        sb.AppendLine("=================================================");
                        sb.AppendLine("Instance    Anomaly Score     Labeled");
                        hasRows = true;
                    }
                    var    name  = default(ReadOnlyMemory <char>);
                    Single score = 0;
                    Single label = 0;
                    instanceGetter(ref name);
                    scoreGetter(ref score);
                    labelGetter(ref label);
                    sb.AppendFormat("{0,-10}{1,12:G4}{2,12}", name, score, label);
                    sb.AppendLine();
                }
            }
            if (sb.Length > 0)
            {
                ch.Info(MessageSensitivity.UserData, sb.ToString());
            }

            IDataView overall;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out overall))
            {
                throw Host.Except("No overall metrics found");
            }

            // Find the number of anomalies, and the thresholds.
            int numAnomIndex;

            if (!overall.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies, out numAnomIndex))
            {
                throw Host.Except("Could not find the 'NumAnomalies' column");
            }

            int  stratCol;
            var  hasStrat = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            Contracts.Assert(hasStrat == hasStratVals);
            long numAnomalies = 0;

            using (var cursor = overall.GetRowCursor(col => col == numAnomIndex ||
                                                     (hasStrat && col == stratCol)))
            {
                var numAnomGetter = cursor.GetGetter <long>(numAnomIndex);
                ValueGetter <uint> stratGetter = null;
                if (hasStrat)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                bool foundRow = false;
                while (cursor.MoveNext())
                {
                    uint strat = 0;
                    if (stratGetter != null)
                    {
                        stratGetter(ref strat);
                    }
                    if (strat > 0)
                    {
                        continue;
                    }
                    if (foundRow)
                    {
                        throw Host.Except("Found multiple non-stratified rows in overall results data view");
                    }
                    foundRow = true;
                    numAnomGetter(ref numAnomalies);
                }
            }

            var args = new ChooseColumnsTransform.Arguments();
            var cols = new List <ChooseColumnsTransform.Column>()
            {
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtKFormat, _k),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
                },
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtPFormat, _p),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
                },
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = BinaryClassifierEvaluator.Auc
                }
            };

            args.Column = cols.ToArray();
            IDataView fold = new ChooseColumnsTransform(Host, args, overall);
            string    weightedFold;

            ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
        }