/// <summary> /// Do non-maximum suppression (nms) on prediction results. /// </summary> /// <param name="colBottom">bottom input Blob vector (at least 2) /// -# @f$ (N \times C1 \times 1 \times 1) @f$ the location predictions with C1 predictions. /// -# @f$ (N \times C2 \times 1 \times 1) @f$ the confidence predictions with C2 predictions. /// -# @f$ (N \times 2 \times C3 \times 1) @f$ the prior bounding boxes with C3 values. /// </param> /// <param name="colTop">top otuput Blob vector (Length 1) /// -# @f$ (1 \times 1 \times N \times 7) @f$ N is the number of detections after, and each row is: [image_id, label, confidence, xmin, ymin, xmax, ymax]. /// </param> protected override void forward(BlobCollection <T> colBottom, BlobCollection <T> colTop) { float[] rgfLocData = convertF(colBottom[0].mutable_cpu_data); float[] rgfConfData = convertF(colBottom[1].mutable_cpu_data); float[] rgfPriorData = convertF(colBottom[2].mutable_cpu_data); int nNum = colBottom[0].num; // Retrieve all location predictions. List <LabelBBox> rgAllLocPreds = m_bboxUtil.GetLocPredictions(rgfLocData, nNum, m_nNumPriors, m_nNumLocClasses, m_bShareLocations); // Retrieve all confidence scores. List <Dictionary <int, List <float> > > rgAllConfScores = m_bboxUtil.GetConfidenceScores(rgfConfData, nNum, m_nNumPriors, m_nNumClasses); // Retrieve all prior bboxes, which is the same within a batch since we assume all // images in a batch are of the same dimension. List <List <float> > rgrgPriorVariances; List <NormalizedBBox> rgPriorBboxes = m_bboxUtil.GetPrior(rgfPriorData, m_nNumPriors, out rgrgPriorVariances); // Decode all loc predictions to bboxes. bool bClipBbox = false; List <LabelBBox> rgAllDecodeBboxes = m_bboxUtil.DecodeAll(rgAllLocPreds, rgPriorBboxes, rgrgPriorVariances, nNum, m_bShareLocations, m_nNumLocClasses, m_nBackgroundLabelId, m_codeType, m_bVarianceEncodedInTarget, bClipBbox); int nNumKept = 0; List <Dictionary <int, List <int> > > rgAllIndices = new List <Dictionary <int, List <int> > >(); for (int i = 0; i < nNum; i++) { LabelBBox decode_bboxes = rgAllDecodeBboxes[i]; Dictionary <int, List <float> > rgConfScores = rgAllConfScores[i]; Dictionary <int, List <int> > rgIndices = new Dictionary <int, List <int> >(); int nNumDet = 0; for (int c = 0; c < m_nNumClasses; c++) { // Ignore background class. if (c == m_nBackgroundLabelId) { continue; } // Something bad happened if there are no predictions for the current label. if (!rgConfScores.ContainsKey(c)) { m_log.FAIL("Could not find confidence predictions for label '" + c.ToString() + "'!"); } List <float> rgfScores = rgConfScores[c]; int nLabel = (m_bShareLocations) ? -1 : c; // Something bad happened if there are no locations for the current label. if (!decode_bboxes.Contains(nLabel)) { m_log.FAIL("Could not find location predictions for the label '" + nLabel.ToString() + "'!"); } List <NormalizedBBox> rgBboxes = decode_bboxes[nLabel]; List <int> rgIndexes; m_bboxUtil.ApplyNMSFast(rgBboxes, rgfScores, m_fConfidenceThreshold, m_fNmsThreshold, m_fEta, m_nTopK, out rgIndexes); rgIndices[c] = rgIndexes; nNumDet += rgIndices[c].Count; } if (m_nKeepTopK > -1 && nNumDet > m_nKeepTopK) { List <Tuple <float, Tuple <int, int> > > rgScoreIndexPairs = new List <Tuple <float, Tuple <int, int> > >(); foreach (KeyValuePair <int, List <int> > kv in rgIndices) { int nLabel = kv.Key; List <int> rgLabelIndices = kv.Value; // Something bad happend for the current label. if (!rgConfScores.ContainsKey(nLabel)) { m_log.FAIL("Could not find location predictions for label " + nLabel.ToString() + "!"); } List <float> rgScores = rgConfScores[nLabel]; for (int j = 0; j < rgLabelIndices.Count; j++) { int nIdx = rgLabelIndices[j]; m_log.CHECK_LT(nIdx, rgScores.Count, "The current index must be less than the number of scores!"); rgScoreIndexPairs.Add(new Tuple <float, Tuple <int, int> >(rgScores[nIdx], new Tuple <int, int>(nLabel, nIdx))); } } // Keep top k results per image. rgScoreIndexPairs = rgScoreIndexPairs.OrderByDescending(p => p.Item1).ToList(); if (rgScoreIndexPairs.Count > m_nKeepTopK) { rgScoreIndexPairs = rgScoreIndexPairs.Take(m_nKeepTopK).ToList(); } // Store the new indices. Dictionary <int, List <int> > rgNewIndices = new Dictionary <int, List <int> >(); for (int j = 0; j < rgScoreIndexPairs.Count; j++) { int nLabel = rgScoreIndexPairs[j].Item2.Item1; int nIdx = rgScoreIndexPairs[j].Item2.Item2; if (!rgNewIndices.ContainsKey(nLabel)) { rgNewIndices.Add(nLabel, new List <int>()); } rgNewIndices[nLabel].Add(nIdx); } rgAllIndices.Add(rgNewIndices); nNumKept += m_nKeepTopK; } else { rgAllIndices.Add(rgIndices); nNumKept += nNumDet; } } List <int> rgTopShape = Utility.Create <int>(2, 1); rgTopShape.Add(nNumKept); rgTopShape.Add(7); float[] rgfTopData = null; if (nNumKept == 0) { m_log.WriteLine("WARNING: Could not find any detections."); rgTopShape[2] = nNum; colTop[0].Reshape(rgTopShape); colTop[0].SetData(-1); rgfTopData = convertF(colTop[0].mutable_cpu_data); int nOffset = 0; // Generate fake results per image. for (int i = 0; i < nNum; i++) { rgfTopData[nOffset + 0] = i; nOffset += 7; } } else { colTop[0].Reshape(rgTopShape); rgfTopData = convertF(colTop[0].mutable_cpu_data); } int nCount = 0; string strDir = m_strOutputDir; for (int i = 0; i < nNum; i++) { Dictionary <int, List <float> > rgConfScores = rgAllConfScores[i]; LabelBBox decode_bboxes = rgAllDecodeBboxes[i]; foreach (KeyValuePair <int, List <int> > kv in rgAllIndices[i]) { int nLabel = kv.Key; // Something bad happened if there are no predictions for the current label. if (!rgConfScores.ContainsKey(nLabel)) { m_log.FAIL("Could not find confidence predictions for label '" + nLabel.ToString() + "'!"); } List <float> rgfScores = rgConfScores[nLabel]; int nLocLabel = (m_bShareLocations) ? -1 : nLabel; // Something bad happened if therea re no predictions for the current label. if (!decode_bboxes.Contains(nLocLabel)) { m_log.FAIL("COuld not find location predictions for label '" + nLabel.ToString() + "'!"); } List <NormalizedBBox> rgBboxes = decode_bboxes[nLocLabel]; List <int> rgIndices = kv.Value; if (m_bNeedSave) { m_log.CHECK(m_rgLabelToName.ContainsKey(nLabel), "The label to name mapping does not contain the label '" + nLabel.ToString() + "'!"); m_log.CHECK_LT(m_nNameCount, m_rgstrNames.Count, "The name count must be less than the number of names."); } for (int j = 0; j < rgIndices.Count; j++) { int nIdx = rgIndices[j]; rgfTopData[nCount * 7 + 0] = i; rgfTopData[nCount * 7 + 1] = nLabel; rgfTopData[nCount * 7 + 2] = rgfScores[nIdx]; NormalizedBBox bbox = rgBboxes[nIdx]; rgfTopData[nCount * 7 + 3] = bbox.xmin; rgfTopData[nCount * 7 + 4] = bbox.ymin; rgfTopData[nCount * 7 + 5] = bbox.xmax; rgfTopData[nCount * 7 + 6] = bbox.ymax; if (m_bNeedSave) { NormalizedBBox out_bbox = m_bboxUtil.Output(bbox, m_rgSizes[m_nNameCount], m_resizeParam); float fScore = rgfTopData[nCount * 7 + 2]; float fXmin = out_bbox.xmin; float fYmin = out_bbox.ymin; float fXmax = out_bbox.xmax; float fYmax = out_bbox.ymax; PropertyTree pt_xmin = new PropertyTree(); pt_xmin.Put("", Math.Round(fXmin * 100) / 100); PropertyTree pt_ymin = new PropertyTree(); pt_ymin.Put("", Math.Round(fYmin * 100) / 100); PropertyTree pt_wd = new PropertyTree(); pt_wd.Put("", Math.Round((fXmax - fXmin) * 100) / 100); PropertyTree pt_ht = new PropertyTree(); pt_ht.Put("", Math.Round((fYmax - fYmin) * 100) / 100); PropertyTree cur_bbox = new PropertyTree(); cur_bbox.AddChild("", pt_xmin); cur_bbox.AddChild("", pt_ymin); cur_bbox.AddChild("", pt_wd); cur_bbox.AddChild("", pt_ht); PropertyTree cur_det = new PropertyTree(); cur_det.Put("image_id", m_rgstrNames[m_nNameCount]); if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.ILSVRC) { cur_det.Put("category_id", nLabel); } else { cur_det.Put("category_id", m_rgLabelToName[nLabel]); } cur_det.AddChild("bbox", cur_bbox); cur_det.Put("score", fScore); m_detections.AddChild("", cur_det); } nCount++; } } if (m_bNeedSave) { m_nNameCount++; if (m_nNameCount % m_nNumTestImage == 0) { if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.VOC) { Dictionary <string, StreamWriter> rgOutFiles = new Dictionary <string, StreamWriter>(); for (int c = 0; c < m_nNumClasses; c++) { if (c == m_nBackgroundLabelId) { continue; } string strLabelName = m_rgLabelToName[c]; string strFile = getFileName(strLabelName, "txt"); rgOutFiles.Add(strLabelName, new StreamWriter(strFile)); } foreach (PropertyTree pt in m_detections.Children) { string strLabel = pt.Get("category_id").Value; if (!rgOutFiles.ContainsKey(strLabel)) { m_log.WriteLine("WARNING! Cannot find '" + strLabel + "' label in the output files!"); continue; } string strImageName = pt.Get("image_id").Value; float fScore = (float)pt.Get("score").Numeric; List <int> bbox = new List <int>(); foreach (Property elm in pt.GetChildren("bbox")) { bbox.Add((int)elm.Numeric); } string strLine = strImageName; strLine += " " + fScore.ToString(); strLine += " " + bbox[0].ToString() + " " + bbox[1].ToString(); strLine += " " + (bbox[0] + bbox[2]).ToString(); strLine += " " + (bbox[1] + bbox[3]).ToString(); rgOutFiles[strLabel].WriteLine(strLine); } for (int c = 0; c < m_nNumClasses; c++) { if (c == m_nBackgroundLabelId) { continue; } string strLabel = m_rgLabelToName[c]; rgOutFiles[strLabel].Flush(); rgOutFiles[strLabel].Close(); rgOutFiles[strLabel].Dispose(); } } else if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.COCO) { string strFile = getFileName("", "json"); using (StreamWriter sw = new StreamWriter(strFile)) { PropertyTree output = new PropertyTree(); output.AddChild("detections", m_detections); string strOut = output.ToJson(); sw.Write(strOut); } } else if (m_outputFormat == SaveOutputParameter.OUTPUT_FORMAT.ILSVRC) { string strFile = getFileName("", "txt"); using (StreamWriter sw = new StreamWriter(strFile)) { foreach (PropertyTree pt in m_detections.Children) { int nLabel = (int)pt.Get("category_id").Numeric; string strImageName = pt.Get("image_id").Value; float fScore = (float)pt.Get("score").Numeric; List <int> bbox = new List <int>(); foreach (Property elm in pt.GetChildren("bbox")) { bbox.Add((int)elm.Numeric); } string strLine = strImageName; strLine += " " + fScore.ToString(); strLine += " " + bbox[0].ToString() + " " + bbox[1].ToString(); strLine += " " + (bbox[0] + bbox[2]).ToString(); strLine += " " + (bbox[1] + bbox[3]).ToString(); sw.WriteLine(strLine); } } } m_nNameCount = 0; m_detections.Clear(); } } if (m_bVisualize) { #warning DetectionOutputLayer - does not visualize detections yet. // TBD. } } colTop[0].mutable_cpu_data = convert(rgfTopData); colTop[0].type = BLOB_TYPE.MULTIBBOX; }