Ejemplo n.º 1
0
        /// <summary>
        /// Evaluate the detection output.
        /// </summary>
        /// <param name="colBottom">bottom input Blob vector (Length 2)
        ///  -# @f$ (1 \times 1 \times N \times 7) @f$ N detection results.
        ///  -# @f$ (1 \times 1 \times M \times 7) @f$ M ground truth.
        /// </param>
        /// <param name="colTop">top otuput Blob vector (Length 1)
        ///  -# @f$ (1 \times 1 \times N \times 4) @f$ N is the number of detections, and each row is: [image_id, label, confidence, true_pos, false_pos].
        /// </param>
        protected override void forward(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            float[] rgfDetData = convertF(colBottom[0].mutable_cpu_data);
            float[] rgfGtData  = convertF(colBottom[1].mutable_cpu_data);

            // Retrieve all detection results.
            Dictionary <int, LabelBBox> rgAllDetections = m_bboxUtil.GetDetectionResults(rgfDetData, colBottom[0].height, m_nBackgroundLabelId);

            // Retrieve all ground truth (including difficult ones).
            Dictionary <int, LabelBBox> rgAllGtBboxes = m_bboxUtil.GetGroundTruthEx(rgfGtData, colBottom[1].height, m_nBackgroundLabelId, true);

            colTop[0].SetData(0);
            float[] rgfTopData = convertF(colTop[0].mutable_cpu_data);
            int     nNumDet    = 0;

            // Insert number of ground truth for each label.
            Dictionary <int, int> rgNumPos = new Dictionary <int, int>();
            List <KeyValuePair <int, LabelBBox> > rgAllGtBboxList = rgAllGtBboxes.ToList();

            foreach (KeyValuePair <int, LabelBBox> kv in rgAllGtBboxList)
            {
                List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = kv.Value.ToList();
                foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                {
                    int nCount = 0;

                    if (m_bEvaluateDifficultGt)
                    {
                        nCount = kvLabel.Value.Count;
                    }
                    else
                    {
                        // Get number of non difficult ground truth.
                        for (int i = 0; i < kvLabel.Value.Count; i++)
                        {
                            if (kvLabel.Value[i].difficult)
                            {
                                nCount++;
                            }
                        }
                    }

                    if (!rgNumPos.ContainsKey(kvLabel.Key))
                    {
                        rgNumPos.Add(kvLabel.Key, nCount);
                    }
                    else
                    {
                        rgNumPos[kvLabel.Key] += nCount;
                    }
                }
            }

            for (int c = 0; c < m_nNumClasses; c++)
            {
                if (c == m_nBackgroundLabelId)
                {
                    continue;
                }

                rgfTopData[nNumDet * 5 + 0] = -1;
                rgfTopData[nNumDet * 5 + 1] = c;

                if (!rgNumPos.ContainsKey(c))
                {
                    rgfTopData[nNumDet * 5 + 2] = 0;
                }
                else
                {
                    rgfTopData[nNumDet * 5 + 2] = rgNumPos[c];
                }

                rgfTopData[nNumDet * 5 + 3] = -1;
                rgfTopData[nNumDet * 5 + 4] = -1;
                nNumDet++;
            }

            // Insert detection evaluate status.
            foreach (KeyValuePair <int, LabelBBox> kv in rgAllDetections)
            {
                int       nImageId   = kv.Key;
                LabelBBox detections = kv.Value;

                // No ground truth for current image.  All detections become false_pos.
                if (!rgAllGtBboxes.ContainsKey(nImageId))
                {
                    List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = detections.ToList();
                    foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                    {
                        int nLabel = kvLabel.Key;
                        if (nLabel == -1)
                        {
                            continue;
                        }

                        List <NormalizedBBox> bboxes = kvLabel.Value;
                        for (int i = 0; i < bboxes.Count; i++)
                        {
                            rgfTopData[nNumDet * 5 + 0] = nImageId;
                            rgfTopData[nNumDet * 5 + 1] = nLabel;
                            rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;
                            rgfTopData[nNumDet * 5 + 3] = 0;
                            rgfTopData[nNumDet * 5 + 4] = 1;
                            nNumDet++;
                        }
                    }
                }

                // Gound truth's exist for current image.
                else
                {
                    LabelBBox label_bboxes = rgAllGtBboxes[nImageId];

                    List <KeyValuePair <int, List <NormalizedBBox> > > kvLabels = detections.ToList();
                    foreach (KeyValuePair <int, List <NormalizedBBox> > kvLabel in kvLabels)
                    {
                        int nLabel = kvLabel.Key;
                        if (nLabel == -1)
                        {
                            continue;
                        }

                        List <NormalizedBBox> bboxes = kvLabel.Value;

                        // No ground truth for current label. All detectiosn become false_pos
                        if (!label_bboxes.Contains(nLabel))
                        {
                            for (int i = 0; i < bboxes.Count; i++)
                            {
                                rgfTopData[nNumDet * 5 + 0] = nImageId;
                                rgfTopData[nNumDet * 5 + 1] = nLabel;
                                rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;
                                rgfTopData[nNumDet * 5 + 3] = 0;
                                rgfTopData[nNumDet * 5 + 4] = 1;
                                nNumDet++;
                            }
                        }

                        // Ground truth for current label found.
                        else
                        {
                            List <NormalizedBBox> gt_bboxes = label_bboxes[nLabel];
                            // Scale ground truth if needed.
                            if (!m_bUseNormalizedBbox)
                            {
                                m_log.CHECK_LE(m_nCount, m_rgSizes.Count, "The count must be <= the sizes count.");
                                for (int i = 0; i < gt_bboxes.Count; i++)
                                {
                                    gt_bboxes[i] = m_bboxUtil.Output(gt_bboxes[i], m_rgSizes[m_nCount], m_resizeParam);
                                }
                            }

                            List <bool> rgbVisited = Utility.Create <bool>(gt_bboxes.Count, false);

                            // Sort detections in decending order based on scores.
                            if (bboxes.Count > 1)
                            {
                                bboxes.Sort(new Comparison <NormalizedBBox>(sortBboxDescending));
                            }

                            for (int i = 0; i < bboxes.Count; i++)
                            {
                                rgfTopData[nNumDet * 5 + 0] = nImageId;
                                rgfTopData[nNumDet * 5 + 1] = nLabel;
                                rgfTopData[nNumDet * 5 + 2] = bboxes[i].score;

                                if (!m_bUseNormalizedBbox)
                                {
                                    bboxes[i] = m_bboxUtil.Output(bboxes[i], m_rgSizes[m_nCount], m_resizeParam);
                                }

                                // Compare with each ground truth bbox.
                                float fOverlapMax = -1;
                                int   nJmax       = -1;

                                for (int j = 0; j < gt_bboxes.Count; j++)
                                {
                                    float fOverlap = m_bboxUtil.JaccardOverlap(bboxes[i], gt_bboxes[j], m_bUseNormalizedBbox);
                                    if (fOverlap > fOverlapMax)
                                    {
                                        fOverlapMax = fOverlap;
                                        nJmax       = j;
                                    }
                                }

                                if (fOverlapMax >= m_fOverlapThreshold)
                                {
                                    if (m_bEvaluateDifficultGt || (!m_bEvaluateDifficultGt && !gt_bboxes[nJmax].difficult))
                                    {
                                        // True positive.
                                        if (!rgbVisited[nJmax])
                                        {
                                            rgfTopData[nNumDet * 5 + 3] = 1;
                                            rgfTopData[nNumDet * 5 + 4] = 0;
                                            rgbVisited[nJmax]           = true;
                                        }
                                        // False positive (multiple detectioN).
                                        else
                                        {
                                            rgfTopData[nNumDet * 5 + 3] = 0;
                                            rgfTopData[nNumDet * 5 + 4] = 1;
                                        }
                                    }
                                }
                                else
                                {
                                    // False positive.
                                    rgfTopData[nNumDet * 5 + 3] = 0;
                                    rgfTopData[nNumDet * 5 + 4] = 1;
                                }

                                nNumDet++;
                            }
                        }
                    }
                }

                if (m_rgSizes.Count > 0)
                {
                    m_nCount++;

                    // Reset count after a full iteration through the DB.
                    if (m_nCount == m_rgSizes.Count)
                    {
                        m_nCount = 0;
                    }
                }
            }

            colTop[0].mutable_cpu_data = convert(rgfTopData);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Do non-maximum suppression (nms) on prediction results.
        /// </summary>
        /// <param name="colBottom">bottom input Blob vector (at least 2)
        ///  -# @f$ (N \times C1 \times 1 \times 1) @f$ the location predictions with C1 predictions.
        ///  -# @f$ (N \times C2 \times 1 \times 1) @f$ the confidence predictions with C2 predictions.
        ///  -# @f$ (N \times 2 \times C3 \times 1) @f$ the prior bounding boxes with C3 values.
        /// </param>
        /// <param name="colTop">top otuput Blob vector (Length 1)
        ///  -# @f$ (1 \times 1 \times N \times 7) @f$ N is the number of detections after, and each row is: [image_id, label, confidence, xmin, ymin, xmax, ymax].
        /// </param>
        protected override void forward(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            float[] rgfLocData   = convertF(colBottom[0].mutable_cpu_data);
            float[] rgfConfData  = convertF(colBottom[1].mutable_cpu_data);
            float[] rgfPriorData = convertF(colBottom[2].mutable_cpu_data);
            int     nNum         = colBottom[0].num;

            // Retrieve all location predictions.
            List <LabelBBox> rgAllLocPreds = m_bboxUtil.GetLocPredictions(rgfLocData, nNum, m_nNumPriors, m_nNumLocClasses, m_bShareLocations);

            // Retrieve all confidence scores.
            List <Dictionary <int, List <float> > > rgAllConfScores = m_bboxUtil.GetConfidenceScores(rgfConfData, nNum, m_nNumPriors, m_nNumClasses);

            // Retrieve all prior bboxes, which is the same within a batch since we assume all
            // images in a batch are of the same dimension.
            List <List <float> >  rgrgPriorVariances;
            List <NormalizedBBox> rgPriorBboxes = m_bboxUtil.GetPrior(rgfPriorData, m_nNumPriors, out rgrgPriorVariances);

            // Decode all loc predictions to bboxes.
            bool             bClipBbox         = false;
            List <LabelBBox> rgAllDecodeBboxes = m_bboxUtil.DecodeAll(rgAllLocPreds, rgPriorBboxes, rgrgPriorVariances, nNum, m_bShareLocations, m_nNumLocClasses, m_nBackgroundLabelId, m_codeType, m_bVarianceEncodedInTarget, bClipBbox);

            int nNumKept = 0;
            List <Dictionary <int, List <int> > > rgAllIndices = new List <Dictionary <int, List <int> > >();

            for (int i = 0; i < nNum; i++)
            {
                LabelBBox decode_bboxes = rgAllDecodeBboxes[i];
                Dictionary <int, List <float> > rgConfScores = rgAllConfScores[i];
                Dictionary <int, List <int> >   rgIndices    = new Dictionary <int, List <int> >();
                int nNumDet = 0;

                for (int c = 0; c < m_nNumClasses; c++)
                {
                    // Ignore background class.
                    if (c == m_nBackgroundLabelId)
                    {
                        continue;
                    }

                    // Something bad happened if there are no predictions for the current label.
                    if (!rgConfScores.ContainsKey(c))
                    {
                        m_log.FAIL("Could not find confidence predictions for label '" + c.ToString() + "'!");
                    }

                    List <float> rgfScores = rgConfScores[c];
                    int          nLabel    = (m_bShareLocations) ? -1 : c;

                    // Something bad happened if there are no locations for the current label.
                    if (!decode_bboxes.Contains(nLabel))
                    {
                        m_log.FAIL("Could not find location predictions for the label '" + nLabel.ToString() + "'!");
                    }

                    List <NormalizedBBox> rgBboxes = decode_bboxes[nLabel];
                    List <int>            rgIndexes;
                    m_bboxUtil.ApplyNMSFast(rgBboxes, rgfScores, m_fConfidenceThreshold, m_fNmsThreshold, m_fEta, m_nTopK, out rgIndexes);
                    rgIndices[c] = rgIndexes;
                    nNumDet     += rgIndices[c].Count;
                }

                if (m_nKeepTopK > -1 && nNumDet > m_nKeepTopK)
                {
                    List <Tuple <float, Tuple <int, int> > > rgScoreIndexPairs = new List <Tuple <float, Tuple <int, int> > >();

                    foreach (KeyValuePair <int, List <int> > kv in rgIndices)
                    {
                        int        nLabel         = kv.Key;
                        List <int> rgLabelIndices = kv.Value;

                        // Something bad happend for the current label.
                        if (!rgConfScores.ContainsKey(nLabel))
                        {
                            m_log.FAIL("Could not find location predictions for label " + nLabel.ToString() + "!");
                        }

                        List <float> rgScores = rgConfScores[nLabel];
                        for (int j = 0; j < rgLabelIndices.Count; j++)
                        {
                            int nIdx = rgLabelIndices[j];
                            m_log.CHECK_LT(nIdx, rgScores.Count, "The current index must be less than the number of scores!");
                            rgScoreIndexPairs.Add(new Tuple <float, Tuple <int, int> >(rgScores[nIdx], new Tuple <int, int>(nLabel, nIdx)));
                        }
                    }

                    // Keep top k results per image.
                    rgScoreIndexPairs = rgScoreIndexPairs.OrderByDescending(p => p.Item1).ToList();
                    if (rgScoreIndexPairs.Count > m_nKeepTopK)
                    {
                        rgScoreIndexPairs = rgScoreIndexPairs.Take(m_nKeepTopK).ToList();
                    }

                    // Store the new indices.
                    Dictionary <int, List <int> > rgNewIndices = new Dictionary <int, List <int> >();
                    for (int j = 0; j < rgScoreIndexPairs.Count; j++)
                    {
                        int nLabel = rgScoreIndexPairs[j].Item2.Item1;
                        int nIdx   = rgScoreIndexPairs[j].Item2.Item2;

                        if (!rgNewIndices.ContainsKey(nLabel))
                        {
                            rgNewIndices.Add(nLabel, new List <int>());
                        }

                        rgNewIndices[nLabel].Add(nIdx);
                    }

                    rgAllIndices.Add(rgNewIndices);
                    nNumKept += m_nKeepTopK;
                }
                else
                {
                    rgAllIndices.Add(rgIndices);
                    nNumKept += nNumDet;
                }
            }

            List <int> rgTopShape = Utility.Create <int>(2, 1);

            rgTopShape.Add(nNumKept);
            rgTopShape.Add(7);
            float[] rgfTopData = null;

            if (nNumKept == 0)
            {
                m_log.WriteLine("WARNING: Could not find any detections.");
                rgTopShape[2] = nNum;
                colTop[0].Reshape(rgTopShape);

                colTop[0].SetData(-1);
                rgfTopData = convertF(colTop[0].mutable_cpu_data);
                int nOffset = 0;

                // Generate fake results per image.
                for (int i = 0; i < nNum; i++)
                {
                    rgfTopData[nOffset + 0] = i;
                    nOffset += 7;
                }
            }
            else
            {
                colTop[0].Reshape(rgTopShape);
                rgfTopData = convertF(colTop[0].mutable_cpu_data);
            }

            int    nCount = 0;
            string strDir = m_strOutputDir;

            for (int i = 0; i < nNum; i++)
            {
                Dictionary <int, List <float> > rgConfScores = rgAllConfScores[i];
                LabelBBox decode_bboxes = rgAllDecodeBboxes[i];

                foreach (KeyValuePair <int, List <int> > kv in rgAllIndices[i])
                {
                    int nLabel = kv.Key;

                    // Something bad happened if there are no predictions for the current label.
                    if (!rgConfScores.ContainsKey(nLabel))
                    {
                        m_log.FAIL("Could not find confidence predictions for label '" + nLabel.ToString() + "'!");
                    }

                    List <float> rgfScores = rgConfScores[nLabel];
                    int          nLocLabel = (m_bShareLocations) ? -1 : nLabel;

                    // Something bad happened if therea re no predictions for the current label.
                    if (!decode_bboxes.Contains(nLocLabel))
                    {
                        m_log.FAIL("COuld not find location predictions for label '" + nLabel.ToString() + "'!");
                    }

                    List <NormalizedBBox> rgBboxes  = decode_bboxes[nLocLabel];
                    List <int>            rgIndices = kv.Value;

                    if (m_bNeedSave)
                    {
                        m_log.CHECK(m_rgLabelToName.ContainsKey(nLabel), "The label to name mapping does not contain the label '" + nLabel.ToString() + "'!");
                        m_log.CHECK_LT(m_nNameCount, m_rgstrNames.Count, "The name count must be less than the number of names.");
                    }

                    for (int j = 0; j < rgIndices.Count; j++)
                    {
                        int nIdx = rgIndices[j];
                        rgfTopData[nCount * 7 + 0] = i;
                        rgfTopData[nCount * 7 + 1] = nLabel;
                        rgfTopData[nCount * 7 + 2] = rgfScores[nIdx];

                        NormalizedBBox bbox = rgBboxes[nIdx];
                        rgfTopData[nCount * 7 + 3] = bbox.xmin;
                        rgfTopData[nCount * 7 + 4] = bbox.ymin;
                        rgfTopData[nCount * 7 + 5] = bbox.xmax;
                        rgfTopData[nCount * 7 + 6] = bbox.ymax;

                        if (m_bNeedSave)
                        {
                            NormalizedBBox out_bbox = m_bboxUtil.Output(bbox, m_rgSizes[m_nNameCount], m_resizeParam);

                            float fScore = rgfTopData[nCount * 7 + 2];
                            float fXmin  = out_bbox.xmin;
                            float fYmin  = out_bbox.ymin;
                            float fXmax  = out_bbox.xmax;
                            float fYmax  = out_bbox.ymax;

                            PropertyTree pt_xmin = new PropertyTree();
                            pt_xmin.Put("", Math.Round(fXmin * 100) / 100);

                            PropertyTree pt_ymin = new PropertyTree();
                            pt_ymin.Put("", Math.Round(fYmin * 100) / 100);

                            PropertyTree pt_wd = new PropertyTree();
                            pt_wd.Put("", Math.Round((fXmax - fXmin) * 100) / 100);

                            PropertyTree pt_ht = new PropertyTree();
                            pt_ht.Put("", Math.Round((fYmax - fYmin) * 100) / 100);

                            PropertyTree cur_bbox = new PropertyTree();
                            cur_bbox.AddChild("", pt_xmin);
                            cur_bbox.AddChild("", pt_ymin);
                            cur_bbox.AddChild("", pt_wd);
                            cur_bbox.AddChild("", pt_ht);

                            PropertyTree cur_det = new PropertyTree();
                            cur_det.Put("image_id", m_rgstrNames[m_nNameCount]);
                            if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.ILSVRC)
                            {
                                cur_det.Put("category_id", nLabel);
                            }
                            else
                            {
                                cur_det.Put("category_id", m_rgLabelToName[nLabel]);
                            }

                            cur_det.AddChild("bbox", cur_bbox);
                            cur_det.Put("score", fScore);

                            m_detections.AddChild("", cur_det);
                        }

                        nCount++;
                    }
                }

                if (m_bNeedSave)
                {
                    m_nNameCount++;

                    if (m_nNameCount % m_nNumTestImage == 0)
                    {
                        if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.VOC)
                        {
                            Dictionary <string, StreamWriter> rgOutFiles = new Dictionary <string, StreamWriter>();

                            for (int c = 0; c < m_nNumClasses; c++)
                            {
                                if (c == m_nBackgroundLabelId)
                                {
                                    continue;
                                }

                                string strLabelName = m_rgLabelToName[c];
                                string strFile      = getFileName(strLabelName, "txt");
                                rgOutFiles.Add(strLabelName, new StreamWriter(strFile));
                            }

                            foreach (PropertyTree pt in m_detections.Children)
                            {
                                string strLabel = pt.Get("category_id").Value;
                                if (!rgOutFiles.ContainsKey(strLabel))
                                {
                                    m_log.WriteLine("WARNING! Cannot find '" + strLabel + "' label in the output files!");
                                    continue;
                                }

                                string strImageName = pt.Get("image_id").Value;
                                float  fScore       = (float)pt.Get("score").Numeric;

                                List <int> bbox = new List <int>();
                                foreach (Property elm in pt.GetChildren("bbox"))
                                {
                                    bbox.Add((int)elm.Numeric);
                                }

                                string strLine = strImageName;
                                strLine += " " + fScore.ToString();
                                strLine += " " + bbox[0].ToString() + " " + bbox[1].ToString();
                                strLine += " " + (bbox[0] + bbox[2]).ToString();
                                strLine += " " + (bbox[1] + bbox[3]).ToString();
                                rgOutFiles[strLabel].WriteLine(strLine);
                            }

                            for (int c = 0; c < m_nNumClasses; c++)
                            {
                                if (c == m_nBackgroundLabelId)
                                {
                                    continue;
                                }

                                string strLabel = m_rgLabelToName[c];
                                rgOutFiles[strLabel].Flush();
                                rgOutFiles[strLabel].Close();
                                rgOutFiles[strLabel].Dispose();
                            }
                        }
                        else if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.COCO)
                        {
                            string strFile = getFileName("", "json");
                            using (StreamWriter sw = new StreamWriter(strFile))
                            {
                                PropertyTree output = new PropertyTree();
                                output.AddChild("detections", m_detections);
                                string strOut = output.ToJson();
                                sw.Write(strOut);
                            }
                        }
                        else if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.ILSVRC)
                        {
                            string strFile = getFileName("", "txt");
                            using (StreamWriter sw = new StreamWriter(strFile))
                            {
                                foreach (PropertyTree pt in m_detections.Children)
                                {
                                    int    nLabel       = (int)pt.Get("category_id").Numeric;
                                    string strImageName = pt.Get("image_id").Value;
                                    float  fScore       = (float)pt.Get("score").Numeric;

                                    List <int> bbox = new List <int>();
                                    foreach (Property elm in pt.GetChildren("bbox"))
                                    {
                                        bbox.Add((int)elm.Numeric);
                                    }

                                    string strLine = strImageName;
                                    strLine += " " + fScore.ToString();
                                    strLine += " " + bbox[0].ToString() + " " + bbox[1].ToString();
                                    strLine += " " + (bbox[0] + bbox[2]).ToString();
                                    strLine += " " + (bbox[1] + bbox[3]).ToString();
                                    sw.WriteLine(strLine);
                                }
                            }
                        }

                        m_nNameCount = 0;
                        m_detections.Clear();
                    }
                }

                if (m_bVisualize)
                {
#warning DetectionOutputLayer - does not visualize detections yet.
                    // TBD.
                }
            }

            colTop[0].mutable_cpu_data = convert(rgfTopData);
            colTop[0].type             = BLOB_TYPE.MULTIBBOX;
        }