Beispiel #1
0
        /// <summary>
        /// Load a batch of data in the background (this is run on an internal thread within the BasePrefetchingDataLayer class).
        /// </summary>
        /// <param name="batch">Specifies the Batch of data to load.</param>
        protected override void load_batch(Batch <T> batch)
        {
            m_log.CHECK(batch.Data.count() > 0, "There is no space allocated for data!");
            int nBatchSize = (int)m_param.data_param.batch_size;

            T[] rgTopLabel = null;

            if (m_bOutputLabels)
            {
                rgTopLabel = batch.Label.mutable_cpu_data;
            }

            if (m_param.data_param.display_timing)
            {
                m_swTimerBatch.Restart();
                m_dfReadTime  = 0;
                m_dfTransTime = 0;
            }

            Datum      datum;
            int        nDim           = 0;
            List <int> rgLabels       = new List <int>();
            List <int> rgTargetLabels = null;

            // If we are synced with another dataset, wait for it to load the initial data set.
            if (m_param.data_param.synchronize_target)
            {
                int nWait = m_rgBatchLabels.WaitReady;
                if (nWait == 0)
                {
                    return;
                }

                rgTargetLabels = m_rgBatchLabels.Get();
                m_log.CHECK_EQ(nBatchSize, m_rgBatchLabels.Count, "The batch label count (previously loaded by the primary dataset) does not match the batch size '" + m_param.data_param.batch_size.ToString() + "' of this layer!");
            }

            for (int i = 0; i < nBatchSize; i++)
            {
                if (m_param.data_param.display_timing)
                {
                    m_swTimerTransaction.Restart();
                }

                while (Skip())
                {
                    Next();
                }

                if (rgTargetLabels == null)
                {
                    datum = m_cursor.GetValue();
                }
                else
                {
                    datum = m_cursor.GetValue(rgTargetLabels[i]);
                }

                if (m_param.data_param.display_timing)
                {
                    m_dfReadTime += m_swTimerTransaction.Elapsed.TotalMilliseconds;
                    m_swTimerTransaction.Restart();
                }

                if (i == 0)
                {
                    // Reshape according to the first datum of each batch
                    // on single input batches allows for inputs of varying dimension.
                    // Use data transformer to infer the expected blob shape for datum.
                    List <int> rgTopShape = m_transformer.InferBlobShape(datum);

                    // Reshape batch according to the batch size.
                    rgTopShape[0] = nBatchSize;
                    batch.Data.Reshape(rgTopShape);

                    nDim = 1;
                    for (int k = 1; k < rgTopShape.Count; k++)
                    {
                        nDim *= rgTopShape[k];
                    }

                    int nTopLen = nDim * nBatchSize;
                    if (m_rgTopData == null || m_rgTopData.Length != nTopLen)
                    {
                        m_rgTopData = new T[nTopLen];
                    }
                }

                // Apply data transformations (mirrow, scaling, crop, etc)
                T[] rgTrans = m_transformer.Transform(datum);
                Array.Copy(rgTrans, 0, m_rgTopData, nDim * i, nDim);

                // Copy label.
                if (m_bOutputLabels)
                {
                    if (m_param.data_param.load_multiple_labels)
                    {
                        if (datum.DataCriteria == null || datum.DataCriteria.Length == 0)
                        {
                            m_log.FAIL("Could not find the multi-label data.  The data source '" + m_param.data_param.source + "' does not appear to have any Image Criteria data.");
                        }

                        // Get the number of items and the item size from the end of the data.
                        int nLen      = BitConverter.ToInt32(datum.DataCriteria, datum.DataCriteria.Length - (sizeof(int) * 4));
                        int nItemSize = BitConverter.ToInt32(datum.DataCriteria, datum.DataCriteria.Length - (sizeof(int) * 3));
                        int nDstIdx   = i * nLen;

                        m_log.CHECK_EQ(nItemSize, 1, "Currently only byte sized labels are supported in multi-label scenarios.");
                        Array.Copy(datum.DataCriteria, 0, rgTopLabel, nDstIdx, nLen);
                    }
                    else
                    {
                        rgTopLabel[i] = (T)Convert.ChangeType(datum.Label, typeof(T));
                    }
                }

                if (m_param.data_param.display_timing)
                {
                    m_dfTransTime += m_swTimerTransaction.Elapsed.TotalMilliseconds;
                }

                rgLabels.Add(datum.Label);

                Next();
            }

            batch.Data.SetCPUData(m_rgTopData);

            if (m_bOutputLabels)
            {
                batch.Label.SetCPUData(rgTopLabel);
            }

            if (m_param.data_param.display_timing)
            {
                m_swTimerBatch.Stop();
                m_swTimerTransaction.Stop();
                m_log.WriteLine("Prefetch batch: " + m_swTimerBatch.ElapsedMilliseconds.ToString() + " ms.", true);
                m_log.WriteLine("     Read time: " + m_dfReadTime.ToString() + " ms.", true);
                m_log.WriteLine("Transform time: " + m_dfTransTime.ToString() + " ms.", true);
            }

            if (m_param.data_param.synchronize_target)
            {
                m_rgBatchLabels.Done();
            }

            if (OnBatchLoad != null)
            {
                OnBatchLoad(this, new LastBatchLoadedArgs(rgLabels));
            }
        }
Beispiel #2
0
        /// <summary>
        /// Load a batch of data in the background (this is run on an internal thread within the BasePrefetchingDataLayer class).
        /// </summary>
        /// <param name="batch">Specifies the Batch of data to load.</param>
        protected override void load_batch(Batch <T> batch)
        {
            m_log.CHECK(batch.Data.count() > 0, "There is no space allocated for data!");
            int  nBatchSize        = (int)m_param.data_param.batch_size;
            bool bLoadDataCriteria = false;

            if (m_bOutputLabels && m_param.data_param.label_type == DataParameter.LABEL_TYPE.MULTIPLE)
            {
                bLoadDataCriteria = true;
            }

            if (m_bOutputLabels)
            {
                int nCount = batch.Label.count();
                m_log.CHECK_GT(nCount, 0, "The label count cannot be zero!");

                if (m_rgTopLabel == null || m_rgTopLabel.Length < nCount)
                {
                    m_rgTopLabel = new T[nCount];
                }
            }

            if (m_param.data_param.display_timing)
            {
                m_swTimerBatch.Restart();
                m_dfReadTime  = 0;
                m_dfTransTime = 0;
            }

            SimpleDatum datum;
            int         nDim           = 0;
            List <int>  rgLabels       = null;
            List <int>  rgTargetLabels = null;

            if (OnBatchLoad != null)
            {
                rgLabels = new List <int>();
            }

            // If we are synced with another dataset, wait for it to load the initial data set.
            if (m_param.data_param.synchronize_target)
            {
                m_log.CHECK_EQ(m_param.data_param.images_per_blob, 1, "DataLayer synchronize targets are not supported when loading more than 1 image per blob.");

                int nWait = m_rgBatchLabels.WaitReady;
                if (nWait == 0)
                {
                    return;
                }

                rgTargetLabels = m_rgBatchLabels.Get();
                m_log.CHECK_EQ(nBatchSize, m_rgBatchLabels.Count, "The batch label count (previously loaded by the primary dataset) does not match the batch size '" + m_param.data_param.batch_size.ToString() + "' of this layer!");
            }

            for (int i = 0; i < nBatchSize; i++)
            {
                if (m_param.data_param.display_timing)
                {
                    m_swTimerTransaction.Restart();
                }

                while (Skip())
                {
                    if (m_evtCancel.WaitOne(0))
                    {
                        return;
                    }
                    Next();
                }

                if (rgTargetLabels == null)
                {
                    datum = m_cursor.GetValue(null, bLoadDataCriteria);

                    if (m_param.data_param.images_per_blob > 1)
                    {
                        if (m_rgDatum == null || m_rgDatum.Length != m_param.data_param.images_per_blob - 1)
                        {
                            m_rgDatum = new SimpleDatum[m_param.data_param.images_per_blob - 1];
                        }

                        for (int j = 0; j < m_param.data_param.images_per_blob - 1; j++)
                        {
                            Next();

                            while (Skip())
                            {
                                if (m_evtCancel.WaitOne(0))
                                {
                                    return;
                                }
                                Next();
                            }

                            if (m_param.data_param.balance_matches)
                            {
                                if (m_bMatchingCycle)
                                {
                                    m_rgDatum[j] = getNextPair(true, datum, bLoadDataCriteria);
                                }
                                else
                                {
                                    if (m_param.data_param.enable_noise_for_nonmatch)
                                    {
                                        m_rgDatum[j] = m_datumNoise;
                                    }
                                    else
                                    {
                                        m_rgDatum[j] = getNextPair(false, datum, bLoadDataCriteria);
                                    }
                                }
                            }
                            else
                            {
                                if (m_param.data_param.enable_noise_for_nonmatch)
                                {
                                    m_rgDatum[j] = m_datumNoise;
                                }
                                else
                                {
                                    m_rgDatum[j] = m_cursor.GetValue(null, bLoadDataCriteria);
                                }
                            }
                        }

                        m_bMatchingCycle = !m_bMatchingCycle;
                    }
                }
                else
                {
                    datum = m_cursor.GetValue(rgTargetLabels[i], bLoadDataCriteria);
                }

                // When debug output is enabled, output information each image loaded.
                if (m_param.data_param.enable_debug_output)
                {
                    saveImageInfo(m_param.data_param.data_debug_param, datum, i, 0);

                    if (m_rgDatum != null)
                    {
                        for (int n = 0; n < m_rgDatum.Length; n++)
                        {
                            saveImageInfo(m_param.data_param.data_debug_param, m_rgDatum[n], i, n + 1);
                        }
                    }
                }

                if (m_param.data_param.display_timing)
                {
                    m_dfReadTime += m_swTimerTransaction.Elapsed.TotalMilliseconds;
                    m_swTimerTransaction.Restart();
                }

                if (i == 0)
                {
                    // Reshape according to the first datum of each batch
                    // on single input batches allows for inputs of varying dimension.
                    // Use data transformer to infer the expected blob shape for datum.
                    m_rgTopShape = m_transformer.InferBlobShape(datum, m_rgTopShape);

                    // Double the channels when loading image pairs where the first image is loaded followed by the second on the channel.
                    if (m_rgDatum != null)
                    {
                        m_rgTopShape[1] *= (m_rgDatum.Length + 1);
                    }

                    // Reshape batch according to the batch size.
                    m_rgTopShape[0] = nBatchSize;
                    batch.Data.Reshape(m_rgTopShape);

                    nDim = 1;
                    for (int k = 1; k < m_rgTopShape.Length; k++)
                    {
                        nDim *= m_rgTopShape[k];
                    }

                    int nTopLen = nDim * nBatchSize;
                    if (m_rgTopData == null || m_rgTopData.Length != nTopLen)
                    {
                        m_rgTopData = new T[nTopLen];
                    }
                }

                // Apply data transformations (mirrow, scaling, crop, etc)
                int nDimCount = nDim;

                if (m_rgDatum != null)
                {
                    nDimCount /= (m_rgDatum.Length + 1);
                }

                T[] rgTrans = m_transformer.Transform(datum);
                Array.Copy(rgTrans, 0, m_rgTopData, nDim * i, nDimCount);

                // When using load_image_pairs, stack the additional images right after the first.
                if (m_rgDatum != null)
                {
                    for (int j = 0; j < m_rgDatum.Length; j++)
                    {
                        rgTrans = m_transformer.Transform(m_rgDatum[j]);
                        int nOffset = (nDim * i) + (nDimCount * (j + 1));
                        Array.Copy(rgTrans, 0, m_rgTopData, nOffset, nDimCount);
                    }
                }

                // Copy label.
                if (m_bOutputLabels)
                {
                    if (m_param.data_param.label_type == DataParameter.LABEL_TYPE.MULTIPLE)
                    {
                        if (m_param.data_param.images_per_blob > 1)
                        {
                            m_log.FAIL("Loading image pairs (images_per_blob > 1) currently only supports the " + DataParameter.LABEL_TYPE.SINGLE.ToString() + " label type.");
                        }

                        if (m_param.transform_param.label_mapping.Active)
                        {
                            m_log.FAIL("Label mapping is not supported on labels of type 'MULTIPLE'.");
                        }

                        if (datum.DataCriteria == null || datum.DataCriteria.Length == 0)
                        {
                            m_log.FAIL("Could not find the multi-label data.  The data source '" + m_param.data_param.source + "' does not appear to have any Image Criteria data.");
                        }

                        // Get the number of items and the item size from the end of the data.
                        int nLen      = BitConverter.ToInt32(datum.DataCriteria, datum.DataCriteria.Length - (sizeof(int) * 4));
                        int nItemSize = BitConverter.ToInt32(datum.DataCriteria, datum.DataCriteria.Length - (sizeof(int) * 3));
                        int nDstIdx   = i * nLen;

                        m_log.CHECK_EQ(nItemSize, 1, "Currently only byte sized labels are supported in multi-label scenarios.");
                        Array.Copy(datum.DataCriteria, 0, m_rgTopLabel, nDstIdx, nLen);
                    }
                    else
                    {
                        // When using image pairs, the label is set to 1 when the labels are the same and 0 when they are different.
                        if (m_rgDatum != null)
                        {
                            if (m_rgDatum.Length == 1)
                            {
                                int nLabelDim = 1;

                                if (m_param.data_param.output_all_labels)
                                {
                                    nLabelDim = m_param.data_param.images_per_blob;
                                }

                                if (m_param.data_param.output_all_labels)
                                {
                                    int nLabel = datum.Label;
                                    if (m_param.data_param.forced_primary_label >= 0)
                                    {
                                        nLabel = m_param.data_param.forced_primary_label;
                                    }

                                    m_rgTopLabel[i * nLabelDim] = (T)Convert.ChangeType(nLabel, typeof(T));

                                    for (int j = 0; j < m_rgDatum.Length; j++)
                                    {
                                        m_rgTopLabel[i * nLabelDim + 1 + j] = (T)Convert.ChangeType(m_rgDatum[j].Label, typeof(T));
                                    }
                                }
                                else
                                {
                                    if (datum.Label == m_rgDatum[0].Label)
                                    {
                                        m_rgTopLabel[i * nLabelDim] = m_tOne;
                                    }
                                    else
                                    {
                                        m_rgTopLabel[i * nLabelDim] = m_tZero;
                                    }
                                }
                            }
                            else
                            {
                                m_log.FAIL("Currently image pairing only supports up to 2 images per blob.");
                            }
                        }
                        else
                        {
                            m_rgTopLabel[i] = (T)Convert.ChangeType(datum.Label, typeof(T));
                        }
                    }
                }

                if (m_param.data_param.display_timing)
                {
                    m_dfTransTime += m_swTimerTransaction.Elapsed.TotalMilliseconds;
                }

                if (rgLabels != null)
                {
                    rgLabels.Add(datum.Label);
                }

                Next();

                if (m_evtCancel.WaitOne(0))
                {
                    return;
                }
            }

            m_nBatchCount++;
            batch.Data.SetCPUData(m_rgTopData);

            if (m_bOutputLabels)
            {
                batch.Label.SetCPUData(m_rgTopLabel);
            }

            if (m_param.data_param.display_timing)
            {
                m_swTimerBatch.Stop();
                m_swTimerTransaction.Stop();
                m_log.WriteLine("Prefetch batch: " + m_swTimerBatch.ElapsedMilliseconds.ToString() + " ms.", true);
                m_log.WriteLine("     Read time: " + m_dfReadTime.ToString() + " ms.", true);
                m_log.WriteLine("Transform time: " + m_dfTransTime.ToString() + " ms.", true);
            }

            if (m_param.data_param.synchronize_target)
            {
                if (m_rgBatchLabels != null)
                {
                    m_rgBatchLabels.Done();
                }
            }

            if (OnBatchLoad != null)
            {
                OnBatchLoad(this, new LastBatchLoadedArgs(rgLabels));
            }
        }