private torch.Tensor ForwardEmbedding(torch.Tensor tokens, torch.Tensor segmentLabels, torch.Tensor positions) { using var disposeScope = torch.NewDisposeScope(); var x = TokenEmbedding.forward(tokens); if (EmbedScale != null) { x.mul_(EmbedScale); } if (PositionalEmbedding != null) { var positionalEmbedding = PositionalEmbedding.forward(tokens, new Dictionary <string, object> { { PositionalEmbedding.PositionKey, positions } }); x.add_(positionalEmbedding); } if (SegmentEmbedding != null && segmentLabels.IsNotNull()) { var segmentEmbedding = SegmentEmbedding.forward(segmentLabels); x.add_(segmentEmbedding); } if (EmbeddingLayerNorm != null) { x = EmbeddingLayerNorm.forward(x); } x = EmbedTransfer.forward(x, (int)x.size()[x.size().Length - 1]); x = DropoutLayer.forward(x); return(x.MoveToOuterDisposeScope()); }
public torch.Tensor forward( torch.Tensor query, torch.Tensor key, torch.Tensor value, out torch.Tensor outAttentionWeights, torch.Tensor keyPaddingMask = null, Dictionary <string, Dictionary <string, torch.Tensor> > incrementalState = null, bool needWeights = true, bool staticKv = false, torch.Tensor attentionMask = null) { outAttentionWeights = null; if (query.IsNull() || query.size().Length != 3 || query.size(2) != _embeddingDim) { throw new ArgumentException("query must NOT be null and must be 3D in multi-head attention;" + "the last dimension should be the same as embedding dimension."); } using var disposeScope = torch.NewDisposeScope(); var qSize = query.size(); var tgtLen = qSize[0]; var batchSize = qSize[1]; var embedDim = qSize[2]; // Get saved state from incrementalState Dictionary <string, torch.Tensor> savedState = null; if (incrementalState != null) { savedState = GetInputBuffer(incrementalState); // previous time steps are cached - no need to recompute key and value if they are static. if (savedState.ContainsKey(PrevKeyKey) && savedState.ContainsKey(PrevValueKey) && staticKv) { if (_selfAttention || !_encoderDecoderAttention) { throw new ArgumentException( "prevKey and prevValue are only valid in encoder-decoder attention."); } key = value = null; } } // Calculate current qkv projection var(q, k, v) = QkvProjection(query, key, value); // Simulate using-statement by try-finally torch.Tensor attentionMaskPad = attentionMask?.alias(); torch.Tensor keyPaddingMaskPad = keyPaddingMask?.alias(); q.mul_(_scaling); if (_addBiasKv) { var kRepeat = KBias.repeat(1, batchSize, 1); var vRepeat = VBias.repeat(1, batchSize, 1); k = torch.cat(new List <torch.Tensor> { k, kRepeat }, dimension: 0); v = torch.cat(new List <torch.Tensor> { v, vRepeat }, dimension: 0); attentionMaskPad = PadMask(attentionMaskPad); keyPaddingMaskPad = PadMask(keyPaddingMaskPad); } q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1); k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1); v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1); if (savedState != null) { // saved states are stored with shape (batchSize, NumHeads, seqLen, HeadDim) if (savedState.ContainsKey(PrevKeyKey)) { var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim); k = staticKv ? prevKey : torch.cat(new List <torch.Tensor> { prevKey, k }, dimension: 1); } if (savedState.ContainsKey(PrevValueKey)) { var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim); v = staticKv ? prevValue : torch.cat(new List <torch.Tensor> { prevValue, v }, dimension: 1); } savedState[PrevKeyKey].Dispose(); savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim); savedState[PrevValueKey].Dispose(); savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim); SetInputBuffer(incrementalState, savedState); } Debug.Assert(k.IsNotNull() && v.IsNotNull()); var srcLen = k !.size(1); // This is part of a workaround to get around fork/join parallelism not supporting Optional types. if (keyPaddingMaskPad?.shape.Length == 0) { keyPaddingMaskPad = null; } Debug.Assert(keyPaddingMaskPad.IsNull() || (keyPaddingMaskPad.size(0) == batchSize && keyPaddingMaskPad.size(1) == srcLen)); if (_addZeroAttention) { srcLen += 1; var zeroPadSize = k.size(); zeroPadSize[1] = 1; var kZeros = k.new_zeros(zeroPadSize); var vZeros = v !.new_zeros(zeroPadSize); k = torch.cat(new List <torch.Tensor> { k, kZeros }, dimension: 1); v = torch.cat(new List <torch.Tensor> { v, vZeros }, dimension: 1); attentionMaskPad = PadMask(attentionMaskPad); keyPaddingMaskPad = PadMask(keyPaddingMaskPad); } var attentionWeights = torch.matmul(q, k.transpose(1, 2)); Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, srcLen })); if (attentionMaskPad.IsNotNull()) { attentionWeights.add_(attentionMaskPad.unsqueeze(0)); } if (keyPaddingMaskPad.IsNotNull()) { // Don't attend to pad symbols keyPaddingMaskPad = keyPaddingMaskPad.unsqueeze(1).unsqueeze(2); attentionWeights = attentionWeights .view(batchSize, _numHeads, tgtLen, srcLen) .masked_fill(keyPaddingMaskPad, float.NegativeInfinity) .view(batchSize * _numHeads, tgtLen, srcLen); } attentionWeights = torch.nn.functional.softmax(attentionWeights, dim: -1); attentionWeights = DropoutLayer.forward(attentionWeights); if (needWeights) { // Average attention weights over heads var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen); outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads); } var attention = torch.matmul(attentionWeights, v); Debug.Assert(attention.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, _headDim })); attention = attention.transpose(0, 1).contiguous().view(tgtLen, batchSize, embedDim); var attentionOutput = OutProjLinear.forward(attention); outAttentionWeights?.MoveToOuterDisposeScope(); return(attentionOutput.MoveToOuterDisposeScope()); }