예제 #1
0
        /** @copydoc LayerParameterBase::Copy */
        public override void Copy(LayerParameterBase src)
        {
            MultiBoxLossParameter p = src as MultiBoxLossParameter;

            m_locLossType              = p.loc_loss_type;
            m_confLossType             = p.conf_loss_type;
            m_fLocWeight               = p.loc_weight;
            m_nNumClasses              = p.num_classes;
            m_bShareLocation           = p.share_location;
            m_matchType                = p.match_type;
            m_fOverlapThreshold        = p.overlap_threshold;
            m_nBackgroundLabelId       = p.background_label_id;
            m_bUseDifficultGt          = p.use_difficult_gt;
            m_bDoNegMining             = p.do_neg_mining;
            m_fNegPosRatio             = p.neg_pos_ratio;
            m_fNegOverlap              = p.neg_overlap;
            m_codeType                 = p.code_type;
            m_bEncodeVarianceInTarget  = p.encode_variance_in_target;
            m_bMapObjectToAgnostic     = p.map_object_to_agnostic;
            m_bIgnoreCrossBoundaryBbox = p.ignore_cross_boundary_bbox;
            m_bBpInside                = p.bp_inside;
            m_miningType               = p.mining_type;
            m_nmsParam                 = p.nms_param.Clone();
            m_nSampleSize              = p.sample_size;
            m_bUsePriorForNms          = p.use_prior_for_nms;
            m_bUsePriorForMatching     = p.use_prior_for_matching;
        }
        /// <summary>
        /// Copy on parameter to another.
        /// </summary>
        /// <param name="src">Specifies the parameter to copy.</param>
        public override void Copy(LayerParameterBase src)
        {
            DetectionOutputParameter p = (DetectionOutputParameter)src;

            m_nNumClasses              = p.m_nNumClasses;
            m_bShareLocation           = p.m_bShareLocation;
            m_nBackgroundLabelId       = p.m_nBackgroundLabelId;
            m_nmsParam                 = p.m_nmsParam.Clone();
            m_saveOutputParam          = p.save_output_param.Clone() as SaveOutputParameter;
            m_codeType                 = p.m_codeType;
            m_bVarianceEncodedInTarget = p.m_bVarianceEncodedInTarget;
            m_nKeepTopK                = p.m_nKeepTopK;
            m_fConfidenceThreshold     = p.m_fConfidenceThreshold;
            m_bVisualize               = p.m_bVisualize;
            m_fVisualizeThreshold      = p.m_fVisualizeThreshold;
            m_strSaveFile              = p.m_strSaveFile;
        }
예제 #3
0
        /// <summary>
        /// Setup the layer.
        /// </summary>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            m_log.CHECK_GT(m_param.detection_output_param.num_classes, 0, "There must be at least one class specified.");
            m_nNumClasses = (int)m_param.detection_output_param.num_classes;

            m_bShareLocations          = m_param.detection_output_param.share_location;
            m_nNumLocClasses           = (m_bShareLocations) ? 1 : m_nNumClasses;
            m_nBackgroundLabelId       = m_param.detection_output_param.background_label_id;
            m_codeType                 = m_param.detection_output_param.code_type;
            m_bVarianceEncodedInTarget = m_param.detection_output_param.variance_encoded_in_target;
            m_nKeepTopK                = m_param.detection_output_param.keep_top_k;
            m_fConfidenceThreshold     = m_param.detection_output_param.confidence_threshold.GetValueOrDefault(-float.MaxValue);

            // Parameters used in nms.
            m_fNmsThreshold = m_param.detection_output_param.nms_param.nms_threshold;
            m_log.CHECK_GE(m_fNmsThreshold, 0, "The nms_threshold must be non negative.");
            m_fEta = m_param.detection_output_param.nms_param.eta;
            m_log.CHECK_GT(m_fEta, 0, "The nms_param.eta must be > 0.");
            m_log.CHECK_LE(m_fEta, 1, "The nms_param.eta must be < 0.");

            m_nTopK = m_param.detection_output_param.nms_param.top_k.GetValueOrDefault(-1);

            m_strOutputDir = m_param.detection_output_param.save_output_param.output_directory;
            m_bNeedSave    = !string.IsNullOrEmpty(m_strOutputDir);
            if (m_bNeedSave && !Directory.Exists(m_strOutputDir))
            {
                Directory.CreateDirectory(m_strOutputDir);
            }

            m_strOutputNamePrefix = m_param.detection_output_param.save_output_param.output_name_prefix;
            m_outputFormat        = m_param.detection_output_param.save_output_param.output_format;

            if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.label_map_file))
            {
                string strLabelMapFile = m_param.detection_output_param.save_output_param.label_map_file;
                if (!File.Exists(strLabelMapFile))
                {
                    // Ignore saving if there is no label map file.
                    m_log.WriteLine("WARNING: Could not find the label_map_file '" + strLabelMapFile + "'!");
                    m_bNeedSave = false;
                }
                else
                {
                    LabelMap label_map;

                    try
                    {
                        RawProto proto = RawProto.FromFile(strLabelMapFile);
                        label_map = LabelMap.FromProto(proto);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to read label map file!", excpt);
                    }

                    try
                    {
                        m_rgLabelToName = label_map.MapToName(m_log, true, false);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to convert the label to name!", excpt);
                    }

                    try
                    {
                        m_rgLabelToDisplayName = label_map.MapToName(m_log, true, true);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to convert the label to display name!", excpt);
                    }
                }
            }
            else
            {
                m_bNeedSave = false;
            }

            if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.name_size_file))
            {
                string strNameSizeFile = m_param.detection_output_param.save_output_param.name_size_file;
                if (!File.Exists(strNameSizeFile))
                {
                    // Ignore saving if there is no name size file.
                    m_log.WriteLine("WARNING: Could not find the name_size_file '" + strNameSizeFile + "'!");
                    m_bNeedSave = false;
                }
                else
                {
                    using (StreamReader sr = new StreamReader(strNameSizeFile))
                    {
                        string strName;
                        int    nHeight;
                        int    nWidth;

                        string strLine = sr.ReadLine();
                        while (strLine != null)
                        {
                            string[] rgstr = strLine.Split(' ');
                            if (rgstr.Length != 3 && rgstr.Length != 4)
                            {
                                throw new Exception("Invalid name_size_file format, expected 'name' 'height' 'width'");
                            }

                            int nNameIdx = (rgstr.Length == 4) ? 1 : 0;
                            strName = rgstr[nNameIdx].Trim(',');
                            nHeight = int.Parse(rgstr[nNameIdx + 1].Trim(','));
                            nWidth  = int.Parse(rgstr[nNameIdx + 2].Trim(','));

                            m_rgstrNames.Add(strName);
                            m_rgSizes.Add(new SizeF(nWidth, nHeight));

                            strLine = sr.ReadLine();
                        }
                    }

                    if (m_param.detection_output_param.save_output_param.num_test_image.HasValue)
                    {
                        m_nNumTestImage = (int)m_param.detection_output_param.save_output_param.num_test_image.Value;
                    }
                    else
                    {
                        m_nNumTestImage = m_rgstrNames.Count;
                    }

                    m_log.CHECK_LE(m_nNumTestImage, m_rgstrNames.Count, "The number of test images cannot exceed the number of names.");
                }
            }
            else
            {
                m_bNeedSave = false;
            }

            if (m_param.detection_output_param.save_output_param.resize_param != null && m_param.detection_output_param.save_output_param.resize_param.Active)
            {
                m_resizeParam = m_param.detection_output_param.save_output_param.resize_param;
            }

            m_nNameCount = 0;

            m_bVisualize = m_param.detection_output_param.visualize;
            if (m_bVisualize)
            {
                m_fVisualizeThreshold = m_param.detection_output_param.visualize_threshold.GetValueOrDefault(0.6f);
                m_transformer         = new DataTransformer <T>(m_cuda, m_log, m_param.transform_param, m_phase, 0, 0, 0);
                m_transformer.InitRand();
                m_strSaveFile = m_param.detection_output_param.save_file;
            }

            m_blobBboxPreds.ReshapeLike(colBottom[0]);

            if (!m_bShareLocations)
            {
                m_blobBboxPermute.ReshapeLike(colBottom[0]);
            }

            m_blobConfPermute.ReshapeLike(colBottom[1]);
        }