internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); if (!enumerator.MoveNext()) { throw env.Except("The overall RegressionMetrics didn't have any rows."); } SerializationClass metrics = enumerator.Current; if (enumerator.MoveNext()) { throw env.Except("The overall RegressionMetrics contained more than 1 row."); } return(new ClassificationMetrics() { AccuracyMicro = metrics.AccuracyMicro, AccuracyMacro = metrics.AccuracyMacro, LogLoss = metrics.LogLoss, LogLossReduction = metrics.LogLossReduction, TopKAccuracy = metrics.TopKAccuracy, PerClassLogLoss = metrics.PerClassLogLoss, ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix) }); }
internal static List <BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true); if (!metricsEnumerable.GetEnumerator().MoveNext()) { throw env.Except("The overall RegressionMetrics didn't have any rows."); } List <BinaryClassificationMetrics> metrics = new List <BinaryClassificationMetrics>(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int index = 0; foreach (var metric in metricsEnumerable) { if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } metrics.Add( new BinaryClassificationMetrics() { Auc = metric.Auc, Accuracy = metric.Accuracy, PositivePrecision = metric.PositivePrecision, PositiveRecall = metric.PositiveRecall, NegativePrecision = metric.NegativePrecision, NegativeRecall = metric.NegativeRecall, LogLoss = metric.LogLoss, LogLossReduction = metric.LogLossReduction, Entropy = metric.Entropy, F1Score = metric.F1Score, Auprc = metric.Auprc, RowTag = metric.RowTag, ConfusionMatrix = confusionMatrices.Current, }); } return(metrics); }
internal static List <ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true); if (!metricsEnumerable.GetEnumerator().MoveNext()) { throw env.Except("The overall RegressionMetrics didn't have any rows."); } List <ClassificationMetrics> metrics = new List <ClassificationMetrics>(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int index = 0; foreach (var metric in metricsEnumerable) { if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } metrics.Add( new ClassificationMetrics() { AccuracyMicro = metric.AccuracyMicro, AccuracyMacro = metric.AccuracyMacro, LogLoss = metric.LogLoss, LogLossReduction = metric.LogLossReduction, TopKAccuracy = metric.TopKAccuracy, PerClassLogLoss = metric.PerClassLogLoss, ConfusionMatrix = confusionMatrices.Current, RowTag = metric.RowTag, }); } return(metrics); }
internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); if (!enumerator.MoveNext()) { throw env.Except("The overall RegressionMetrics didn't have any rows."); } SerializationClass metrics = enumerator.Current; if (enumerator.MoveNext()) { throw env.Except("The overall RegressionMetrics contained more than 1 row."); } return(new BinaryClassificationMetrics() { Auc = metrics.Auc, Accuracy = metrics.Accuracy, PositivePrecision = metrics.PositivePrecision, PositiveRecall = metrics.PositiveRecall, NegativePrecision = metrics.NegativePrecision, NegativeRecall = metrics.NegativeRecall, LogLoss = metrics.LogLoss, LogLossReduction = metrics.LogLossReduction, Entropy = metrics.Entropy, F1Score = metrics.F1Score, Auprc = metrics.Auprc, ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix), }); }