Exemplo n.º 1
0
        /// <summary>
        /// Evaluate the detection output.
        /// </summary>
        /// <param name="colBottom">bottom input Blob vector (Length 2)
        ///  -# @f$ (1 \times 1 \times N \times 7) @f$ N detection results.
        ///  -# @f$ (1 \times 1 \times M \times 7) @f$ M ground truth.
        /// </param>
        /// <param name="colTop">top otuput Blob vector (Length 1)
        ///  -# @f$ (1 \times 1 \times N \times 4) @f$ N is the number of detections, and each row is: [image_id, label, confidence, true_pos, false_pos].
        /// </param>
        protected override void forward(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            float[] rgfDetData = convertF(colBottom[0].mutable_cpu_data);
            float[] rgfGtData  = convertF(colBottom[1].mutable_cpu_data);

            // Retrieve all detection results.
            Dictionary <int, LabelBBox> rgAllDetections = m_bboxUtil.GetDetectionResults(rgfDetData, colBottom[0].height, m_nBackgroundLabelId);

            // Retrieve all ground truth (including difficult ones).
            Dictionary <int, LabelBBox> rgAllGtBboxes = m_bboxUtil.GetGroundTruthEx(rgfGtData, colBottom[1].height, m_nBackgroundLabelId, true);

            colTop[0].SetData(0);
            float[] rgfTopData = convertF(colTop[0].mutable_cpu_data);
            int     nNumDet    = 0;

            // Insert number of ground truth for each label.
            Dictionary <int, int> rgNumPos = new Dictionary <int, int>();
            List <KeyValuePair <int, LabelBBox> > rgAllGtBboxList = rgAllGtBboxes.ToList();

            foreach (KeyValuePair <int, LabelBBox> kv in rgAllGtBboxList)
            {
                List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = kv.Value.ToList();
                foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                {
                    int nCount = 0;

                    if (m_bEvaluateDifficultGt)
                    {
                        nCount = kvLabel.Value.Count;
                    }
                    else
                    {
                        // Get number of non difficult ground truth.
                        for (int i = 0; i < kvLabel.Value.Count; i++)
                        {
                            if (kvLabel.Value[i].difficult)
                            {
                                nCount++;
                            }
                        }
                    }

                    if (!rgNumPos.ContainsKey(kvLabel.Key))
                    {
                        rgNumPos.Add(kvLabel.Key, nCount);
                    }
                    else
                    {
                        rgNumPos[kvLabel.Key] += nCount;
                    }
                }
            }

            for (int c = 0; c < m_nNumClasses; c++)
            {
                if (c == m_nBackgroundLabelId)
                {
                    continue;
                }

                rgfTopData[nNumDet * 5 + 0] = -1;
                rgfTopData[nNumDet * 5 + 1] = c;

                if (!rgNumPos.ContainsKey(c))
                {
                    rgfTopData[nNumDet * 5 + 2] = 0;
                }
                else
                {
                    rgfTopData[nNumDet * 5 + 2] = rgNumPos[c];
                }

                rgfTopData[nNumDet * 5 + 3] = -1;
                rgfTopData[nNumDet * 5 + 4] = -1;
                nNumDet++;
            }

            // Insert detection evaluate status.
            foreach (KeyValuePair <int, LabelBBox> kv in rgAllDetections)
            {
                int       nImageId   = kv.Key;
                LabelBBox detections = kv.Value;

                // No ground truth for current image.  All detections become false_pos.
                if (!rgAllGtBboxes.ContainsKey(nImageId))
                {
                    List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = detections.ToList();
                    foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                    {
                        int nLabel = kvLabel.Key;
                        if (nLabel == -1)
                        {
                            continue;
                        }

                        List <NormalizedBBox> bboxes = kvLabel.Value;
                        for (int i = 0; i < bboxes.Count; i++)
                        {
                            rgfTopData[nNumDet * 5 + 0] = nImageId;
                            rgfTopData[nNumDet * 5 + 1] = nLabel;
                            rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;
                            rgfTopData[nNumDet * 5 + 3] = 0;
                            rgfTopData[nNumDet * 5 + 4] = 1;
                            nNumDet++;
                        }
                    }
                }

                // Gound truth's exist for current image.
                else
                {
                    LabelBBox label_bboxes = rgAllGtBboxes[nImageId];

                    List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = detections.ToList();
                    foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                    {
                        int nLabel = kvLabel.Key;
                        if (nLabel == -1)
                        {
                            continue;
                        }

                        List <NormalizedBBox> bboxes = kvLabel.Value;

                        // No ground truth for current label. All detectiosn become false_pos
                        if (!label_bboxes.Contains(nLabel))
                        {
                            for (int i = 0; i < bboxes.Count; i++)
                            {
                                rgfTopData[nNumDet * 5 + 0] = nImageId;
                                rgfTopData[nNumDet * 5 + 1] = nLabel;
                                rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;
                                rgfTopData[nNumDet * 5 + 3] = 0;
                                rgfTopData[nNumDet * 5 + 4] = 1;
                                nNumDet++;
                            }
                        }

                        // Ground truth for current label found.
                        else
                        {
                            List <NormalizedBBox> gt_bboxes = label_bboxes[nLabel];
                            // Scale ground truth if needed.
                            if (!m_bUseNormalizedBbox)
                            {
                                m_log.CHECK_LE(m_nCount, m_rgSizes.Count, "The count must be <= the sizes count.");
                                for (int i = 0; i < gt_bboxes.Count; i++)
                                {
                                    gt_bboxes[i] = m_bboxUtil.Output(gt_bboxes[i], m_rgSizes[m_nCount], m_resizeParam);
                                }
                            }

                            List <bool> rgbVisited = Utility.Create <bool>(gt_bboxes.Count, false);

                            // Sort detections in decending order based on scores.
                            if (bboxes.Count > 1)
                            {
                                bboxes.Sort(new Comparison <NormalizedBBox>(sortBboxDescending));
                            }

                            for (int i = 0; i < bboxes.Count; i++)
                            {
                                rgfTopData[nNumDet * 5 + 0] = nImageId;
                                rgfTopData[nNumDet * 5 + 1] = nLabel;
                                rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;

                                if (!m_bUseNormalizedBbox)
                                {
                                    bboxes[i] = m_bboxUtil.Output(bboxes[i], m_rgSizes[m_nCount], m_resizeParam);
                                }

                                // Compare with each ground truth bbox.
                                float fOverlapMax = -1;
                                int   nJmax       = -1;

                                for (int j = 0; j < gt_bboxes.Count; j++)
                                {
                                    float fOverlap = m_bboxUtil.JaccardOverlap(bboxes[i], gt_bboxes[j], m_bUseNormalizedBbox);
                                    if (fOverlap > fOverlapMax)
                                    {
                                        fOverlapMax = fOverlap;
                                        nJmax       = j;
                                    }
                                }

                                if (fOverlapMax >= m_fOverlapThreshold)
                                {
                                    if (m_bEvaluateDifficultGt || (!m_bEvaluateDifficultGt && !gt_bboxes[nJmax].difficult))
                                    {
                                        // True positive.
                                        if (!rgbVisited[nJmax])
                                        {
                                            rgfTopData[nNumDet * 5 + 3] = 1;
                                            rgfTopData[nNumDet * 5 + 4] = 0;
                                            rgbVisited[nJmax]           = true;
                                        }
                                        // False positive (multiple detectioN).
                                        else
                                        {
                                            rgfTopData[nNumDet * 5 + 3] = 0;
                                            rgfTopData[nNumDet * 5 + 4] = 1;
                                        }
                                    }
                                }
                                else
                                {
                                    // False positive.
                                    rgfTopData[nNumDet * 5 + 3] = 0;
                                    rgfTopData[nNumDet * 5 + 4] = 1;
                                }

                                nNumDet++;
                            }
                        }
                    }
                }

                if (m_rgSizes.Count > 0)
                {
                    m_nCount++;

                    // Reset count after a full iteration through the DB.
                    if (m_nCount == m_rgSizes.Count)
                    {
                        m_nCount = 0;
                    }
                }
            }

            colTop[0].mutable_cpu_data = convert(rgfTopData);
        }
Exemplo n.º 2
0
        /// <summary>
        /// Check if the sampled bbox satisfies the constraints with all object bboxes.
        /// </summary>
        /// <param name="sampledBBox">Specifies the sampled BBox.</param>
        /// <param name="rgObjectBboxes">Specifies the list of object normalized BBoxes.</param>
        /// <param name="sampleConstraint">Specifies the sample constraint.</param>
        /// <returns>Returns whether or not the sample constraints are satisfied.</returns>
        public bool SatisfySampleConstraint(NormalizedBBox sampledBBox, List <NormalizedBBox> rgObjectBboxes, SamplerConstraint sampleConstraint)
        {
            bool bHasJaccardOverlap = sampleConstraint.min_jaccard_overlap.HasValue || sampleConstraint.max_jaccard_overlap.HasValue;
            bool bHasSampleCoverage = sampleConstraint.min_sample_coverage.HasValue || sampleConstraint.max_sample_coverage.HasValue;
            bool bHasObjectCoverage = sampleConstraint.min_object_coverage.HasValue || sampleConstraint.max_object_coverage.HasValue;
            bool bSatisfy           = !bHasJaccardOverlap && !bHasSampleCoverage && !bHasObjectCoverage;

            // By default, the sampledBBox is 'positive' if not constraints are defined.
            if (bSatisfy)
            {
                return(true);
            }

            // Check constraints.
            bool bFound = false;

            for (int i = 0; i < rgObjectBboxes.Count; i++)
            {
                NormalizedBBox objectBbox = rgObjectBboxes[i];

                // Test jaccard overlap.
                if (bHasJaccardOverlap)
                {
                    float fJaccardOverlap = m_util.JaccardOverlap(sampledBBox, objectBbox);

                    if (sampleConstraint.min_jaccard_overlap.HasValue && fJaccardOverlap < sampleConstraint.min_jaccard_overlap.Value)
                    {
                        continue;
                    }

                    if (sampleConstraint.max_jaccard_overlap.HasValue && fJaccardOverlap > sampleConstraint.max_jaccard_overlap.Value)
                    {
                        continue;
                    }

                    bFound = true;
                }

                // Test sample coverage
                if (bHasSampleCoverage)
                {
                    float fSampleCoverage = m_util.Coverage(sampledBBox, objectBbox);

                    if (sampleConstraint.min_sample_coverage.HasValue && fSampleCoverage < sampleConstraint.min_sample_coverage.Value)
                    {
                        continue;
                    }

                    if (sampleConstraint.max_sample_coverage.HasValue && fSampleCoverage > sampleConstraint.max_sample_coverage.Value)
                    {
                        continue;
                    }

                    bFound = true;
                }

                // Test object coverage
                if (bHasObjectCoverage)
                {
                    float fObjectOverage = m_util.Coverage(objectBbox, sampledBBox);

                    if (sampleConstraint.min_object_coverage.HasValue && fObjectOverage < sampleConstraint.min_object_coverage.Value)
                    {
                        continue;
                    }

                    if (sampleConstraint.max_jaccard_overlap.HasValue && fObjectOverage > sampleConstraint.max_jaccard_overlap.Value)
                    {
                        continue;
                    }

                    bFound = true;
                }

                if (bFound)
                {
                    return(true);
                }
            }

            return(bFound);
        }