private LabelMap loadLabelMap() { string strFile = getDataFile(dataset_name, "labelmap_voc.prototxt"); RawProto proto = RawProtoFile.LoadFromFile(strFile); return(LabelMap.FromProto(proto)); }
/// <summary> /// Setup the layer. /// </summary> /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param> /// <param name="colTop">Specifies the collection of top (output) Blobs.</param> public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop) { m_log.CHECK_GT(m_param.detection_output_param.num_classes, 0, "There must be at least one class specified."); m_nNumClasses = (int)m_param.detection_output_param.num_classes; m_bShareLocations = m_param.detection_output_param.share_location; m_nNumLocClasses = (m_bShareLocations) ? 1 : m_nNumClasses; m_nBackgroundLabelId = m_param.detection_output_param.background_label_id; m_codeType = m_param.detection_output_param.code_type; m_bVarianceEncodedInTarget = m_param.detection_output_param.variance_encoded_in_target; m_nKeepTopK = m_param.detection_output_param.keep_top_k; m_fConfidenceThreshold = m_param.detection_output_param.confidence_threshold.GetValueOrDefault(-float.MaxValue); // Parameters used in nms. m_fNmsThreshold = m_param.detection_output_param.nms_param.nms_threshold; m_log.CHECK_GE(m_fNmsThreshold, 0, "The nms_threshold must be non negative."); m_fEta = m_param.detection_output_param.nms_param.eta; m_log.CHECK_GT(m_fEta, 0, "The nms_param.eta must be > 0."); m_log.CHECK_LE(m_fEta, 1, "The nms_param.eta must be < 0."); m_nTopK = m_param.detection_output_param.nms_param.top_k.GetValueOrDefault(-1); m_strOutputDir = m_param.detection_output_param.save_output_param.output_directory; m_bNeedSave = !string.IsNullOrEmpty(m_strOutputDir); if (m_bNeedSave && !Directory.Exists(m_strOutputDir)) { Directory.CreateDirectory(m_strOutputDir); } m_strOutputNamePrefix = m_param.detection_output_param.save_output_param.output_name_prefix; m_outputFormat = m_param.detection_output_param.save_output_param.output_format; if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.label_map_file)) { string strLabelMapFile = m_param.detection_output_param.save_output_param.label_map_file; if (!File.Exists(strLabelMapFile)) { // Ignore saving if there is no label map file. m_log.WriteLine("WARNING: Could not find the label_map_file '" + strLabelMapFile + "'!"); m_bNeedSave = false; } else { LabelMap label_map; try { RawProto proto = RawProto.FromFile(strLabelMapFile); label_map = LabelMap.FromProto(proto); } catch (Exception excpt) { throw new Exception("Failed to read label map file!", excpt); } try { m_rgLabelToName = label_map.MapToName(m_log, true, false); } catch (Exception excpt) { throw new Exception("Failed to convert the label to name!", excpt); } try { m_rgLabelToDisplayName = label_map.MapToName(m_log, true, true); } catch (Exception excpt) { throw new Exception("Failed to convert the label to display name!", excpt); } } } else { m_bNeedSave = false; } if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.name_size_file)) { string strNameSizeFile = m_param.detection_output_param.save_output_param.name_size_file; if (!File.Exists(strNameSizeFile)) { // Ignore saving if there is no name size file. m_log.WriteLine("WARNING: Could not find the name_size_file '" + strNameSizeFile + "'!"); m_bNeedSave = false; } else { using (StreamReader sr = new StreamReader(strNameSizeFile)) { string strName; int nHeight; int nWidth; string strLine = sr.ReadLine(); while (strLine != null) { string[] rgstr = strLine.Split(' '); if (rgstr.Length != 3 && rgstr.Length != 4) { throw new Exception("Invalid name_size_file format, expected 'name' 'height' 'width'"); } int nNameIdx = (rgstr.Length == 4) ? 1 : 0; strName = rgstr[nNameIdx].Trim(','); nHeight = int.Parse(rgstr[nNameIdx + 1].Trim(',')); nWidth = int.Parse(rgstr[nNameIdx + 2].Trim(',')); m_rgstrNames.Add(strName); m_rgSizes.Add(new SizeF(nWidth, nHeight)); strLine = sr.ReadLine(); } } if (m_param.detection_output_param.save_output_param.num_test_image.HasValue) { m_nNumTestImage = (int)m_param.detection_output_param.save_output_param.num_test_image.Value; } else { m_nNumTestImage = m_rgstrNames.Count; } m_log.CHECK_LE(m_nNumTestImage, m_rgstrNames.Count, "The number of test images cannot exceed the number of names."); } } else { m_bNeedSave = false; } if (m_param.detection_output_param.save_output_param.resize_param != null && m_param.detection_output_param.save_output_param.resize_param.Active) { m_resizeParam = m_param.detection_output_param.save_output_param.resize_param; } m_nNameCount = 0; m_bVisualize = m_param.detection_output_param.visualize; if (m_bVisualize) { m_fVisualizeThreshold = m_param.detection_output_param.visualize_threshold.GetValueOrDefault(0.6f); m_transformer = new DataTransformer <T>(m_cuda, m_log, m_param.transform_param, m_phase, 0, 0, 0); m_transformer.InitRand(); m_strSaveFile = m_param.detection_output_param.save_file; } m_blobBboxPreds.ReshapeLike(colBottom[0]); if (!m_bShareLocations) { m_blobBboxPermute.ReshapeLike(colBottom[0]); } m_blobConfPermute.ReshapeLike(colBottom[1]); }