Beispiel #1
0
        /// <summary>
        /// Attempts to share a Layer Blob if another parameter Blob with the same name and acceptable size is found.
        /// </summary>
        /// <param name="b">Specifies the Blob to share.</param>
        /// <param name="rgMinShape">Specifies the minimum shape requried to share.</param>
        /// <returns>If the Blob is shared, <i>true</i> is returned, otherwise <i>false</i> is returned.</returns>
        protected bool shareLayerBlob(Blob <T> b, List <int> rgMinShape)
        {
            LayerParameterEx <T> paramEx = m_param as LayerParameterEx <T>;

            if (paramEx == null)
            {
                return(false);
            }

            if (paramEx.SharedLayerBlobs == null)
            {
                return(false);
            }

            return(paramEx.SharedLayerBlobs.Share(b, rgMinShape, false));
        }
Beispiel #2
0
        /// <summary>
        /// Setup the layer.
        /// </summary>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        public override void LayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            LSTMAttentionParameter p = m_param.lstm_attention_param;

            if (m_param.lstm_attention_param.enable_attention)
            {
                m_log.CHECK_GE(colBottom.Count, 4, "When using attention, four bottoms are required: x, xClip, encoding, encodingClip.");
                m_log.CHECK_LE(colBottom.Count, 5, "When using attention, four bottoms are required: x, xClip, encoding, encodingClip, vocabcount (optional).");

                if (colBottom.Count == 5)
                {
                    if (p.num_output_ip != 0)
                    {
                        p.num_output_ip = (uint)convertF(colBottom[4].GetData(0));
                    }
                }
            }
            else
            {
                m_log.CHECK_GE(colBottom.Count, 1, "When not using attention, at least one bottom is required: x.");
                m_log.CHECK_LE(colBottom.Count, 2, "When not using attention, no more than two bottoms is required: x, clip.");
            }

            m_dfClippingThreshold = p.clipping_threshold;
            m_nN = colBottom[0].channels;
            m_nH = (int)p.num_output;      // number of hidden units.
            m_nI = colBottom[0].count(2);  // input dimension.

            // Check if we need to set up the weights.
            if (m_colBlobs.Count > 0)
            {
                m_log.WriteLine("Skipping parameter initialization.");
            }
            else
            {
                m_colBlobs = new BlobCollection <T>();

                Filler <T> weight_filler = Filler <T> .Create(m_cuda, m_log, p.weight_filler);

                Filler <T> bias_filler = Filler <T> .Create(m_cuda, m_log, p.bias_filler);

                // input-to-hidden weights
                // Initialize the weight.
                List <int> rgShape1 = new List <int>()
                {
                    4 * m_nH, m_nI
                };
                Blob <T> blobWeights_I_H = new Blob <T>(m_cuda, m_log);
                blobWeights_I_H.Name = m_param.name + " weights I to H";
                blobWeights_I_H.type = BLOB_TYPE.WEIGHT;

                if (!shareParameter(blobWeights_I_H, rgShape1))
                {
                    blobWeights_I_H.Reshape(rgShape1);
                    weight_filler.Fill(blobWeights_I_H);
                }
                m_nWeightItoHidx = m_colBlobs.Count;
                m_colBlobs.Add(blobWeights_I_H);

                // hidden-to-hidden weights
                // Initialize the weight.
                List <int> rgShape2 = new List <int>()
                {
                    4 * m_nH, m_nH
                };
                Blob <T> blobWeights_H_H = new Blob <T>(m_cuda, m_log);
                blobWeights_H_H.Name = m_param.name + " weights H to H";
                blobWeights_H_H.type = BLOB_TYPE.WEIGHT;

                if (!shareParameter(blobWeights_H_H, rgShape2))
                {
                    blobWeights_H_H.Reshape(rgShape2);
                    weight_filler.Fill(blobWeights_H_H);
                }
                m_nWeightHtoHidx = m_colBlobs.Count;
                m_colBlobs.Add(blobWeights_H_H);

                // If necessary, initialize and fill the bias term.
                List <int> rgShape3 = new List <int>()
                {
                    4 * m_nH
                };
                Blob <T> blobBias = new Blob <T>(m_cuda, m_log);
                blobBias.Name = m_param.name + " bias weights";
                blobBias.type = BLOB_TYPE.WEIGHT;

                if (!shareParameter(blobBias, rgShape3))
                {
                    blobBias.Reshape(rgShape3);
                    bias_filler.Fill(blobBias);
                }
                m_nWeightBiasidx = m_colBlobs.Count;
                m_colBlobs.Add(blobBias);

                // Initialize the bias for the forget gate to 5.0 as described in the
                // Clockwork RNN paper:
                // [1] Koutnik, J., Greff, K., Gomez, F., Schmidhuber, J., 'A Clockwork RNN', 2014"
                if (p.enable_clockwork_forgetgate_bias)
                {
                    double[] rgBias = convertD(blobBias.mutable_cpu_data);

                    for (int i = m_nH; i < 2 * m_nH; i++)
                    {
                        rgBias[i] = 5.0;
                    }

                    blobBias.mutable_cpu_data = convert(rgBias);
                }

                if (m_param.lstm_attention_param.num_output_ip > 0)
                {
                    Blob <T> blobWeightWhd = new Blob <T>(m_cuda, m_log);
                    blobWeightWhd.Name = m_param.name + " weights Whd";
                    blobWeightWhd.type = BLOB_TYPE.WEIGHT;

                    List <int> rgShapeWhd = new List <int>()
                    {
                        m_nH, (int)m_param.lstm_attention_param.num_output_ip
                    };
                    if (!shareParameter(blobWeightWhd, rgShapeWhd))
                    {
                        blobWeightWhd.Reshape(rgShapeWhd);
                        weight_filler.Fill(blobWeightWhd);
                    }
                    m_nWeightWhdidx = m_colBlobs.Count;
                    m_colBlobs.Add(blobWeightWhd);

                    Blob <T> blobWeightWhdb = new Blob <T>(m_cuda, m_log);
                    blobWeightWhdb.Name = m_param.name + " weights Whdb";
                    blobWeightWhdb.type = BLOB_TYPE.WEIGHT;

                    List <int> rgShapeWhdb = new List <int>()
                    {
                        1, (int)m_param.lstm_attention_param.num_output_ip
                    };
                    if (!shareParameter(blobWeightWhdb, rgShape1))
                    {
                        blobWeightWhdb.Reshape(rgShapeWhdb);
                        bias_filler.Fill(blobWeightWhdb);
                    }
                    m_nWeightWhdbidx = m_colBlobs.Count;
                    m_colBlobs.Add(blobWeightWhdb);
                }

                if (m_param.lstm_attention_param.enable_attention)
                {
                    // context-to-hidden weights
                    // Initialize the weight.
                    Blob <T> blobWeights_C_H = new Blob <T>(m_cuda, m_log);
                    blobWeights_C_H.Name = m_param.name + " weights C to H";
                    blobWeights_C_H.type = BLOB_TYPE.WEIGHT;

                    if (!shareParameter(blobWeights_C_H, rgShape1))
                    {
                        blobWeights_C_H.Reshape(rgShape1); // same shape as I to H
                        weight_filler.Fill(blobWeights_C_H);
                    }
                    m_nWeightCtoHidx = m_colBlobs.Count;
                    m_colBlobs.Add(blobWeights_C_H);
                }
            }

            m_rgbParamPropagateDown = new DictionaryMap <bool>(m_colBlobs.Count, true);

            List <int> rgCellShape = new List <int>()
            {
                m_nN, m_nH
            };

            m_blob_C_0.Reshape(rgCellShape);
            m_blob_H_0.Reshape(rgCellShape);
            m_blob_C_T.Reshape(rgCellShape);
            m_blob_H_T.Reshape(rgCellShape);
            m_blob_H_to_H.Reshape(rgCellShape);

            List <int> rgGateShape = new List <int>()
            {
                m_nN, 4, m_nH
            };

            m_blob_H_to_Gate.Reshape(rgGateShape);

            // Attention settings
            if (m_param.lstm_attention_param.enable_attention)
            {
                m_blob_C_to_Gate      = new Blob <T>(m_cuda, m_log, false);
                m_blob_C_to_Gate.Name = m_param.name + "c_to_gate";
                m_blob_C_to_Gate.Reshape(rgGateShape);

                m_blobContext      = new Blob <T>(m_cuda, m_log);
                m_blobContext.Name = "context_out";

                m_blobContextFull      = new Blob <T>(m_cuda, m_log);
                m_blobContextFull.Name = "context_full";

                m_blobPrevCt      = new Blob <T>(m_cuda, m_log);
                m_blobPrevCt.Name = "prev_ct";

                LayerParameter attentionParam = new LayerParameter(LayerParameter.LayerType.ATTENTION);
                attentionParam.attention_param.axis          = 2;
                attentionParam.attention_param.dim           = m_param.lstm_attention_param.num_output;
                attentionParam.attention_param.weight_filler = m_param.lstm_attention_param.weight_filler;
                attentionParam.attention_param.bias_filler   = m_param.lstm_attention_param.bias_filler;

                if (m_param is LayerParameterEx <T> )
                {
                    LayerParameterEx <T> pEx = m_param as LayerParameterEx <T>;
                    attentionParam = new LayerParameterEx <T>(attentionParam, pEx.SharedBlobs, pEx.SharedLayerBlobs, pEx.SharedLayer);
                }

                m_attention = new AttentionLayer <T>(m_cuda, m_log, attentionParam);

                Blob <T> blobEncoding     = colBottom[2];
                Blob <T> blobEncodingClip = colBottom[3];
                addInternal(new List <Blob <T> >()
                {
                    blobEncoding, m_blob_C_T, blobEncodingClip
                }, m_blobContext);
                m_attention.Setup(m_colInternalBottom, m_colInternalTop);

                foreach (Blob <T> b in m_attention.blobs)
                {
                    m_colBlobs.Add(b);
                }
            }
        }
        /// <summary>
        /// The AttentionLayer constructor.
        /// </summary>
        /// <param name="cuda">Specifies the CudaDnn connection to Cuda.</param>
        /// <param name="log">Specifies the Log for output.</param>
        /// <param name="p">provides LayerParameter inner_product_param, with options:
        /// </param>
        public AttentionLayer(CudaDnn <T> cuda, Log log, LayerParameter p)
            : base(cuda, log, p)
        {
            m_type = LayerParameter.LayerType.ATTENTION;

            List <int> rgDimClip = new List <int>()
            {
                1, 0
            };
            LayerParameter transposeClipparam = new LayerParameter(LayerParameter.LayerType.TRANSPOSE);

            transposeClipparam.transpose_param.dim = new List <int>(rgDimClip);

            m_transposeClip = new TransposeLayer <T>(cuda, log, transposeClipparam);

            LayerParameter ipUaParam = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT);

            ipUaParam.name = "ipUa";
            ipUaParam.inner_product_param.axis          = 2;
            ipUaParam.inner_product_param.num_output    = m_param.attention_param.dim;
            ipUaParam.inner_product_param.weight_filler = m_param.attention_param.weight_filler;
            ipUaParam.inner_product_param.bias_filler   = m_param.attention_param.bias_filler;

            if (m_param is LayerParameterEx <T> )
            {
                LayerParameterEx <T> pEx = m_param as LayerParameterEx <T>;
                ipUaParam = new LayerParameterEx <T>(ipUaParam, pEx.SharedBlobs, pEx.SharedLayerBlobs, pEx.SharedLayer);
            }

            m_ipUa = new InnerProductLayer <T>(cuda, log, ipUaParam);

            LayerParameter ipWaParam = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT);

            ipWaParam.name = "ipWa";
            ipWaParam.inner_product_param.axis          = 1;
            ipWaParam.inner_product_param.num_output    = m_param.attention_param.dim;
            ipWaParam.inner_product_param.weight_filler = m_param.attention_param.weight_filler;
            ipWaParam.inner_product_param.bias_filler   = m_param.attention_param.bias_filler;

            if (m_param is LayerParameterEx <T> )
            {
                LayerParameterEx <T> pEx = m_param as LayerParameterEx <T>;
                ipWaParam = new LayerParameterEx <T>(ipWaParam, pEx.SharedBlobs, pEx.SharedLayerBlobs, pEx.SharedLayer);
            }

            m_ipWa = new InnerProductLayer <T>(cuda, log, ipWaParam);

            LayerParameter addParam = new LayerParameter(LayerParameter.LayerType.ELTWISE);

            addParam.name = "add";
            addParam.eltwise_param.operation = EltwiseParameter.EltwiseOp.SUM;

            m_add1 = new EltwiseLayer <T>(cuda, log, addParam);

            LayerParameter tanhParam = new LayerParameter(LayerParameter.LayerType.TANH);

            tanhParam.name = "tanh";
            tanhParam.tanh_param.engine = EngineParameter.Engine.CUDNN;

            m_tanh = new TanhLayer <T>(cuda, log, tanhParam);

            LayerParameter ipVParam = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT);

            ipVParam.name = "ipV";
            ipVParam.inner_product_param.axis          = 2;
            ipVParam.inner_product_param.num_output    = 1;
            ipVParam.inner_product_param.bias_term     = false;
            ipVParam.inner_product_param.weight_filler = m_param.attention_param.weight_filler;

            if (m_param is LayerParameterEx <T> )
            {
                LayerParameterEx <T> pEx = m_param as LayerParameterEx <T>;
                ipVParam = new LayerParameterEx <T>(ipVParam, pEx.SharedBlobs, pEx.SharedLayerBlobs, pEx.SharedLayer);
            }

            m_ipV = new InnerProductLayer <T>(cuda, log, ipVParam);

            m_blobX      = new Blob <T>(cuda, log);
            m_blobX.Name = "x";

            m_blobClip      = new Blob <T>(cuda, log);
            m_blobClip.Name = "clip";

            m_blobX1      = new Blob <T>(cuda, log);
            m_blobX1.Name = "x1";

            m_blobState      = new Blob <T>(cuda, log);
            m_blobState.Name = "state";

            m_blobUh      = new Blob <T>(cuda, log);
            m_blobUh.Name = "Uh";

            m_blobWc      = new Blob <T>(cuda, log);
            m_blobWc.Name = "Wc";

            m_blobFullWc      = new Blob <T>(cuda, log);
            m_blobFullWc.Name = "Full Wc";

            m_blobAddOutput      = new Blob <T>(cuda, log);
            m_blobAddOutput.Name = "addOut";

            m_blobGG      = new Blob <T>(cuda, log);
            m_blobGG.Name = "gg";

            m_blobAA      = new Blob <T>(cuda, log);
            m_blobAA.Name = "aa";

            m_blobScale      = new Blob <T>(cuda, log, false);
            m_blobScale.Name = "scale";

            m_blobSoftmax      = new Blob <T>(cuda, log);
            m_blobSoftmax.Name = "softmax";

            m_blobFocusedInput      = new Blob <T>(cuda, log);
            m_blobFocusedInput.Name = "softmax_full";

            m_blobContext      = new Blob <T>(cuda, log);
            m_blobContext.Name = "context";

            m_blobWork      = new Blob <T>(cuda, log);
            m_blobWork.Name = "work";
        }
        /// <summary>
        /// Setup the ImageDataLayer by starting up the pre-fetching.
        /// </summary>
        /// <param name="colBottom">Not used.</param>
        /// <param name="colTop">Specifies the collection of top (output) Blobs.</param>
        protected override void DataLayerSetUp(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            int    nBatchSize    = (int)m_param.data_param.batch_size;
            int    nNewHeight    = (int)m_param.image_data_param.new_height;
            int    nNewWidth     = (int)m_param.image_data_param.new_width;
            bool   bIsColor      = m_param.image_data_param.is_color;
            string strRootFolder = getRootFolder();

            m_log.CHECK((nNewHeight == 0 && nNewWidth == 0) || (nNewHeight > 0 && nNewWidth > 0), "Current implementation requires new_height and new_width to be set at the same time.");

            // Read the file with filenames and labels.
            loadFileList();

            // Randomly shuffle the images.
            if (m_param.image_data_param.shuffle)
            {
                shuffleImages();
            }
            else if (m_param.image_data_param.rand_skip == 0)
            {
                LayerParameterEx <T> layer_param = m_param as LayerParameterEx <T>;
                if (layer_param != null && layer_param.solver_rank > 0)
                {
                    m_log.WriteLine("WARNING: Shuffling or skipping recommended for multi-GPU.");
                }
            }

            m_log.WriteLine("A total of " + m_rgLines.Count.ToString("N0") + " images.");

            m_nLinesId = 0;
            // Check if we would need to randomly skip a few data points.
            if (m_param.image_data_param.rand_skip > 0)
            {
                int nSkip = m_random.Next((int)m_param.image_data_param.rand_skip);
                m_log.WriteLine("Skipping first " + nSkip.ToString() + " data points.");
                m_log.CHECK_GT(m_rgLines.Count, nSkip, "Not enough data points to skip.");
                m_nLinesId = nSkip;
            }

            // Read an image and use it to initialize the top blob.
            Datum datum = loadImage(strRootFolder, m_rgLines[m_nLinesId], bIsColor, nNewHeight, nNewWidth);
            // Use data_transofrmer to infer the expected blob shape from the image.
            List <int> rgTopShape = m_transformer.InferBlobShape(datum);

            // Reshape colTop[0] and prefetch data according to the batch size.
            rgTopShape[0] = nBatchSize;
            colTop[0].Reshape(rgTopShape);

            for (int i = 0; i < m_rgPrefetch.Length; i++)
            {
                m_rgPrefetch[i].Data.Reshape(rgTopShape);
            }

            m_log.WriteLine("output data size: " + colTop[0].ToSizeString());

            // label.
            List <int> rgLabelShape = new List <int>()
            {
                nBatchSize
            };

            colTop[1].Reshape(rgLabelShape);

            for (int i = 0; i < m_rgPrefetch.Length; i++)
            {
                m_rgPrefetch[i].Label.Reshape(rgLabelShape);
            }
        }