internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new ClassificationMetrics()
            {
                AccuracyMicro = metrics.AccuracyMicro,
                AccuracyMacro = metrics.AccuracyMacro,
                LogLoss = metrics.LogLoss,
                LogLossReduction = metrics.LogLossReduction,
                TopKAccuracy = metrics.TopKAccuracy,
                PerClassLogLoss = metrics.PerClassLogLoss,
                ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix)
            });
        }
示例#2
0
        internal static List <BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <BinaryClassificationMetrics> metrics = new List <BinaryClassificationMetrics>();
            var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();

            int index = 0;

            foreach (var metric in metricsEnumerable)
            {
                if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
                {
                    throw env.Except("Confusion matrices didn't have enough matrices.");
                }

                metrics.Add(
                    new BinaryClassificationMetrics()
                {
                    Auc               = metric.Auc,
                    Accuracy          = metric.Accuracy,
                    PositivePrecision = metric.PositivePrecision,
                    PositiveRecall    = metric.PositiveRecall,
                    NegativePrecision = metric.NegativePrecision,
                    NegativeRecall    = metric.NegativeRecall,
                    LogLoss           = metric.LogLoss,
                    LogLossReduction  = metric.LogLossReduction,
                    Entropy           = metric.Entropy,
                    F1Score           = metric.F1Score,
                    Auprc             = metric.Auprc,
                    RowTag            = metric.RowTag,
                    ConfusionMatrix   = confusionMatrices.Current,
                });
            }

            return(metrics);
        }
示例#3
0
        internal static List <ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix,
                                                                 int confusionMatriceStartIndex = 0)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <ClassificationMetrics> metrics = new List <ClassificationMetrics>();
            var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();

            int index = 0;

            foreach (var metric in metricsEnumerable)
            {
                if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
                {
                    throw env.Except("Confusion matrices didn't have enough matrices.");
                }

                metrics.Add(
                    new ClassificationMetrics()
                {
                    AccuracyMicro    = metric.AccuracyMicro,
                    AccuracyMacro    = metric.AccuracyMacro,
                    LogLoss          = metric.LogLoss,
                    LogLossReduction = metric.LogLossReduction,
                    TopKAccuracy     = metric.TopKAccuracy,
                    PerClassLogLoss  = metric.PerClassLogLoss,
                    ConfusionMatrix  = confusionMatrices.Current,
                    RowTag           = metric.RowTag,
                });
            }

            return(metrics);
        }
示例#4
0
        internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new BinaryClassificationMetrics()
            {
                Auc = metrics.Auc,
                Accuracy = metrics.Accuracy,
                PositivePrecision = metrics.PositivePrecision,
                PositiveRecall = metrics.PositiveRecall,
                NegativePrecision = metrics.NegativePrecision,
                NegativeRecall = metrics.NegativeRecall,
                LogLoss = metrics.LogLoss,
                LogLossReduction = metrics.LogLossReduction,
                Entropy = metrics.Entropy,
                F1Score = metrics.F1Score,
                Auprc = metrics.Auprc,
                ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
            });
        }