/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="inputQ">The input Q tensor</param> /// <param name="keyMask">The mask for softmax</param> /// <param name="batchSize">Batch size of input data set</param> /// <param name="graph">The instance of computing graph</param> /// <returns>Transformered output tensor</returns> public (IWeightTensor, IWeightTensor) Perform(IWeightTensor inputQ, IWeightTensor keyMask, int batchSize, IComputeGraph graph, bool outputAttenWeights = false) { using IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention"); int seqLenQ = inputQ.Rows / batchSize; IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g); //Input projections var weightedQKV = g.View(g.Affine(inputQNorm, QKV, QKVb), dims: new long[] { batchSize, seqLenQ, 3, m_multiHeadNum, m_d }); var allQ = g.Select(weightedQKV, 2, 0); var allK = g.Select(weightedQKV, 2, 1); var allV = g.Select(weightedQKV, 2, 2); //Multi-head attentions IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d }); IWeightTensor Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize *m_multiHeadNum, m_d, seqLenQ }); IWeightTensor Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d }); // Scaled softmax float scale = 1.0f / (float)(Math.Sqrt(m_d)); var attn = g.MulBatch(Qs, Ks, scale); attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ }); if (keyMask != null) { attn = g.Add(attn, keyMask, inPlace: true); } var attnProbs = g.Softmax(attn, inPlace: true); IWeightTensor sumAttnWeights = null; if (outputAttenWeights) { //Merge all attention probs over multi-heads sumAttnWeights = graph.Sum(attnProbs, 1); sumAttnWeights = graph.Div(sumAttnWeights, (float)m_multiHeadNum); sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize *seqLenQ, seqLenQ }); } attnProbs = g.View(attnProbs, dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, seqLenQ }); IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d }); IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d }); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); IWeightTensor result = graph.Add(finalAttResults, inputQ, inPlace: true); return(result, sumAttnWeights); }
public IWeightTensor Step(IWeightTensor input, IComputeGraph g) { using (IComputeGraph innerGraph = g.CreateSubGraph(m_name)) { IWeightTensor hidden_prev = m_hidden; IWeightTensor cell_prev = m_cell; IWeightTensor inputs = innerGraph.Concate(1, input, hidden_prev); IWeightTensor hhSum = innerGraph.Affine(inputs, m_Wxh, m_b); IWeightTensor hhSum2 = m_layerNorm1.Norm(hhSum, innerGraph); (IWeightTensor gates_raw, IWeightTensor cell_write_raw) = innerGraph.SplitColumns(hhSum2, m_hdim * 3, m_hdim); IWeightTensor gates = innerGraph.Sigmoid(gates_raw); IWeightTensor cell_write = innerGraph.Tanh(cell_write_raw); (IWeightTensor input_gate, IWeightTensor forget_gate, IWeightTensor output_gate) = innerGraph.SplitColumns(gates, m_hdim, m_hdim, m_hdim); // compute new cell activation: ct = forget_gate * cell_prev + input_gate * cell_write m_cell = g.EltMulMulAdd(forget_gate, cell_prev, input_gate, cell_write); IWeightTensor ct2 = m_layerNorm2.Norm(m_cell, innerGraph); // compute hidden state as gated, saturated cell activations m_hidden = g.EltMul(output_gate, innerGraph.Tanh(ct2)); return(m_hidden); } }
/// <summary> /// Update LSTM-Attention cells according to given weights /// </summary> /// <param name="context">The context weights for attention</param> /// <param name="input">The input weights</param> /// <param name="computeGraph">The compute graph to build workflow</param> /// <returns>Update hidden weights</returns> public IWeightTensor Step(IWeightTensor context, IWeightTensor input, IComputeGraph g) { using (IComputeGraph computeGraph = g.CreateSubGraph(m_name)) { IWeightTensor cell_prev = Cell; IWeightTensor hidden_prev = Hidden; IWeightTensor hxhc = computeGraph.Concate(1, input, hidden_prev, context); IWeightTensor hhSum = computeGraph.Affine(hxhc, m_Wxhc, m_b); IWeightTensor hhSum2 = m_layerNorm1.Norm(hhSum, computeGraph); (IWeightTensor gates_raw, IWeightTensor cell_write_raw) = computeGraph.SplitColumns(hhSum2, m_hiddenDim * 3, m_hiddenDim); IWeightTensor gates = computeGraph.Sigmoid(gates_raw); IWeightTensor cell_write = computeGraph.Tanh(cell_write_raw); (IWeightTensor input_gate, IWeightTensor forget_gate, IWeightTensor output_gate) = computeGraph.SplitColumns(gates, m_hiddenDim, m_hiddenDim, m_hiddenDim); // compute new cell activation: ct = forget_gate * cell_prev + input_gate * cell_write Cell = g.EltMulMulAdd(forget_gate, cell_prev, input_gate, cell_write); IWeightTensor ct2 = m_layerNorm2.Norm(Cell, computeGraph); Hidden = g.EltMul(output_gate, computeGraph.Tanh(ct2)); return(Hidden); } }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> public IWeightTensor Encode(IWeightTensor inputs, int batchSize, IComputeGraph g, IWeightTensor srcSelfMask) { using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Encoder")) { IWeightTensor maskTensor = null; if (srcSelfMask != null) { int seqLen = inputs.Rows / batchSize; using var keyMaskView = subg.View(srcSelfMask, dims: new long[] { batchSize, 1, seqLen, seqLen }); maskTensor = subg.Expand(keyMaskView, dims: new long[] { batchSize, m_multiHeadNum, seqLen, seqLen }); } IWeightTensor attnProbs = null; for (int k = 0; k < m_encoders.Count; k++) { (inputs, attnProbs) = m_encoders[k].Perform(inputs, maskTensor, batchSize, subg, outputAttenWeights: false); inputs = m_posFFNs[k].Perform(inputs, batchSize, subg); } inputs = layerNorm.Norm(inputs, subg); inputs.UnbindFromComputeGraph(); if (attnProbs != null) { attnProbs.UnbindFromComputeGraph(); } if (maskTensor != null) { maskTensor.Dispose(); } } return(inputs); }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> /// public (IWeightTensor, IWeightTensor) Decode(IWeightTensor tgtInputs, IWeightTensor encOutputBatchFirst, IWeightTensor tgtSelfMask, IWeightTensor srcTgtMask, int batchSize, IComputeGraph g, bool outputAttnWeights = false, Dictionary <string, IWeightTensor> cachedTensors = null) { IWeightTensor attnProbs = null; using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder")) { int seqLenQ = tgtInputs.Rows / batchSize; // SeqLenK must be euqal to SeqLenV int seqLenK = encOutputBatchFirst.Rows / batchSize; IWeightTensor selfMaskTensor = null; if (tgtSelfMask != null) { selfMaskTensor = subg.Expand(tgtSelfMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ }); } IWeightTensor crossMaskTensor = null; if (srcTgtMask != null) { crossMaskTensor = subg.Expand(srcTgtMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK }); } for (int k = 0; k < m_selfAttns.Count; k++) { (tgtInputs, attnProbs) = m_selfAttns[k].Perform(tgtInputs, selfMaskTensor, batchSize, subg, outputAttenWeights: false); (tgtInputs, attnProbs) = m_encAttns[k].Perform(tgtInputs, encOutputBatchFirst, encOutputBatchFirst, crossMaskTensor, batchSize, subg, outputAttenWeights: (outputAttnWeights && k == m_selfAttns.Count - 1), cachedTensors: cachedTensors); tgtInputs = m_posFFNs[k].Perform(tgtInputs, batchSize, subg); } tgtInputs = layerNorm.Norm(tgtInputs, subg); tgtInputs.UnbindFromComputeGraph(); if (attnProbs != null) { attnProbs.UnbindFromComputeGraph(); } if (selfMaskTensor != null) { selfMaskTensor.Dispose(); } if (crossMaskTensor != null) { crossMaskTensor.Dispose(); } } // tgtInputs = m_decoderFFLayer.Process(tgtInputs, batchSize, g); return(tgtInputs, attnProbs); }
/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="inputQ">The input Q tensor</param> /// <param name="inputK">The input K tensor</param> /// <param name="inputV">The input V tensor</param> /// <param name="batchSize">Batch size of input data set</param> /// <param name="graph">The instance of computing graph</param> /// <returns>Transformered output tensor</returns> public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor inputK, IWeightTensor inputV, IWeightTensor keyMask, int batchSize, IComputeGraph graph) { if (m_sharedQKV) { throw new ArgumentException($"Layer '{m_name}' is in shared QKV mode, please call antoher Perform function with single input tensor."); } using (IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention")) { int seqLenQ = inputQ.Rows / batchSize; // SeqLenK must be euqal to SeqLenV int seqLenK = inputK.Rows / batchSize; int seqLenV = inputV.Rows / batchSize; IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g); //Input projections float scale = 1.0f / (float)(m_inputDim); IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb, scale), dims: new long[] { batchSize, seqLenQ, m_multiHeadNum, m_d }); IWeightTensor allK = g.View(g.Affine(inputK, K, Kb, scale), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d }); IWeightTensor allV = g.View(g.Affine(inputV, V, Vb, scale), dims: new long[] { batchSize, seqLenV, m_multiHeadNum, m_d }); //Multi-head attentions IWeightTensor Qs = g.View(g.Permute(allQ, 2, 0, 1, 3), dims: new long[] { m_multiHeadNum *batchSize, seqLenQ, m_d }); IWeightTensor Ks = g.View(g.Permute(allK, 2, 0, 3, 1), dims: new long[] { m_multiHeadNum *batchSize, m_d, seqLenK }); IWeightTensor Vs = g.View(g.Permute(allV, 2, 0, 1, 3), dims: new long[] { m_multiHeadNum *batchSize, seqLenV, m_d }); // Scaled softmax scale = 1.0f / (float)(m_d); IWeightTensor attn = g.MulBatch(Qs, Ks, m_multiHeadNum * batchSize, scale); IWeightTensor softmax = g.Softmax(attn, keyMask, inPlace: true); IWeightTensor o = g.View(g.MulBatch(softmax, Vs, m_multiHeadNum * batchSize), dims: new long[] { m_multiHeadNum, batchSize, seqLenQ, m_d }); IWeightTensor W = g.View(g.Permute(o, 1, 2, 0, 3), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d }); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); return(graph.Add(finalAttResults, inputQ)); } }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> /// public IWeightTensor Decode(IWeightTensor tgtInputs, IWeightTensor encOutputBatchFirst, IWeightTensor tgtSelfMask, IWeightTensor decEncAttnMask, IWeightTensor tgtDimMask, int batchSize, IComputeGraph g) { int tgtSeqLen = tgtInputs.Rows / batchSize; int srcSeqLen = encOutputBatchFirst.Rows / batchSize; using (IWeightTensor posEmbedding = g.BuildPositionMatrix(tgtSeqLen, m_inputDim)) { using (IWeightTensor posEmbeddingRepeat = g.RepeatRows(posEmbedding, batchSize, runGradient: false)) { tgtInputs = g.AddMul(posEmbeddingRepeat, tgtInputs, (float)Math.Sqrt(m_inputDim), runGradientW1: false, runGradientW2: true); } } tgtInputs = g.Dropout(tgtInputs, batchSize, m_dropoutRatio, inPlace: true); var tgtSelfMaskRep = g.View(tgtSelfMask, dims: new long[] { 1, batchSize, tgtSeqLen, tgtSeqLen }); var tgtSelfMaskRepExp = g.Expand(tgtSelfMaskRep, dims: new long[] { m_multiHeadNum, batchSize, tgtSeqLen, tgtSeqLen }); var decEncAttnMaskRep = g.View(decEncAttnMask, dims: new long[] { 1, batchSize, tgtSeqLen, srcSeqLen }); var decEncAttnMaskRepExp = g.Expand(decEncAttnMaskRep, dims: new long[] { m_multiHeadNum, batchSize, tgtSeqLen, srcSeqLen }); var tgtSelfMaskRepExpView = g.View(tgtSelfMaskRepExp, dims: new long[] { m_multiHeadNum *batchSize *tgtSeqLen, tgtSeqLen }); var decEncAttnMaskRepExpView = g.View(decEncAttnMaskRepExp, dims: new long[] { m_multiHeadNum *batchSize *tgtSeqLen, srcSeqLen }); tgtSelfMaskRep.Dispose(); tgtSelfMaskRepExp.Dispose(); decEncAttnMaskRep.Dispose(); decEncAttnMaskRepExp.Dispose(); using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder")) { for (int k = 0; k < m_selfAttns.Count; k++) { tgtInputs = g.MaskFill(tgtInputs, tgtDimMask, 0.0f); tgtInputs = m_selfAttns[k].Perform(tgtInputs, tgtInputs, tgtInputs, tgtSelfMaskRepExpView, batchSize, subg); tgtInputs = m_encAttns[k].Perform(tgtInputs, encOutputBatchFirst, encOutputBatchFirst, decEncAttnMaskRepExpView, batchSize, subg); tgtInputs = m_posFFNs[k].Perform(tgtInputs, batchSize, subg); } tgtInputs.UnbindFromComputeGraph(); } tgtInputs = layerNorm.Norm(tgtInputs, g); // tgtInputs = m_decoderFFLayer.Process(tgtInputs, batchSize, g); return(tgtInputs); }
public IWeightTensor Perform(IWeightTensor input, int batchSize, IComputeGraph graph) { using IComputeGraph g = graph.CreateSubGraph($"{m_name}_PositionwiseFeedForward"); var inputNorm = layerNorm2.Norm(input, g); //Feed forward IWeightTensor ffnResult = feedForwardLayer1.Process(inputNorm, batchSize, g); IWeightTensor reluFFNResult = g.Relu(ffnResult, inPlace: true); IWeightTensor ffn2Result = feedForwardLayer2.Process(reluFFNResult, batchSize, g); //Skip connection and layer normaliztion IWeightTensor addFFNResult = graph.Add(ffn2Result, input, inPlace: true); return(addFFNResult); }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> public IWeightTensor Encode(IWeightTensor inputs, int batchSize, IComputeGraph g, IWeightTensor srcSelfMask) { using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Encoder")) { for (int k = 0; k < m_encoders.Count; k++) { inputs = m_encoders[k].Perform(inputs, inputs, inputs, srcSelfMask, batchSize, subg); inputs = m_posFFNs[k].Perform(inputs, batchSize, subg); } inputs.UnbindFromComputeGraph(); } inputs = layerNorm.Norm(inputs, g); return(inputs); }
/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="inputQ">The input Q tensor</param> /// <param name="inputK">The input K tensor</param> /// <param name="inputV">The input V tensor</param> /// <param name="batchSize">Batch size of input data set</param> /// <param name="graph">The instance of computing graph</param> /// <returns>Transformered output tensor</returns> public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor inputK, IWeightTensor inputV, IWeightTensor keyMask, int batchSize, IComputeGraph graph) { using (IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention")) { int seqLenQ = inputQ.Rows / batchSize; // SeqLenK must be euqal to SeqLenV int seqLenK = inputK.Rows / batchSize; int seqLenV = inputV.Rows / batchSize; IWeightTensor inputQNorm = layerNorm1.Norm(inputQ, g); IWeightTensor inputKNorm = (inputK == inputQ) ? inputQNorm : inputK; // layerNorm1.Norm(inputK, g); IWeightTensor inputVNorm = (inputK == inputV) ? inputKNorm : inputV; // layerNorm1.Norm(inputV, g); //Input projections IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb), dims: new long[] { batchSize, seqLenQ, m_multiHeadNum, m_d }); IWeightTensor allK = g.View(g.Affine(inputKNorm, K, Kb), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d }); IWeightTensor allV = g.View(g.Affine(inputVNorm, V, Vb), dims: new long[] { batchSize, seqLenV, m_multiHeadNum, m_d }); //Multi-head attentions IWeightTensor Qs = g.View(g.Permute(allQ, 2, 0, 1, 3), dims: new long[] { m_multiHeadNum *batchSize, seqLenQ, m_d }); IWeightTensor Ks = g.View(g.Permute(allK, 2, 0, 3, 1), dims: new long[] { m_multiHeadNum *batchSize, m_d, seqLenK }); IWeightTensor Vs = g.View(g.Permute(allV, 2, 0, 1, 3), dims: new long[] { m_multiHeadNum *batchSize, seqLenV, m_d }); // Scaled softmax float scale = 1.0f / (float)Math.Sqrt(m_d); IWeightTensor attn = g.MulBatch(Qs, Ks, m_multiHeadNum * batchSize, scale); IWeightTensor attn2 = g.View(attn, dims: new long[] { m_multiHeadNum *batchSize *seqLenQ, seqLenK }); if (keyMask != null) { // attn2 = g.Add(attn2, mask, runGradient2: false); attn2 = g.MaskFill(attn2, keyMask, -1e9f); } IWeightTensor softmax = g.Softmax(attn2, inPlace: true); IWeightTensor softmax2 = g.View(softmax, dims: new long[] { m_multiHeadNum *batchSize, seqLenQ, seqLenK }); IWeightTensor o = g.View(g.MulBatch(softmax2, Vs, m_multiHeadNum * batchSize), dims: new long[] { m_multiHeadNum, batchSize, seqLenQ, m_d }); IWeightTensor W = g.View(g.Permute(o, 1, 2, 0, 3), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d }); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); return(graph.Add(finalAttResults, inputQ)); } }
/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="input">The input tensor</param> /// <param name="g">The instance of computing graph</param> /// <returns></returns> public IWeightTensor Perform(IWeightTensor input, int batchSize, IComputeGraph graph) { using (IComputeGraph g = graph.CreateSubGraph(m_name)) { int seqLen = input.Rows / batchSize; IWeightTensor nInput = layerNorm1.Norm(input, g); //Input projections IWeightTensor allQ = g.View(g.Affine(nInput, Q, Qb), batchSize, seqLen, m_multiHeadNum, m_d); IWeightTensor allK = g.View(g.Affine(nInput, K, Kb), batchSize, seqLen, m_multiHeadNum, m_d); IWeightTensor allV = g.View(g.Affine(nInput, V, Vb), batchSize, seqLen, m_multiHeadNum, m_d); //Multi-head attentions IWeightTensor Qs = g.View(g.Permute(allQ, 2, 0, 1, 3), m_multiHeadNum * batchSize, seqLen, m_d); IWeightTensor Ks = g.View(g.Permute(allK, 2, 0, 3, 1), m_multiHeadNum * batchSize, m_d, seqLen); IWeightTensor Vs = g.View(g.Permute(allV, 2, 0, 1, 3), m_multiHeadNum * batchSize, seqLen, m_d); // Scaled softmax float scale = 1.0f / (float)Math.Sqrt(m_d); IWeightTensor attn = g.MulBatch(Qs, Ks, m_multiHeadNum * batchSize, scale); IWeightTensor attn2 = g.View(attn, m_multiHeadNum * batchSize * seqLen, seqLen); IWeightTensor softmax = g.Softmax(attn2, inPlace: true); IWeightTensor softmax2 = g.View(softmax, m_multiHeadNum * batchSize, seqLen, seqLen); IWeightTensor o = g.View(g.MulBatch(softmax2, Vs, m_multiHeadNum * batchSize), m_multiHeadNum, batchSize, seqLen, m_d); IWeightTensor W = g.View(g.Permute(o, 1, 2, 0, 3), batchSize * seqLen, m_multiHeadNum * m_d); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); //Skip connection and layer normaliztion IWeightTensor normAddedAttResult = layerNorm2.AddNorm(finalAttResults, input, g); //Feed forward IWeightTensor ffnResult = feedForwardLayer1.Process(normAddedAttResult, batchSize, g); IWeightTensor reluFFNResult = g.Relu(ffnResult); IWeightTensor ffn2Result = feedForwardLayer2.Process(reluFFNResult, batchSize, g); //Skip connection and layer normaliztion IWeightTensor addFFNResult = graph.Add(ffn2Result, normAddedAttResult); return(addFFNResult); } }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> /// public IWeightTensor Decode(IWeightTensor tgtInputs, IWeightTensor encOutputBatchFirst, IWeightTensor tgtSelfMask, IWeightTensor srcTgtMask, int batchSize, IComputeGraph g) { using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder")) { for (int k = 0; k < m_selfAttns.Count; k++) { tgtInputs = m_selfAttns[k].Perform(tgtInputs, tgtInputs, tgtInputs, tgtSelfMask, batchSize, subg); tgtInputs = m_encAttns[k].Perform(tgtInputs, encOutputBatchFirst, encOutputBatchFirst, srcTgtMask, batchSize, subg); tgtInputs = m_posFFNs[k].Perform(tgtInputs, batchSize, subg); } tgtInputs = layerNorm.Norm(tgtInputs, subg); tgtInputs.UnbindFromComputeGraph(); } tgtInputs = m_decoderFFLayer.Process(tgtInputs, batchSize, g); return(tgtInputs); }
/// <summary> /// Transformer encoder /// </summary> /// <param name="rawInputs"></param> /// <param name="g"></param> /// <returns></returns> public IWeightTensor Encode(IWeightTensor inputs, IWeightTensor selfMask, IWeightTensor dimMask, int batchSize, IComputeGraph g) { int seqLen = inputs.Rows / batchSize; using (IWeightTensor posEmbedding = g.BuildPositionMatrix(seqLen, m_inputDim)) { using (IWeightTensor posEmbeddingRepeat = g.RepeatRows(posEmbedding, batchSize, runGradient: false)) { inputs = g.AddMul(posEmbeddingRepeat, inputs, (float)Math.Sqrt(m_inputDim), runGradientW1: false, runGradientW2: true); } } inputs = g.Dropout(inputs, batchSize, m_dropoutRatio, inPlace: true); var selfMaskRep = g.View(selfMask, dims: new long[] { 1, batchSize, seqLen, seqLen }); var multiHeadhSelfMaskRep = g.Expand(selfMaskRep, dims: new long[] { m_multiHeadNum, batchSize, seqLen, seqLen }); var multiHeadhSelfMaskRepView = g.View(multiHeadhSelfMaskRep, dims: new long[] { m_multiHeadNum *batchSize *seqLen, seqLen }); selfMaskRep.Dispose(); multiHeadhSelfMaskRep.Dispose(); using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Encoder")) { for (int k = 0; k < m_encoders.Count; k++) { inputs = g.MaskFill(inputs, dimMask, 0.0f); inputs = m_encoders[k].Perform(inputs, inputs, inputs, multiHeadhSelfMaskRepView, batchSize, subg); inputs = m_posFFNs[k].Perform(inputs, batchSize, subg); } inputs.UnbindFromComputeGraph(); } inputs = layerNorm.Norm(inputs, g); return(inputs); }
/// <summary> /// Scaled multi-heads attention component with skip connectioned feed forward layers /// </summary> /// <param name="inputQ">The input Q tensor</param> /// <param name="inputK">The input K tensor</param> /// <param name="inputV">The input V tensor</param> /// <param name="keyMask">The mask for softmax</param> /// <param name="batchSize">Batch size of input data set</param> /// <param name="graph">The instance of computing graph</param> /// <returns>Transformered output tensor</returns> public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor inputK, IWeightTensor inputV, IWeightTensor keyMask, int batchSize, IComputeGraph graph) { using (IComputeGraph g = graph.CreateSubGraph($"{m_name}_MultiHeadAttention")) { int seqLenQ = inputQ.Rows / batchSize; // SeqLenK must be euqal to SeqLenV int seqLenK = inputK.Rows / batchSize; int seqLenV = inputV.Rows / batchSize; IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g); if (inputK == inputQ) { inputK = inputQNorm; } if (inputV == inputQ) { inputV = inputQNorm; } //Input projections float scale = 1.0f; IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb, scale), dims: new long[] { batchSize, seqLenQ, m_multiHeadNum, m_d }); IWeightTensor allK = g.View(g.Affine(inputK, K, Kb, scale), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d }); IWeightTensor allV = g.View(g.Affine(inputV, V, Vb, scale), dims: new long[] { batchSize, seqLenV, m_multiHeadNum, m_d }); //Multi-head attentions IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, m_d }); IWeightTensor Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize *m_multiHeadNum, m_d, seqLenK }); IWeightTensor Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize *m_multiHeadNum, seqLenV, m_d }); // Scaled softmax scale = 1.0f / (float)(Math.Sqrt(m_d)); IWeightTensor attn = g.MulBatch(Qs, Ks, batchSize * m_multiHeadNum, scale); if (keyMask != null) { using (var keyMaskView = g.View(keyMask, runGradient: false, dims: new long[] { batchSize, 1, seqLenQ, seqLenK })) { using (var keyMaskViewExp = g.Expand(keyMaskView, runGradient: false, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK })) { using (var keyMaskViewExpConti = g.AsContiguous(keyMaskViewExp, runGradient: false)) { using (var keyMaskViewExpContiView = g.View(keyMaskViewExpConti, runGradient: false, dims: new long[] { batchSize *m_multiHeadNum, seqLenQ, seqLenK })) { attn = g.Add(attn, keyMaskViewExpContiView, runGradient1: true, runGradient2: false); } } } } } IWeightTensor softmax = g.Softmax(attn, inPlace: true); IWeightTensor o = g.View(g.MulBatch(softmax, Vs, batchSize * m_multiHeadNum), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d }); IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize *seqLenQ, m_multiHeadNum *m_d }); // Output projection IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); return(graph.Add(finalAttResults, inputQ)); } }