private static torch.Tensor PadMask(torch.Tensor tensor)
        {
            if (tensor.IsNull())
            {
                return(null);
            }

            using var zeros = tensor.new_zeros(tensor.size(0), 1);
            return(torch.cat(new List <torch.Tensor> {
                tensor, zeros
            }, dimension: 1));
        }
        public override torch.Tensor forward(torch.Tensor input, Dictionary <string, object> param = null)
        {
            using var disposeScope = torch.NewDisposeScope();

            ParseArguments(param, out var incrementalState, out var positions);

            if (positions.IsNull())
            {
                positions = incrementalState
                    ? torch.tensor(PadPositionIndex + input.size(1))
                    : MakePositions(input, PadTokenIndex);
            }

            var embedding = Embedding.forward(positions);

            return(embedding.MoveToOuterDisposeScope());
        }
        public torch.Tensor forward(
            torch.Tensor query,
            torch.Tensor key,
            torch.Tensor value,
            out torch.Tensor outAttentionWeights,
            torch.Tensor keyPaddingMask = null,
            Dictionary <string, Dictionary <string, torch.Tensor> > incrementalState = null,
            bool needWeights           = true,
            bool staticKv              = false,
            torch.Tensor attentionMask = null)
        {
            outAttentionWeights = null;

            if (query.IsNull() || query.size().Length != 3 || query.size(2) != _embeddingDim)
            {
                throw new ArgumentException("query must NOT be null and must be 3D in multi-head attention;" +
                                            "the last dimension should be the same as embedding dimension.");
            }

            using var disposeScope = torch.NewDisposeScope();

            var qSize     = query.size();
            var tgtLen    = qSize[0];
            var batchSize = qSize[1];
            var embedDim  = qSize[2];

            // Get saved state from incrementalState
            Dictionary <string, torch.Tensor> savedState = null;

            if (incrementalState != null)
            {
                savedState = GetInputBuffer(incrementalState);

                // previous time steps are cached - no need to recompute key and value if they are static.
                if (savedState.ContainsKey(PrevKeyKey) && savedState.ContainsKey(PrevValueKey) && staticKv)
                {
                    if (_selfAttention || !_encoderDecoderAttention)
                    {
                        throw new ArgumentException(
                                  "prevKey and prevValue are only valid in encoder-decoder attention.");
                    }

                    key = value = null;
                }
            }

            // Calculate current qkv projection
            var(q, k, v) = QkvProjection(query, key, value);

            // Simulate using-statement by try-finally
            torch.Tensor attentionMaskPad  = attentionMask?.alias();
            torch.Tensor keyPaddingMaskPad = keyPaddingMask?.alias();
            q.mul_(_scaling);

            if (_addBiasKv)
            {
                var kRepeat = KBias.repeat(1, batchSize, 1);
                var vRepeat = VBias.repeat(1, batchSize, 1);
                k = torch.cat(new List <torch.Tensor> {
                    k, kRepeat
                }, dimension: 0);
                v = torch.cat(new List <torch.Tensor> {
                    v, vRepeat
                }, dimension: 0);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1);
            k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);
            v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);

            if (savedState != null)
            {
                // saved states are stored with shape (batchSize, NumHeads, seqLen, HeadDim)
                if (savedState.ContainsKey(PrevKeyKey))
                {
                    var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim);
                    k = staticKv
                        ? prevKey
                        : torch.cat(new List <torch.Tensor> {
                        prevKey, k
                    }, dimension: 1);
                }

                if (savedState.ContainsKey(PrevValueKey))
                {
                    var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim);
                    v = staticKv
                        ? prevValue
                        : torch.cat(new List <torch.Tensor> {
                        prevValue, v
                    }, dimension: 1);
                }

                savedState[PrevKeyKey].Dispose();
                savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim);
                savedState[PrevValueKey].Dispose();
                savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim);

                SetInputBuffer(incrementalState, savedState);
            }

            Debug.Assert(k.IsNotNull() && v.IsNotNull());
            var srcLen = k !.size(1);

            // This is part of a workaround to get around fork/join parallelism not supporting Optional types.
            if (keyPaddingMaskPad?.shape.Length == 0)
            {
                keyPaddingMaskPad = null;
            }
            Debug.Assert(keyPaddingMaskPad.IsNull() ||
                         (keyPaddingMaskPad.size(0) == batchSize && keyPaddingMaskPad.size(1) == srcLen));

            if (_addZeroAttention)
            {
                srcLen += 1;
                var zeroPadSize = k.size();
                zeroPadSize[1] = 1;
                var kZeros = k.new_zeros(zeroPadSize);
                var vZeros = v !.new_zeros(zeroPadSize);
                k = torch.cat(new List <torch.Tensor> {
                    k, kZeros
                }, dimension: 1);
                v = torch.cat(new List <torch.Tensor> {
                    v, vZeros
                }, dimension: 1);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            var attentionWeights = torch.matmul(q, k.transpose(1, 2));

            Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, srcLen }));

            if (attentionMaskPad.IsNotNull())
            {
                attentionWeights.add_(attentionMaskPad.unsqueeze(0));
            }

            if (keyPaddingMaskPad.IsNotNull())
            {
                // Don't attend to pad symbols
                keyPaddingMaskPad = keyPaddingMaskPad.unsqueeze(1).unsqueeze(2);

                attentionWeights = attentionWeights
                                   .view(batchSize, _numHeads, tgtLen, srcLen)
                                   .masked_fill(keyPaddingMaskPad, float.NegativeInfinity)
                                   .view(batchSize * _numHeads, tgtLen, srcLen);
            }

            attentionWeights = torch.nn.functional.softmax(attentionWeights, dim: -1);
            attentionWeights = DropoutLayer.forward(attentionWeights);

            if (needWeights)
            {
                // Average attention weights over heads
                var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen);
                outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads);
            }

            var attention = torch.matmul(attentionWeights, v);

            Debug.Assert(attention.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, _headDim }));
            attention = attention.transpose(0, 1).contiguous().view(tgtLen, batchSize, embedDim);
            var attentionOutput = OutProjLinear.forward(attention);

            outAttentionWeights?.MoveToOuterDisposeScope();
            return(attentionOutput.MoveToOuterDisposeScope());
        }