public void SetupPredictBenchmarks()
        {
            _trainedModel     = Train(_dataPath);
            _predictionEngine = _trainedModel.CreatePredictionEngine <IrisData, IrisPrediction>(_env);
            _consumer.Consume(_predictionEngine.Predict(_example));

            var reader = new TextLoader(_env,
                                        columns: new[]
            {
                new TextLoader.Column("Label", DataKind.R4, 0),
                new TextLoader.Column("SepalLength", DataKind.R4, 1),
                new TextLoader.Column("SepalWidth", DataKind.R4, 2),
                new TextLoader.Column("PetalLength", DataKind.R4, 3),
                new TextLoader.Column("PetalWidth", DataKind.R4, 4),
            },
                                        hasHeader: true
                                        );

            IDataView testData       = reader.Read(_dataPath);
            IDataView scoredTestData = _trainedModel.Transform(testData);
            var       evaluator      = new MultiClassClassifierEvaluator(_env, new MultiClassClassifierEvaluator.Arguments());

            _metrics = evaluator.Evaluate(scoredTestData, DefaultColumnNames.Label, DefaultColumnNames.Score, DefaultColumnNames.PredictedLabel);

            _batches = new IrisData[_batchSizes.Length][];
            for (int i = 0; i < _batches.Length; i++)
            {
                var batch = new IrisData[_batchSizes[i]];
                for (int bi = 0; bi < batch.Length; bi++)
                {
                    batch[bi] = _example;
                }
                _batches[i] = batch;
            }
        }
예제 #2
0
 /// <summary>
 /// Check that a <see cref="MultiClassClassifierMetrics"/> object is valid.
 /// </summary>
 /// <param name="metrics">The metrics object.</param>
 public static void AssertMetrics(MultiClassClassifierMetrics metrics)
 {
     Assert.InRange(metrics.MacroAccuracy, 0, 1);
     Assert.InRange(metrics.MicroAccuracy, 0, 1);
     Assert.True(metrics.LogLoss >= 0);
     Assert.InRange(metrics.TopKAccuracy, 0, 1);
 }
예제 #3
0
 /// <summary>
 /// Pretty-print MultiClassClassifierMetrics objects.
 /// </summary>
 /// <param name="metrics"><see cref="MultiClassClassifierMetrics"/> object.</param>
 public static void PrintMetrics(MultiClassClassifierMetrics metrics)
 {
     Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
     Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
     Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
     Console.WriteLine($"Log Loss Reduction: {metrics.LogLossReduction:F2}");
 }
예제 #4
0
        public static double Evaluate(MLContext context, ITransformer trainedModel, IDataView testData)
        {
            MultiClassClassifierMetrics metrics = context.MulticlassClassification.Evaluate(trainedModel.Transform(testData));
            double accuracy = metrics.AccuracyMacro;

            return(accuracy);
        }
예제 #5
0
 public MulticlassClassificationIterationResult(ITransformer model, MultiClassClassifierMetrics metrics, IDataView scoredValidationData, Pipeline pipeline = null)
 {
     Model   = model;
     Metrics = metrics;
     ScoredValidationData = scoredValidationData;
     Pipeline             = pipeline;
 }
        private void CompareMetrics(MultiClassClassifierMetrics metrics)
        {
            Assert.Equal(.98, metrics.MacroAccuracy);
            Assert.Equal(.98, metrics.MicroAccuracy, 2);
            Assert.InRange(metrics.LogLoss, .05, .06);
            Assert.InRange(metrics.LogLossReduction, 94, 96);

            Assert.Equal(3, metrics.PerClassLogLoss.Count);
            Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
            Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
            Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);
        }
 public static void PrintMultiClassClassificationMetrics(string name, MultiClassClassifierMetrics metrics)
 {
     Console.WriteLine($"************************************************************");
     Console.WriteLine($"*    Metrics for {name} multi-class classification model   ");
     Console.WriteLine($"*-----------------------------------------------------------");
     Console.WriteLine($"    AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
     Console.WriteLine($"    AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
     Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
     Console.WriteLine($"    LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
     Console.WriteLine($"    LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
     Console.WriteLine($"    LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
     Console.WriteLine($"************************************************************");
 }
        public static void ToConsole(this MultiClassClassifierMetrics result)
        {
            Console.WriteLine($"Acuracy macro: {result.AccuracyMacro}");
            Console.WriteLine($"Acuracy micro: {result.AccuracyMicro}");
            Console.WriteLine($"Log loss: {result.LogLoss}");
            Console.WriteLine($"Log loss reduction: {result.LogLossReduction}");
            Console.WriteLine($"Per class log loss: ");
            int count = 0;

            foreach (var logLossClass in result.PerClassLogLoss)
            {
                Console.WriteLine($"\t [{count++}]: {logLossClass}");
            }
            Console.WriteLine($"Top K: {result.TopK}");
            Console.WriteLine($"Top K accuracy: {result.TopKAccuracy}");
        }
        private static MultiClassClassifierMetrics MulticlassClassificationDelta(
            MultiClassClassifierMetrics a, MultiClassClassifierMetrics b)
        {
            if (a.TopK != b.TopK)
            {
                Contracts.Assert(a.TopK == b.TopK, "TopK to compare must be the same length.");
            }

            var perClassLogLoss = ComputeArrayDeltas(a.PerClassLogLoss, b.PerClassLogLoss);

            return(new MultiClassClassifierMetrics(
                       accuracyMicro: a.MicroAccuracy - b.MicroAccuracy,
                       accuracyMacro: a.MacroAccuracy - b.MacroAccuracy,
                       logLoss: a.LogLoss - b.LogLoss,
                       logLossReduction: a.LogLossReduction - b.LogLossReduction,
                       topK: a.TopK,
                       topKAccuracy: a.TopKAccuracy - b.TopKAccuracy,
                       perClassLogLoss: perClassLogLoss
                       ));
        }