/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> public IWeightTensor Encode(IWeightTensor rawInput, int batchSize, IComputeGraph g) { int seqLen = rawInput.Rows / batchSize; IWeightTensor posEmbedding = g.BuildPositionMatrix(seqLen, m_inputDim); IWeightTensor posEmbeddingRepeat = g.RepeatRows(posEmbedding, batchSize, runGradient: false); // Transpose to batch-first based sequence IWeightTensor inputs = g.TransposeBatch(rawInput, batchSize); inputs = g.AddMul(posEmbeddingRepeat, inputs, (float)Math.Sqrt(m_inputDim), runGradientW1: false, runGradientW2: true); // We don't update position embedding, so dispose it now to save memory. posEmbeddingRepeat.Dispose(); posEmbedding.Dispose(); inputs = g.Dropout(inputs, batchSize, m_dropoutRatio, inPlace: true); for (int k = 0; k < m_encoders.Count; k++) { inputs = m_encoders[k].Perform(inputs, batchSize, g); } // Transpose back to time-first based sequence rawInput = g.TransposeBatch(inputs, seqLen); return(rawInput); }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> /// public IWeightTensor Decode(IWeightTensor tgtInputs, IWeightTensor encOutputBatchFirst, IWeightTensor tgtSelfMask, IWeightTensor decEncAttnMask, IWeightTensor tgtDimMask, int batchSize, IComputeGraph g) { int tgtSeqLen = tgtInputs.Rows / batchSize; int srcSeqLen = encOutputBatchFirst.Rows / batchSize; using (IWeightTensor posEmbedding = g.BuildPositionMatrix(tgtSeqLen, m_inputDim)) { using (IWeightTensor posEmbeddingRepeat = g.RepeatRows(posEmbedding, batchSize, runGradient: false)) { tgtInputs = g.AddMul(posEmbeddingRepeat, tgtInputs, (float)Math.Sqrt(m_inputDim), runGradientW1: false, runGradientW2: true); } } tgtInputs = g.Dropout(tgtInputs, batchSize, m_dropoutRatio, inPlace: true); var tgtSelfMaskRep = g.View(tgtSelfMask, dims: new long[] { 1, batchSize, tgtSeqLen, tgtSeqLen }); var tgtSelfMaskRepExp = g.Expand(tgtSelfMaskRep, dims: new long[] { m_multiHeadNum, batchSize, tgtSeqLen, tgtSeqLen }); var decEncAttnMaskRep = g.View(decEncAttnMask, dims: new long[] { 1, batchSize, tgtSeqLen, srcSeqLen }); var decEncAttnMaskRepExp = g.Expand(decEncAttnMaskRep, dims: new long[] { m_multiHeadNum, batchSize, tgtSeqLen, srcSeqLen }); var tgtSelfMaskRepExpView = g.View(tgtSelfMaskRepExp, dims: new long[] { m_multiHeadNum *batchSize *tgtSeqLen, tgtSeqLen }); var decEncAttnMaskRepExpView = g.View(decEncAttnMaskRepExp, dims: new long[] { m_multiHeadNum *batchSize *tgtSeqLen, srcSeqLen }); tgtSelfMaskRep.Dispose(); tgtSelfMaskRepExp.Dispose(); decEncAttnMaskRep.Dispose(); decEncAttnMaskRepExp.Dispose(); using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder")) { for (int k = 0; k < m_selfAttns.Count; k++) { tgtInputs = g.MaskFill(tgtInputs, tgtDimMask, 0.0f); tgtInputs = m_selfAttns[k].Perform(tgtInputs, tgtInputs, tgtInputs, tgtSelfMaskRepExpView, batchSize, subg); tgtInputs = m_encAttns[k].Perform(tgtInputs, encOutputBatchFirst, encOutputBatchFirst, decEncAttnMaskRepExpView, batchSize, subg); tgtInputs = m_posFFNs[k].Perform(tgtInputs, batchSize, subg); } tgtInputs.UnbindFromComputeGraph(); } tgtInputs = layerNorm.Norm(tgtInputs, g); // tgtInputs = m_decoderFFLayer.Process(tgtInputs, batchSize, g); return(tgtInputs); }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> public IWeightTensor Encode(IWeightTensor inputs, IWeightTensor selfMask, IWeightTensor dimMask, int batchSize, IComputeGraph g) { int seqLen = inputs.Rows / batchSize; using (IWeightTensor posEmbedding = g.BuildPositionMatrix(seqLen, m_inputDim)) { using (IWeightTensor posEmbeddingRepeat = g.RepeatRows(posEmbedding, batchSize, runGradient: false)) { inputs = g.AddMul(posEmbeddingRepeat, inputs, (float)Math.Sqrt(m_inputDim), runGradientW1: false, runGradientW2: true); } } inputs = g.Dropout(inputs, batchSize, m_dropoutRatio, inPlace: true); var selfMaskRep = g.View(selfMask, dims: new long[] { 1, batchSize, seqLen, seqLen }); var multiHeadhSelfMaskRep = g.Expand(selfMaskRep, dims: new long[] { m_multiHeadNum, batchSize, seqLen, seqLen }); var multiHeadhSelfMaskRepView = g.View(multiHeadhSelfMaskRep, dims: new long[] { m_multiHeadNum *batchSize *seqLen, seqLen }); selfMaskRep.Dispose(); multiHeadhSelfMaskRep.Dispose(); using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Encoder")) { for (int k = 0; k < m_encoders.Count; k++) { inputs = g.MaskFill(inputs, dimMask, 0.0f); inputs = m_encoders[k].Perform(inputs, inputs, inputs, multiHeadhSelfMaskRepView, batchSize, subg); inputs = m_posFFNs[k].Perform(inputs, batchSize, subg); } inputs.UnbindFromComputeGraph(); } inputs = layerNorm.Norm(inputs, g); return(inputs); }