/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="inputQ">The input Q tensor</param> /// <param name="keyMask">The mask for softmax</param> /// <param name="batchSize">Batch size of input data set</param> /// <param name="graph">The instance of computing graph</param> /// <returns>Transformered output tensor</returns> public (IWeightTensor, IWeightTensor) Perform(IWeightTensor inputQ, IWeightTensor keyMask, int batchSize, IComputeGraph graph, bool outputAttenWeights = false) { using IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention"); int seqLenQ = inputQ.Rows / batchSize; IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g); //Input projections var weightedQKV = g.View(g.Affine(inputQNorm, QKV, QKVb), dims: new long[] { batchSize, seqLenQ, 3, m_multiHeadNum, m_d }); var allQ = g.Select(weightedQKV, 2, 0); var allK = g.Select(weightedQKV, 2, 1); var allV = g.Select(weightedQKV, 2, 2); //Multi-head attentions IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d }); IWeightTensor Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize *m_multiHeadNum, m_d, seqLenQ }); IWeightTensor Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d }); // Scaled softmax float scale = 1.0f / (float)(Math.Sqrt(m_d)); var attn = g.MulBatch(Qs, Ks, scale); attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ }); if (keyMask != null) { attn = g.Add(attn, keyMask, inPlace: true); } var attnProbs = g.Softmax(attn, inPlace: true); IWeightTensor sumAttnWeights = null; if (outputAttenWeights) { //Merge all attention probs over multi-heads sumAttnWeights = graph.Sum(attnProbs, 1); sumAttnWeights = graph.Div(sumAttnWeights, (float)m_multiHeadNum); sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize *seqLenQ, seqLenQ }); } attnProbs = g.View(attnProbs, dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, seqLenQ }); IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d }); IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d }); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); IWeightTensor result = graph.Add(finalAttResults, inputQ, inPlace: true); return(result, sumAttnWeights); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { int batchSize = sntPairBatch.BatchSize; float cost = 0.0f; var nrs = new List <NetworkResult>(); var nr = new NetworkResult { Output = new List <List <List <string> > >() }; (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); IWeightTensor encOutput1; IWeightTensor encOutput2; if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU)) { //We only check cache at inference time string cacheKey1 = GenerateCacheKey(sntPairBatch.GetSrcTokens(0)); if (!m_memoryCache.TryGetValue(cacheKey1, out encOutput1)) { encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey1, encOutput1.CopyWeightsRef($"cache_{encOutput1.Name}", false), cacheEntryOptions); } string cacheKey2 = GenerateCacheKey(sntPairBatch.GetSrcTokens(1)); if (!m_memoryCache.TryGetValue(cacheKey2, out encOutput2)) { encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey2, encOutput2.CopyWeightsRef($"cache_{encOutput2.Name}", false), cacheEntryOptions); } } else { //We always run encoder network during training time or using GPUs encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] } if (m_modelMetaData.SimilarityType.Equals("Continuous", StringComparison.InvariantCultureIgnoreCase)) { // Cosine similairy var w12 = computeGraph.EltMul(encOutput1, encOutput2); w12 = computeGraph.Sum(w12, 1); var w1 = computeGraph.EltMul(encOutput1, encOutput1); w1 = computeGraph.Sum(w1, 1); var w2 = computeGraph.EltMul(encOutput2, encOutput2); w2 = computeGraph.Sum(w2, 1); var n12 = computeGraph.EltMul(w1, w2); n12 = computeGraph.Rsqrt(n12); var probs = computeGraph.EltMul(w12, n12); if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { float golden_score_k = float.Parse(tgtSnts[k][0]); // Get golden similiary score from target side float score_k = probs.GetWeightAt(new long[] { k, 0 }); probs.SetWeightAt(score_k - golden_score_k, new long[] { k, 0 }); cost += (float)Math.Abs(score_k - golden_score_k); } probs.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { float score_k = probs.GetWeightAt(new long[] { k, 0 }); nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(score_k.ToString()); } } } else { IWeightTensor encOutput = computeGraph.EltMul(encOutput1, encOutput2); IWeightTensor ffLayer = encoderFFLayer.Process(encOutput, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } } nrs.Add(nr); return(nrs); }