Ejemplo n.º 1
0
        /// <summary>
        /// Implements common data layer setup functionality, and calls
        /// DataLayerSetUp to do special data layer setup for individual layer types.
        /// </summary>
        /// <remarks>
        /// This method may not be overridden except by BasePrefetchingDataLayer.
        /// </remarks>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            if (colTop.Count == 1)
            {
                m_bOutputLabels = false;
            }
            else
            {
                m_bOutputLabels = true;
            }

            m_transformer = new DataTransformer <T>(m_log, m_param.transform_param, m_param.phase, m_imgMean);
            m_transformer.InitRand();

            // The subclasses should setup the size of bottom and top.
            DataLayerSetUp(colBottom, colTop);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Implements common data layer setup functionality, and calls
        /// DataLayerSetUp to do special data layer setup for individual layer types.
        /// </summary>
        /// <remarks>
        /// This method may not be overridden except by BasePrefetchingDataLayer.
        /// </remarks>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            if (colTop.Count == 1)
            {
                m_bOutputLabels = false;
            }
            else
            {
                m_bOutputLabels = true;
            }

            int nC = 0;
            int nH = 0;
            int nW = 0;

            if (m_src != null)
            {
                nC = m_src.ImageChannels;
                nH = m_src.ImageHeight;
                nW = m_src.ImageWidth;
            }
            else if (m_imgMean != null)
            {
                nC = m_imgMean.Channels;
                nH = m_imgMean.Height;
                nW = m_imgMean.Width;
            }
            else if (m_param.type == LayerParameter.LayerType.MEMORYDATA)
            {
                nC = (int)m_param.memory_data_param.channels;
                nH = (int)m_param.memory_data_param.height;
                nW = (int)m_param.memory_data_param.width;
            }
            else if (m_param.type == LayerParameter.LayerType.DUMMYDATA)
            {
                nC = (int)m_param.dummy_data_param.channels[0];
                nH = (int)m_param.dummy_data_param.height[0];
                nW = (int)m_param.dummy_data_param.height[0];
            }
            else if (m_param.type == LayerParameter.LayerType.INPUT)
            {
                if (m_param.input_param.shape[0].dim.Count > 1)
                {
                    nC = (int)m_param.input_param.shape[0].dim[1];
                }

                if (m_param.input_param.shape[0].dim.Count > 2)
                {
                    nH = (int)m_param.input_param.shape[0].dim[2];
                }

                if (m_param.input_param.shape[0].dim.Count > 3)
                {
                    nW = (int)m_param.input_param.shape[0].dim[3];
                }
            }

            if (nC == 0 && nH == 0 && nW == 0)
            {
                throw new Exception("The sizing of C, H, W cannot be zero for all three!");
            }

            if (nC == 0)
            {
                nC = 1;
            }

            if (nH == 0)
            {
                nH = 1;
            }

            if (nW == 0)
            {
                nW = 1;
            }

            m_transformer = new DataTransformer <T>(m_cuda, m_log, m_param.transform_param, m_param.phase, nC, nH, nW, m_imgMean);
            m_transformer.InitRand();

            // The subclasses should setup the size of bottom and top.
            DataLayerSetUp(colBottom, colTop);
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Setup the layer.
        /// </summary>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            m_log.CHECK_GT(m_param.detection_output_param.num_classes, 0, "There must be at least one class specified.");
            m_nNumClasses = (int)m_param.detection_output_param.num_classes;

            m_bShareLocations          = m_param.detection_output_param.share_location;
            m_nNumLocClasses           = (m_bShareLocations) ? 1 : m_nNumClasses;
            m_nBackgroundLabelId       = m_param.detection_output_param.background_label_id;
            m_codeType                 = m_param.detection_output_param.code_type;
            m_bVarianceEncodedInTarget = m_param.detection_output_param.variance_encoded_in_target;
            m_nKeepTopK                = m_param.detection_output_param.keep_top_k;
            m_fConfidenceThreshold     = m_param.detection_output_param.confidence_threshold.GetValueOrDefault(-float.MaxValue);

            // Parameters used in nms.
            m_fNmsThreshold = m_param.detection_output_param.nms_param.nms_threshold;
            m_log.CHECK_GE(m_fNmsThreshold, 0, "The nms_threshold must be non negative.");
            m_fEta = m_param.detection_output_param.nms_param.eta;
            m_log.CHECK_GT(m_fEta, 0, "The nms_param.eta must be > 0.");
            m_log.CHECK_LE(m_fEta, 1, "The nms_param.eta must be < 0.");

            m_nTopK = m_param.detection_output_param.nms_param.top_k.GetValueOrDefault(-1);

            m_strOutputDir = m_param.detection_output_param.save_output_param.output_directory;
            m_bNeedSave    = !string.IsNullOrEmpty(m_strOutputDir);
            if (m_bNeedSave && !Directory.Exists(m_strOutputDir))
            {
                Directory.CreateDirectory(m_strOutputDir);
            }

            m_strOutputNamePrefix = m_param.detection_output_param.save_output_param.output_name_prefix;
            m_outputFormat        = m_param.detection_output_param.save_output_param.output_format;

            if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.label_map_file))
            {
                string strLabelMapFile = m_param.detection_output_param.save_output_param.label_map_file;
                if (!File.Exists(strLabelMapFile))
                {
                    // Ignore saving if there is no label map file.
                    m_log.WriteLine("WARNING: Could not find the label_map_file '" + strLabelMapFile + "'!");
                    m_bNeedSave = false;
                }
                else
                {
                    LabelMap label_map;

                    try
                    {
                        RawProto proto = RawProto.FromFile(strLabelMapFile);
                        label_map = LabelMap.FromProto(proto);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to read label map file!", excpt);
                    }

                    try
                    {
                        m_rgLabelToName = label_map.MapToName(m_log, true, false);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to convert the label to name!", excpt);
                    }

                    try
                    {
                        m_rgLabelToDisplayName = label_map.MapToName(m_log, true, true);
                    }
                    catch (Exception excpt)
                    {
                        throw new Exception("Failed to convert the label to display name!", excpt);
                    }
                }
            }
            else
            {
                m_bNeedSave = false;
            }

            if (!string.IsNullOrEmpty(m_param.detection_output_param.save_output_param.name_size_file))
            {
                string strNameSizeFile = m_param.detection_output_param.save_output_param.name_size_file;
                if (!File.Exists(strNameSizeFile))
                {
                    // Ignore saving if there is no name size file.
                    m_log.WriteLine("WARNING: Could not find the name_size_file '" + strNameSizeFile + "'!");
                    m_bNeedSave = false;
                }
                else
                {
                    using (StreamReader sr = new StreamReader(strNameSizeFile))
                    {
                        string strName;
                        int    nHeight;
                        int    nWidth;

                        string strLine = sr.ReadLine();
                        while (strLine != null)
                        {
                            string[] rgstr = strLine.Split(' ');
                            if (rgstr.Length != 3 && rgstr.Length != 4)
                            {
                                throw new Exception("Invalid name_size_file format, expected 'name' 'height' 'width'");
                            }

                            int nNameIdx = (rgstr.Length == 4) ? 1 : 0;
                            strName = rgstr[nNameIdx].Trim(',');
                            nHeight = int.Parse(rgstr[nNameIdx + 1].Trim(','));
                            nWidth  = int.Parse(rgstr[nNameIdx + 2].Trim(','));

                            m_rgstrNames.Add(strName);
                            m_rgSizes.Add(new SizeF(nWidth, nHeight));

                            strLine = sr.ReadLine();
                        }
                    }

                    if (m_param.detection_output_param.save_output_param.num_test_image.HasValue)
                    {
                        m_nNumTestImage = (int)m_param.detection_output_param.save_output_param.num_test_image.Value;
                    }
                    else
                    {
                        m_nNumTestImage = m_rgstrNames.Count;
                    }

                    m_log.CHECK_LE(m_nNumTestImage, m_rgstrNames.Count, "The number of test images cannot exceed the number of names.");
                }
            }
            else
            {
                m_bNeedSave = false;
            }

            if (m_param.detection_output_param.save_output_param.resize_param != null && m_param.detection_output_param.save_output_param.resize_param.Active)
            {
                m_resizeParam = m_param.detection_output_param.save_output_param.resize_param;
            }

            m_nNameCount = 0;

            m_bVisualize = m_param.detection_output_param.visualize;
            if (m_bVisualize)
            {
                m_fVisualizeThreshold = m_param.detection_output_param.visualize_threshold.GetValueOrDefault(0.6f);
                m_transformer         = new DataTransformer <T>(m_cuda, m_log, m_param.transform_param, m_phase, 0, 0, 0);
                m_transformer.InitRand();
                m_strSaveFile = m_param.detection_output_param.save_file;
            }

            m_blobBboxPreds.ReshapeLike(colBottom[0]);

            if (!m_bShareLocations)
            {
                m_blobBboxPermute.ReshapeLike(colBottom[0]);
            }

            m_blobConfPermute.ReshapeLike(colBottom[1]);
        }