コード例 #1
0
        internal static RegressionMetrics FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new RegressionMetrics()
            {
                L1 = metrics.L1,
                L2 = metrics.L2,
                Rms = metrics.Rms,
                LossFn = metrics.LossFn,
                RSquared = metrics.RSquared,
            });
        }
コード例 #2
0
        internal static List <ClusterMetrics> FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall ClusteringMetrics didn't have any rows.");
            }

            var metrics = new List <ClusterMetrics>();

            foreach (var metric in metricsEnumerable)
            {
                metrics.Add(new ClusterMetrics()
                {
                    AvgMinScore = metric.AvgMinScore,
                    Nmi         = metric.Nmi,
                    Dbi         = metric.Dbi,
                    RowTag      = metric.RowTag,
                });
            }

            return(metrics);
        }
コード例 #3
0
        internal static List <RegressionMetrics> FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <RegressionMetrics> metrics = new List <RegressionMetrics>();

            foreach (var metric in metricsEnumerable)
            {
                metrics.Add(new RegressionMetrics()
                {
                    L1       = metric.L1,
                    L2       = metric.L2,
                    Rms      = metric.Rms,
                    LossFn   = metric.LossFn,
                    RSquared = metric.RSquared,
                });
            }

            return(metrics);
        }
コード例 #4
0
        internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new ClassificationMetrics()
            {
                AccuracyMicro = metrics.AccuracyMicro,
                AccuracyMacro = metrics.AccuracyMacro,
                LogLoss = metrics.LogLoss,
                LogLossReduction = metrics.LogLossReduction,
                TopKAccuracy = metrics.TopKAccuracy,
                PerClassLogLoss = metrics.PerClassLogLoss,
                ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix)
            });
        }
コード例 #5
0
        public static IEnumerable <TRow> AsEnumerable <TRow>(this IDataView data, bool reuseRowObject,
                                                             bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
            where TRow : class, new()
        {
            // REVIEW: Take an env as a parameter.
            var env = new ConsoleEnvironment();

            return(data.AsEnumerable <TRow>(env, reuseRowObject, ignoreMissingColumns, schemaDefinition));
        }
コード例 #6
0
 public static void ShowPredictions(LocalEnvironment env, IDataView data, bool label = true, int count = 2)
 {
     data
     // Convert to an enumerable of user-defined type.
     .AsEnumerable <TransactionFraudPrediction>(env, reuseRowObject: false)
     .Where(x => x.PredictedLabel == label)
     // Take a couple values as an array.
     .Take(count)
     .ToList()
     // print to console
     .ForEach(row => { row.PrintToConsole(); });
 }
コード例 #7
0
        internal static List <BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <BinaryClassificationMetrics> metrics = new List <BinaryClassificationMetrics>();
            var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();

            int index = 0;

            foreach (var metric in metricsEnumerable)
            {
                if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
                {
                    throw env.Except("Confusion matrices didn't have enough matrices.");
                }

                metrics.Add(
                    new BinaryClassificationMetrics()
                {
                    Auc               = metric.Auc,
                    Accuracy          = metric.Accuracy,
                    PositivePrecision = metric.PositivePrecision,
                    PositiveRecall    = metric.PositiveRecall,
                    NegativePrecision = metric.NegativePrecision,
                    NegativeRecall    = metric.NegativeRecall,
                    LogLoss           = metric.LogLoss,
                    LogLossReduction  = metric.LogLossReduction,
                    Entropy           = metric.Entropy,
                    F1Score           = metric.F1Score,
                    Auprc             = metric.Auprc,
                    RowTag            = metric.RowTag,
                    ConfusionMatrix   = confusionMatrices.Current,
                });
            }

            return(metrics);
        }
コード例 #8
0
        internal static List <ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix,
                                                                 int confusionMatriceStartIndex = 0)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);

            if (!metricsEnumerable.GetEnumerator().MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            List <ClassificationMetrics> metrics = new List <ClassificationMetrics>();
            var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();

            int index = 0;

            foreach (var metric in metricsEnumerable)
            {
                if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
                {
                    throw env.Except("Confusion matrices didn't have enough matrices.");
                }

                metrics.Add(
                    new ClassificationMetrics()
                {
                    AccuracyMicro    = metric.AccuracyMicro,
                    AccuracyMacro    = metric.AccuracyMacro,
                    LogLoss          = metric.LogLoss,
                    LogLossReduction = metric.LogLossReduction,
                    TopKAccuracy     = metric.TopKAccuracy,
                    PerClassLogLoss  = metric.PerClassLogLoss,
                    ConfusionMatrix  = confusionMatrices.Current,
                    RowTag           = metric.RowTag,
                });
            }

            return(metrics);
        }
コード例 #9
0
        internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
        {
            Contracts.AssertValue(env);
            env.AssertValue(overallMetrics);
            env.AssertValue(confusionMatrix);

            var metricsEnumerable = overallMetrics.AsEnumerable <SerializationClass>(env, true, ignoreMissingColumns: true);
            var enumerator        = metricsEnumerable.GetEnumerator();

            if (!enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics didn't have any rows.");
            }

            SerializationClass metrics = enumerator.Current;

            if (enumerator.MoveNext())
            {
                throw env.Except("The overall RegressionMetrics contained more than 1 row.");
            }

            return(new BinaryClassificationMetrics()
            {
                Auc = metrics.Auc,
                Accuracy = metrics.Accuracy,
                PositivePrecision = metrics.PositivePrecision,
                PositiveRecall = metrics.PositiveRecall,
                NegativePrecision = metrics.NegativePrecision,
                NegativeRecall = metrics.NegativeRecall,
                LogLoss = metrics.LogLoss,
                LogLossReduction = metrics.LogLossReduction,
                Entropy = metrics.Entropy,
                F1Score = metrics.F1Score,
                Auprc = metrics.Auprc,
                ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
            });
        }