internal static List <ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true); if (!metricsEnumerable.GetEnumerator().MoveNext()) { throw env.Except("The overall RegressionMetrics didn't have any rows."); } List <ClassificationMetrics> metrics = new List <ClassificationMetrics>(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int index = 0; foreach (var metric in metricsEnumerable) { if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } metrics.Add( new ClassificationMetrics() { AccuracyMicro = metric.AccuracyMicro, AccuracyMacro = metric.AccuracyMacro, LogLoss = metric.LogLoss, LogLossReduction = metric.LogLossReduction, TopKAccuracy = metric.TopKAccuracy, PerClassLogLoss = metric.PerClassLogLoss, ConfusionMatrix = confusionMatrices.Current, RowTag = metric.RowTag, }); } return(metrics); }