Ejemplo n.º 1
0
        /*
         *     """Compute the number of true/false positives/negative for each class
         *
         * Parameters
         * ----------
         * y_true : array-like or list of labels or label indicator matrix
         * Ground truth (correct) labels.
         *
         * y_pred : array-like or list of labels or label indicator matrix
         * Predicted labels, as returned by a classifier.
         *
         * labels : array, shape = [n_labels], optional
         * Integer array of labels.
         *
         * Returns
         * -------
         * true_pos : array of int, shape = [n_unique_labels]
         * Number of true positives
         *
         * true_neg : array of int, shape = [n_unique_labels]
         * Number of true negative
         *
         * false_pos : array of int, shape = [n_unique_labels]
         * Number of false positives
         *
         * false_pos : array of int, shape = [n_unique_labels]
         * Number of false positives
         *
         * Examples
         * --------
         * In the binary case:
         *
         * >>> from sklearn.metrics.metrics import _tp_tn_fp_fn
         * >>> y_pred = [0, 1, 0, 0]
         * >>> y_true = [0, 1, 0, 1]
         * >>> _tp_tn_fp_fn(y_true, y_pred)
         * (array([2, 1]), array([1, 2]), array([1, 0]), array([0, 1]))
         *
         * In the multiclass case:
         * >>> y_true = np.array([0, 1, 2, 0, 1, 2])
         * >>> y_pred = np.array([0, 2, 1, 0, 0, 1])
         * >>> _tp_tn_fp_fn(y_true, y_pred)
         * (array([2, 0, 0]), array([3, 2, 3]), array([1, 2, 1]), array([0, 2, 2]))
         *
         * In the multilabel case with binary indicator format:
         *
         * >>> _tp_tn_fp_fn(np.array([[0.0, 1.0], [1.0, 1.0]]), np.zeros((2, 2)))
         * (array([0, 0]), array([1, 0]), array([0, 0]), array([1, 2]))
         *
         * and with a list of labels format:
         *
         * >>> _tp_tn_fp_fn([(1, 2), (3,)], [(1, 2), tuple()])  # doctest: +ELLIPSIS
         * (array([1, 1, 0]), array([1, 1, 1]), array([0, 0, 0]), array([0, 0, 1]))
         *
         * """
         *
         */
        private static Tuple <int[], int[], int[], int[]> TpTnFpFn(int[] yTrue, int[] yPred, int[] labels = null)
        {
            if (labels == null)
            {
                labels = Multiclass.unique_labels(yTrue, yPred);
            }

            int nLabels  = labels.Length;
            var truePos  = new int[nLabels];
            var falsePos = new int[nLabels];
            var falseNeg = new int[nLabels];
            var trueNeg  = new int[nLabels];

            for (int i = 0; i < labels.Length; i++)
            {
                var labelI = labels[i];
                truePos[i] = yPred
                             .ElementsAt(yTrue.Indices(v => v.Equals(labelI)))
                             .Indices(v => v.Equals(labelI))
                             .Count();

                trueNeg[i] = yPred
                             .ElementsAt(yTrue.Indices(v => !v.Equals(labelI)))
                             .Indices(v => !v.Equals(labelI))
                             .Count();

                falsePos[i] = yPred
                              .ElementsAt(yTrue.Indices(v => !v.Equals(labelI)))
                              .Indices(v => v.Equals(labelI))
                              .Count();

                falseNeg[i] = yPred
                              .ElementsAt(yTrue.Indices(v => v.Equals(labelI)))
                              .Indices(v => !v.Equals(labelI))
                              .Count();
            }

            return(Tuple.Create(truePos, trueNeg, falsePos, falseNeg));
        }
Ejemplo n.º 2
0
        private static PrecisionRecallResult PrecisionRecallFScoreSupportInternal(
            int[] yTrue,
            int[] yPred,
            double beta         = 1.0,
            int[] labels        = null,
            int?posLabel        = 1,
            AverageKind?average = null)
        {
            if (beta <= 0)
            {
                throw new ArgumentException("beta should be >0 in the F-beta score");
            }

            var beta2 = beta * beta;

            string yType = CheckClfTargets(yTrue, yPred);

            if (labels == null)
            {
                labels = Multiclass.unique_labels(yTrue, yPred);
            }

            var r        = TpTnFpFn(yTrue, yPred, labels);
            var truePos  = r.Item1;
            var falsePos = r.Item3;
            var falseNeg = r.Item4;
            var support  = truePos.Add(falseNeg);

            // precision and recall
            var precision = truePos.Div(truePos.Add(falsePos));
            var recall    = truePos.Div(truePos.Add(falseNeg));

            // fbeta score
            var fscore = new double[precision.Length];

            for (int i = 0; i < fscore.Length; i++)
            {
                fscore[i] = (1 + beta2) * precision[i] * recall[i] / ((beta2 * precision[i]) + recall[i]);
                if (double.IsNaN(fscore[i]))
                {
                    fscore[i] = 0;
                }
            }

            if (average == null)
            {
                return(new PrecisionRecallResult
                {
                    Precision = precision,
                    FBetaScore = fscore,
                    Recall = recall,
                    Support = support
                });
            }
            else if (yType == "binary" && posLabel.HasValue)
            {
                if (!labels.Contains(posLabel.Value))
                {
                    if (labels.Length == 1)
                    {
                        // Only negative labels
                        return(new PrecisionRecallResult
                        {
                            FBetaScore = new double[1],
                            Precision = new double[1],
                            Recall = new double[1],
                            Support = new int[1]
                        });
                    }

                    throw new ArgumentException(
                              string.Format(
                                  "pos_label={0} is not a valid label: {1}",
                                  posLabel,
                                  string.Join(",", labels)));
                }

                int posLabelIdx = Array.IndexOf(labels, posLabel);
                return(new PrecisionRecallResult
                {
                    Precision = new[] { precision[posLabelIdx] },
                    Recall = new[] { recall[posLabelIdx] },
                    FBetaScore = new[] { fscore[posLabelIdx] },
                    Support = new[] { support[posLabelIdx] }
                });
            }
            else
            {
                double avgPrecision;
                double avgRecall;
                double avgFscore;

                if (average == AverageKind.Micro)
                {
                    avgPrecision = 1.0 * truePos.Sum() / (truePos.Sum() + falsePos.Sum());
                    avgRecall    = 1.0 * truePos.Sum() / (truePos.Sum() + falseNeg.Sum());
                    avgFscore    = (1 + beta2) * (avgPrecision * avgRecall) /
                                   ((beta2 * avgPrecision) + avgRecall);

                    if (double.IsNaN(avgPrecision))
                    {
                        avgPrecision = 0.0;
                    }

                    if (double.IsNaN(avgRecall))
                    {
                        avgRecall = 0.0;
                    }

                    if (double.IsNaN(avgFscore))
                    {
                        avgFscore = 0.0;
                    }
                }
                else if (average == AverageKind.Macro)
                {
                    avgPrecision = precision.Average();
                    avgRecall    = recall.Average();
                    avgFscore    = fscore.Average();
                }
                else if (average == AverageKind.Weighted)
                {
                    if (support.All(v => v == 0))
                    {
                        avgPrecision = 0.0;
                        avgRecall    = 0.0;
                        avgFscore    = 0.0;
                    }
                    else
                    {
                        avgPrecision = precision.Mul(support).Sum() / support.Sum();
                        avgRecall    = recall.Mul(support).Sum() / support.Sum();
                        avgFscore    = fscore.Mul(support).Sum() / support.Sum();
                    }
                }
                else
                {
                    throw new ArgumentException("Unsupported argument value", "average");
                }

                return(new PrecisionRecallResult
                {
                    Precision = new[] { avgPrecision },
                    Recall = new[] { avgRecall },
                    FBetaScore = new[] { avgFscore },
                    Support = null
                });
            }
        }