/// <summary>
        /// Scaled multi-heads attention component with skip connectioned feed forward layers
        /// </summary>
        /// <param name="inputQ">The input Q tensor</param>
        /// <param name="keyMask">The mask for softmax</param>
        /// <param name="batchSize">Batch size of input data set</param>
        /// <param name="graph">The instance of computing graph</param>
        /// <returns>Transformered output tensor</returns>
        public (IWeightTensor, IWeightTensor) Perform(IWeightTensor inputQ, IWeightTensor keyMask, int batchSize, IComputeGraph graph, bool outputAttenWeights = false)
        {
            using IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention");
            int seqLenQ = inputQ.Rows / batchSize;

            IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g);

            //Input projections
            var weightedQKV = g.View(g.Affine(inputQNorm, QKV, QKVb), dims: new long[] { batchSize, seqLenQ, 3, m_multiHeadNum, m_d });
            var allQ        = g.Select(weightedQKV, 2, 0);
            var allK        = g.Select(weightedQKV, 2, 1);
            var allV        = g.Select(weightedQKV, 2, 2);


            //Multi-head attentions
            IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d });
            IWeightTensor Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize *m_multiHeadNum, m_d, seqLenQ });
            IWeightTensor Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d });

            // Scaled softmax
            float scale = 1.0f / (float)(Math.Sqrt(m_d));
            var   attn  = g.MulBatch(Qs, Ks, scale);

            attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ });

            if (keyMask != null)
            {
                attn = g.Add(attn, keyMask, inPlace: true);
            }

            var attnProbs = g.Softmax(attn, inPlace: true);

            IWeightTensor sumAttnWeights = null;

            if (outputAttenWeights)
            {
                //Merge all attention probs over multi-heads
                sumAttnWeights = graph.Sum(attnProbs, 1);
                sumAttnWeights = graph.Div(sumAttnWeights, (float)m_multiHeadNum);
                sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize *seqLenQ, seqLenQ });
            }

            attnProbs = g.View(attnProbs, dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, seqLenQ });

            IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d });
            IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d });

            // Output projection
            IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true);
            IWeightTensor result          = graph.Add(finalAttResults, inputQ, inPlace: true);


            return(result, sumAttnWeights);
        }
Esempio n. 2
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions)
        {
            int batchSize = sntPairBatch.BatchSize;

            float cost = 0.0f;
            var   nrs  = new List <NetworkResult>();
            var   nr   = new NetworkResult {
                Output = new List <List <List <string> > >()
            };

            (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            IWeightTensor encOutput1;
            IWeightTensor encOutput2;

            if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU))
            {
                //We only check cache at inference time
                string cacheKey1 = GenerateCacheKey(sntPairBatch.GetSrcTokens(0));
                if (!m_memoryCache.TryGetValue(cacheKey1, out encOutput1))
                {
                    encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim]

                    var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1);
                    m_memoryCache.Set(cacheKey1, encOutput1.CopyWeightsRef($"cache_{encOutput1.Name}", false), cacheEntryOptions);
                }

                string cacheKey2 = GenerateCacheKey(sntPairBatch.GetSrcTokens(1));
                if (!m_memoryCache.TryGetValue(cacheKey2, out encOutput2))
                {
                    encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim]

                    var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1);
                    m_memoryCache.Set(cacheKey2, encOutput2.CopyWeightsRef($"cache_{encOutput2.Name}", false), cacheEntryOptions);
                }
            }
            else
            {
                //We always run encoder network during training time or using GPUs
                encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim]
                encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim]
            }

            if (m_modelMetaData.SimilarityType.Equals("Continuous", StringComparison.InvariantCultureIgnoreCase))
            {
                // Cosine similairy
                var w12 = computeGraph.EltMul(encOutput1, encOutput2);
                w12 = computeGraph.Sum(w12, 1);
                var w1 = computeGraph.EltMul(encOutput1, encOutput1);
                w1 = computeGraph.Sum(w1, 1);
                var w2 = computeGraph.EltMul(encOutput2, encOutput2);
                w2 = computeGraph.Sum(w2, 1);
                var n12 = computeGraph.EltMul(w1, w2);
                n12 = computeGraph.Rsqrt(n12);
                var probs = computeGraph.EltMul(w12, n12);
                if (isTraining)
                {
                    var tgtSnts = sntPairBatch.GetTgtTokens(0);
                    for (int k = 0; k < batchSize; k++)
                    {
                        float golden_score_k = float.Parse(tgtSnts[k][0]); // Get golden similiary score from target side
                        float score_k        = probs.GetWeightAt(new long[] { k, 0 });

                        probs.SetWeightAt(score_k - golden_score_k, new long[] { k, 0 });
                        cost += (float)Math.Abs(score_k - golden_score_k);
                    }

                    probs.CopyWeightsToGradients(probs);
                    nr.Cost = cost / batchSize;
                }
                else
                {
                    nr.Output.Add(new List <List <string> >());
                    for (int k = 0; k < batchSize; k++)
                    {
                        float score_k = probs.GetWeightAt(new long[] { k, 0 });

                        nr.Output[0].Add(new List <string>());
                        nr.Output[0][k].Add(score_k.ToString());
                    }
                }
            }
            else
            {
                IWeightTensor encOutput = computeGraph.EltMul(encOutput1, encOutput2);
                IWeightTensor ffLayer   = encoderFFLayer.Process(encOutput, batchSize, computeGraph);
                using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true))
                {
                    if (isTraining)
                    {
                        var tgtSnts = sntPairBatch.GetTgtTokens(0);
                        for (int k = 0; k < batchSize; k++)
                        {
                            int   ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts[k][0]);
                            float score_k        = probs.GetWeightAt(new long[] { k, ix_targets_k_j });
                            cost += (float)-Math.Log(score_k);
                            probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j });
                        }

                        ffLayer.CopyWeightsToGradients(probs);

                        nr.Cost = cost / batchSize;
                    }
                    else
                    {
                        // Output "i"th target word
                        using var targetIdxTensor = computeGraph.Argmax(probs, 1);
                        float[]       targetIdx   = targetIdxTensor.ToWeightArray();
                        List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList());
                        nr.Output.Add(new List <List <string> >());

                        for (int k = 0; k < batchSize; k++)
                        {
                            nr.Output[0].Add(new List <string>());
                            nr.Output[0][k].Add(targetWords[k]);
                        }
                    }
                }
            }


            nrs.Add(nr);

            return(nrs);
        }