Exemplo n.º 1
0
        public float[] FillSmall(AttentionParameter p)
        {
            m_blob_bottom.Reshape(2, 1, 1, 1);

            // timesteps = 2, batch = 1, input = 1
            float[] rgData = convertF(m_blob_bottom.mutable_cpu_data); // shape (2, 1, 1, 1)
            // timestep 1, batch 1
            rgData[0] = 1.11f;
            rgData[1] = 2.11f;

            m_blob_bottom.mutable_cpu_data = convert(rgData);
            m_blobState.Reshape(new List <int>()
            {
                1, (int)p.dim
            });

            List <int> rgShape = Utility.Clone <int>(m_blob_bottom.shape());

            while (rgShape.Count > 2)
            {
                rgShape.RemoveAt(rgShape.Count - 1);
            }
            m_blobClip.Reshape(rgShape);
            m_blobClip.SetData(1);

            BottomVec.Clear();
            BottomVec.Add(m_blob_bottom);
            BottomVec.Add(m_blobState);
            BottomVec.Add(m_blobClip);

            return(rgData);
        }
Exemplo n.º 2
0
        private float[] calculateAttention(AttentionParameter p, float[] rgData, float[] rgState, int nT, int nB, int nI, int nS)
        {
            int nM = nT * nB;
            int nN = (int)p.dim;
            int nK = nI;

            float[] rgDataT  = SimpleDatum.Transpose(rgData, nT, nB, nI);
            float[] rgStateT = SimpleDatum.Transpose(rgState, 1, nB, nS);

            // IP input data with rgUa wts.
            create_weights(ref m_rgUa, nN * nK);
            create_weights(ref m_rgUab, nN, 0.1f);
            create_weights(ref m_rgBiasMult, nM, 1.0f);

            float[] rgUh  = gemm(false, false, nM, nN, nK, 1.0f, rgDataT, m_rgUa, 0.0f);
            float[] rgUhb = gemm(false, false, nM, nN, 1, 1.0f, m_rgBiasMult, m_rgUab, 0.0f);
            rgUh = add(rgUh, rgUhb);

            // IP rgFullState with rgWa wts.
            nM = nB;
            nK = (int)p.dim;
            create_weights(ref m_rgWa, nN * nK);
            create_weights(ref m_rgWab, nN, 0.1f);
            m_rgBiasMult = null;
            create_weights(ref m_rgBiasMult, nM, 1.0f);

            float[] rgWc  = gemm(false, false, nM, nN, nK, 1.0f, rgState, m_rgWa, 0.0f);
            float[] rgWcb = gemm(false, false, nM, nN, 1, 1.0f, m_rgBiasMult, m_rgWab, 0.0f);
            rgWc = add(rgWc, rgWcb);

            // Copy rgWc across all T.
            float[] rgFullWc = expand(rgWc, nT);

            // Add uh + wc
            float[] rgUhWc = add(rgUh, rgFullWc);

            // rgGg = Tanh(un + wc);
            float[] rgGg = tanh(rgUhWc);

            // rgAa = IP rgGg with rgV wts.
            nM = nT * nB;
            nN = 1;
            nK = nS;
            create_weights(ref m_rgV, nN * nK);
            m_rgAa = gemm(false, false, nM, nN, nK, 1.0f, rgGg, m_rgV, 0.0f);

            // Softmax over time steps T.
            m_rgSoftmax = softmax(m_rgAa, nB, nT);

            // Multiply softmax vector with input data.
            float[] rgFocusInput = mul(rgDataT, m_rgSoftmax, nB, nT, nI);

            // Sum across all T.
            float[] rgContext = sum(rgFocusInput, nB, nT, nI);

            return(rgContext);
        }
Exemplo n.º 3
0
        public float[] Fill2(AttentionParameter p)
        {
            m_blob_bottom.Reshape(3, 1, 32, 1);

            // timesteps = 3, batch = 1, input = 32
            float[] rgData = convertF(m_blob_bottom.mutable_cpu_data);
            for (int t = 0; t < m_blob_bottom.num; t++)
            {
                for (int i = 0; i < 32; i++)
                {
                    int nIdx = t * 32 + i;
                    rgData[nIdx] = (t + 1) + (i * 0.01f);
                }
            }

            m_blob_bottom.mutable_cpu_data = convert(rgData);
            m_blobState.Reshape(new List <int>()
            {
                1, (int)p.dim
            });

            List <int> rgShape = Utility.Clone <int>(m_blob_bottom.shape());

            while (rgShape.Count > 2)
            {
                rgShape.RemoveAt(rgShape.Count - 1);
            }
            m_blobClip.Reshape(rgShape);
            m_blobClip.SetData(1);

            BottomVec.Clear();
            BottomVec.Add(m_blob_bottom);
            BottomVec.Add(m_blobState);
            BottomVec.Add(m_blobClip);

            return(rgData);
        }