/// <summary>
        /// Reshape the bottom (input) and top (output) blobs.
        /// </summary>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void Reshape(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            base.Reshape(colBottom, colTop);
            m_nOuterNum = colBottom[0].shape(0); // batch size
            m_nInnerNum = colBottom[0].count(1); // instance size: |output| == |target|

            if (colBottom[0].count() != colBottom[1].count())
            {
                if (colBottom[1].count() != colBottom[0].num)
                {
                    m_log.FAIL("SOFTMAX_CROSS_ENTROPY_LOSS layer inputs must have the same count, or the target must have 'num' items of indexes.");
                }

                // Set the label at the target index = 1.0
                if (m_blobTarget == null)
                {
                    m_blobTarget      = new Blob <T>(m_cuda, m_log);
                    m_blobTarget.Name = "full_label";
                }

                m_blobTarget.ReshapeLike(colBottom[0]);
            }

            m_softmaxLayer.Reshape(m_colSoftmaxBottomVec, m_colSoftmaxTopVec);
            m_blobLoss.ReshapeLike(colBottom[0]);
        }
Пример #2
0
 /// <summary>
 /// Reshape the bottom (input) and top (output) blobs.
 /// </summary>
 /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
 /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
 public override void Reshape(BlobCollection <T> colBottom, BlobCollection <T> colTop)
 {
     base.Reshape(colBottom, colTop);
     m_nOuterNum = colBottom[0].shape(0); // batch size
     m_nInnerNum = colBottom[0].count(1); // instance size: |output| == |target|
     m_log.CHECK_EQ(colBottom[0].count(), colBottom[1].count(), "SOFTMAX_CROSS_ENTROPY_LOSS layer inputs must have the same count.");
     m_softmaxLayer.Reshape(m_colSoftmaxBottomVec, m_colSoftmaxTopVec);
     m_blobLoss.ReshapeLike(colBottom[0]);
 }
Пример #3
0
        /// <summary>
        /// Reshape the bottom (input) and top (output) blobs.
        /// </summary>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void Reshape(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            base.Reshape(colBottom, colTop);

            m_softmaxLayer.Reshape(m_colSoftmaxBottomVec, m_colSoftmaxTopVec);
            m_nInfogainAxis = colBottom[0].CanonicalAxisIndex(m_param.infogain_loss_param.axis);
            m_nOuterNum     = colBottom[0].count(0, m_nInfogainAxis);
            m_nInnerNum     = colBottom[0].count(m_nInfogainAxis + 1);

            m_log.CHECK_EQ(m_nOuterNum * m_nInnerNum, colBottom[1].count(), "Number of labels must match the number of predictions; e.g., if infogain_axis == 1 and predictions shape is (N, C, H, W), label count (number of labels) must be N*H*W, with integer values in {0, 1, ..., C-1}.");
            m_nNumLabels = colBottom[0].shape(m_nInfogainAxis);

            Blob <T> blobInfoGain = null;

            if (colBottom.Count < 3)
            {
                blobInfoGain = m_blobInfoGain;
            }
            else
            {
                blobInfoGain = colBottom[2];
            }

            m_log.CHECK_EQ(blobInfoGain.count(), m_nNumLabels * m_nNumLabels, "The infogain count must equal 'num_labels' * 'num_labels'.");
            m_blobSumRowsOfH.Reshape(new List <int>()
            {
                m_nNumLabels
            });
            if (colBottom.Count == 2)
            {
                // H is provided as a parameter and will not change.  Sum rows once.
                sum_rows_of_H(blobInfoGain);
            }
            if (colTop.Count >= 2)
            {
                // softmax output.
                colTop[1].ReshapeLike(colBottom[0]);
            }
        }