/** @copydoc LayerParameterBase::Copy */ public override void Copy(LayerParameterBase src) { MultiBoxLossParameter p = src as MultiBoxLossParameter; m_locLossType = p.loc_loss_type; m_confLossType = p.conf_loss_type; m_fLocWeight = p.loc_weight; m_nNumClasses = p.num_classes; m_bShareLocation = p.share_location; m_matchType = p.match_type; m_fOverlapThreshold = p.overlap_threshold; m_nBackgroundLabelId = p.background_label_id; m_bUseDifficultGt = p.use_difficult_gt; m_bDoNegMining = p.do_neg_mining; m_fNegPosRatio = p.neg_pos_ratio; m_fNegOverlap = p.neg_overlap; m_codeType = p.code_type; m_bEncodeVarianceInTarget = p.encode_variance_in_target; m_bMapObjectToAgnostic = p.map_object_to_agnostic; m_bIgnoreCrossBoundaryBbox = p.ignore_cross_boundary_bbox; m_bBpInside = p.bp_inside; m_miningType = p.mining_type; m_nmsParam = p.nms_param.Clone(); m_nSampleSize = p.sample_size; m_bUsePriorForNms = p.use_prior_for_nms; m_bUsePriorForMatching = p.use_prior_for_matching; }
/// <summary> /// Copy on parameter to another. /// </summary> /// <param name="src">Specifies the parameter to copy.</param> public override void Copy(LayerParameterBase src) { DetectionOutputParameter p = (DetectionOutputParameter)src; m_nNumClasses = p.m_nNumClasses; m_bShareLocation = p.m_bShareLocation; m_nBackgroundLabelId = p.m_nBackgroundLabelId; m_nmsParam = p.m_nmsParam.Clone(); m_saveOutputParam = p.save_output_param.Clone() as SaveOutputParameter; m_codeType = p.m_codeType; m_bVarianceEncodedInTarget = p.m_bVarianceEncodedInTarget; m_nKeepTopK = p.m_nKeepTopK; m_fConfidenceThreshold = p.m_fConfidenceThreshold; m_bVisualize = p.m_bVisualize; m_fVisualizeThreshold = p.m_fVisualizeThreshold; m_strSaveFile = p.m_strSaveFile; }
/// <summary> /// Setup the layer. /// </summary> /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param> /// <param name="colTop">Specifies the collection of top (output) Blobs.</param> public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop) { m_log.CHECK_GT(m_param.detection_output_param.num_classes, 0, "There must be at least one class specified."); m_nNumClasses = (int)m_param.detection_output_param.num_classes; m_bShareLocations = m_param.detection_output_param.share_location; m_nNumLocClasses = (m_bShareLocations) ? 1 : m_nNumClasses; m_nBackgroundLabelId = m_param.detection_output_param.background_label_id; m_codeType = m_param.detection_output_param.code_type; m_bVarianceEncodedInTarget = m_param.detection_output_param.variance_encoded_in_target; m_nKeepTopK = m_param.detection_output_param.keep_top_k; m_fConfidenceThreshold = m_param.detection_output_param.confidence_threshold.GetValueOrDefault(-float.MaxValue); // Parameters used in nms. m_fNmsThreshold = m_param.detection_output_param.nms_param.nms_threshold; m_log.CHECK_GE(m_fNmsThreshold, 0, "The nms_threshold must be non negative."); m_fEta = m_param.detection_output_param.nms_param.eta; m_log.CHECK_GT(m_fEta, 0, "The nms_param.eta must be > 0."); m_log.CHECK_LE(m_fEta, 1, "The nms_param.eta must be < 0."); m_nTopK = m_param.detection_output_param.nms_param.top_k.GetValueOrDefault(-1); m_strOutputDir = m_param.detection_output_param.save_output_param.output_directory; m_bNeedSave = !string.IsNullOrEmpty(m_strOutputDir); if (m_bNeedSave && !Directory.Exists(m_strOutputDir)) { Directory.CreateDirectory(m_strOutputDir); } m_strOutputNamePrefix = m_param.detection_output_param.save_output_param.output_name_prefix; m_outputFormat = m_param.detection_output_param.save_output_param.output_format; if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.label_map_file)) { string strLabelMapFile = m_param.detection_output_param.save_output_param.label_map_file; if (!File.Exists(strLabelMapFile)) { // Ignore saving if there is no label map file. m_log.WriteLine("WARNING: Could not find the label_map_file '" + strLabelMapFile + "'!"); m_bNeedSave = false; } else { LabelMap label_map; try { RawProto proto = RawProto.FromFile(strLabelMapFile); label_map = LabelMap.FromProto(proto); } catch (Exception excpt) { throw new Exception("Failed to read label map file!", excpt); } try { m_rgLabelToName = label_map.MapToName(m_log, true, false); } catch (Exception excpt) { throw new Exception("Failed to convert the label to name!", excpt); } try { m_rgLabelToDisplayName = label_map.MapToName(m_log, true, true); } catch (Exception excpt) { throw new Exception("Failed to convert the label to display name!", excpt); } } } else { m_bNeedSave = false; } if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.name_size_file)) { string strNameSizeFile = m_param.detection_output_param.save_output_param.name_size_file; if (!File.Exists(strNameSizeFile)) { // Ignore saving if there is no name size file. m_log.WriteLine("WARNING: Could not find the name_size_file '" + strNameSizeFile + "'!"); m_bNeedSave = false; } else { using (StreamReader sr = new StreamReader(strNameSizeFile)) { string strName; int nHeight; int nWidth; string strLine = sr.ReadLine(); while (strLine != null) { string[] rgstr = strLine.Split(' '); if (rgstr.Length != 3 && rgstr.Length != 4) { throw new Exception("Invalid name_size_file format, expected 'name' 'height' 'width'"); } int nNameIdx = (rgstr.Length == 4) ? 1 : 0; strName = rgstr[nNameIdx].Trim(','); nHeight = int.Parse(rgstr[nNameIdx + 1].Trim(',')); nWidth = int.Parse(rgstr[nNameIdx + 2].Trim(',')); m_rgstrNames.Add(strName); m_rgSizes.Add(new SizeF(nWidth, nHeight)); strLine = sr.ReadLine(); } } if (m_param.detection_output_param.save_output_param.num_test_image.HasValue) { m_nNumTestImage = (int)m_param.detection_output_param.save_output_param.num_test_image.Value; } else { m_nNumTestImage = m_rgstrNames.Count; } m_log.CHECK_LE(m_nNumTestImage, m_rgstrNames.Count, "The number of test images cannot exceed the number of names."); } } else { m_bNeedSave = false; } if (m_param.detection_output_param.save_output_param.resize_param != null && m_param.detection_output_param.save_output_param.resize_param.Active) { m_resizeParam = m_param.detection_output_param.save_output_param.resize_param; } m_nNameCount = 0; m_bVisualize = m_param.detection_output_param.visualize; if (m_bVisualize) { m_fVisualizeThreshold = m_param.detection_output_param.visualize_threshold.GetValueOrDefault(0.6f); m_transformer = new DataTransformer <T>(m_cuda, m_log, m_param.transform_param, m_phase, 0, 0, 0); m_transformer.InitRand(); m_strSaveFile = m_param.detection_output_param.save_file; } m_blobBboxPreds.ReshapeLike(colBottom[0]); if (!m_bShareLocations) { m_blobBboxPermute.ReshapeLike(colBottom[0]); } m_blobConfPermute.ReshapeLike(colBottom[1]); }