示例#1
0
        internal static List <ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix,
                                                                 int confusionMatriceStartIndex = 0)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <ClassificationMetrics> metrics = new List <ClassificationMetrics>();
            var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();

            int index = 0;

            foreach (var metric in metricsEnumerable)
            {
                if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
                {
                    throw env.Except("Confusion matrices didn't have enough matrices.");
                }

                metrics.Add(
                    new ClassificationMetrics()
                {
                    AccuracyMicro    = metric.AccuracyMicro,
                    AccuracyMacro    = metric.AccuracyMacro,
                    LogLoss          = metric.LogLoss,
                    LogLossReduction = metric.LogLossReduction,
                    TopKAccuracy     = metric.TopKAccuracy,
                    PerClassLogLoss  = metric.PerClassLogLoss,
                    ConfusionMatrix  = confusionMatrices.Current,
                    RowTag           = metric.RowTag,
                });
            }

            return(metrics);
        }
示例#2
0
        internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new BinaryClassificationMetrics()
            {
                Auc = metrics.Auc,
                Accuracy = metrics.Accuracy,
                PositivePrecision = metrics.PositivePrecision,
                PositiveRecall = metrics.PositiveRecall,
                NegativePrecision = metrics.NegativePrecision,
                NegativeRecall = metrics.NegativeRecall,
                LogLoss = metrics.LogLoss,
                LogLossReduction = metrics.LogLossReduction,
                Entropy = metrics.Entropy,
                F1Score = metrics.F1Score,
                Auprc = metrics.Auprc,
                ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
            });
        }