Example #1
0
        public override torch.Tensor forward(torch.Tensor x, int hiddenSize)
        {
            if (hiddenSize == SearchSpace.HiddenSizeChoices[SearchSpace.HiddenSizeChoices.Length - 1])
            {
                return(x);
            }
            var index = SearchSpace.HiddenSizeChoices.ToList().IndexOf(hiddenSize);

            return(index == -1
                ? x.alias()
                : HiddenTransfer[index].forward(x));
        }
Example #2
0
        private static torch.Tensor ForwardOneLayer(torch.Tensor input, torch.Tensor selfAttentionPaddingMask,
                                                    torch.nn.Module convLayer, torch.nn.Module layerNorm)
        {
            using var disposeScope = torch.NewDisposeScope();

            torch.Tensor x = selfAttentionPaddingMask.IsNull()
                ? input.alias()
                : input.masked_fill(selfAttentionPaddingMask.T.unsqueeze(-1), 0);

            var conv = convLayer.forward(x);

            conv.add_(input);
            var norm = layerNorm.forward(conv);

            return(norm.MoveToOuterDisposeScope());
        }
        private torch.Tensor ForwardOneLayer(torch.Tensor input, torch.Tensor paddingMask,
                                             int i, int blockPerLayer, ref int blockIndex)
        {
            using var disposeScope = torch.NewDisposeScope();

            var x     = input.alias(); // avoid scope mess
            var layer = Layers[i];

            if (i % blockPerLayer == 0)
            {
                x = (HiddenTransferList[blockIndex] as HiddenTransfer).forward(x, HiddenSizePerBlock[blockIndex], true);
            }

            x = (layer as TransformerCell).forward(x, null, paddingMask);

            if ((i + 1) % blockPerLayer == 0)
            {
                x = (HiddenTransferList[blockIndex] as HiddenTransfer).forward(x, HiddenSizePerBlock[blockIndex], false);
                ++blockIndex;
            }

            return(x.MoveToOuterDisposeScope());
        }
 public virtual torch.Tensor forward(torch.Tensor input, Dictionary <string, object> param = null)
 {
     return(input.alias());
 }
        public torch.Tensor forward(
            torch.Tensor query,
            torch.Tensor key,
            torch.Tensor value,
            out torch.Tensor outAttentionWeights,
            torch.Tensor keyPaddingMask = null,
            Dictionary <string, Dictionary <string, torch.Tensor> > incrementalState = null,
            bool needWeights           = true,
            bool staticKv              = false,
            torch.Tensor attentionMask = null)
        {
            outAttentionWeights = null;

            if (query.IsNull() || query.size().Length != 3 || query.size(2) != _embeddingDim)
            {
                throw new ArgumentException("query must NOT be null and must be 3D in multi-head attention;" +
                                            "the last dimension should be the same as embedding dimension.");
            }

            using var disposeScope = torch.NewDisposeScope();

            var qSize     = query.size();
            var tgtLen    = qSize[0];
            var batchSize = qSize[1];
            var embedDim  = qSize[2];

            // Get saved state from incrementalState
            Dictionary <string, torch.Tensor> savedState = null;

            if (incrementalState != null)
            {
                savedState = GetInputBuffer(incrementalState);

                // previous time steps are cached - no need to recompute key and value if they are static.
                if (savedState.ContainsKey(PrevKeyKey) && savedState.ContainsKey(PrevValueKey) && staticKv)
                {
                    if (_selfAttention || !_encoderDecoderAttention)
                    {
                        throw new ArgumentException(
                                  "prevKey and prevValue are only valid in encoder-decoder attention.");
                    }

                    key = value = null;
                }
            }

            // Calculate current qkv projection
            var(q, k, v) = QkvProjection(query, key, value);

            // Simulate using-statement by try-finally
            torch.Tensor attentionMaskPad  = attentionMask?.alias();
            torch.Tensor keyPaddingMaskPad = keyPaddingMask?.alias();
            q.mul_(_scaling);

            if (_addBiasKv)
            {
                var kRepeat = KBias.repeat(1, batchSize, 1);
                var vRepeat = VBias.repeat(1, batchSize, 1);
                k = torch.cat(new List <torch.Tensor> {
                    k, kRepeat
                }, dimension: 0);
                v = torch.cat(new List <torch.Tensor> {
                    v, vRepeat
                }, dimension: 0);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1);
            k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);
            v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);

            if (savedState != null)
            {
                // saved states are stored with shape (batchSize, NumHeads, seqLen, HeadDim)
                if (savedState.ContainsKey(PrevKeyKey))
                {
                    var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim);
                    k = staticKv
                        ? prevKey
                        : torch.cat(new List <torch.Tensor> {
                        prevKey, k
                    }, dimension: 1);
                }

                if (savedState.ContainsKey(PrevValueKey))
                {
                    var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim);
                    v = staticKv
                        ? prevValue
                        : torch.cat(new List <torch.Tensor> {
                        prevValue, v
                    }, dimension: 1);
                }

                savedState[PrevKeyKey].Dispose();
                savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim);
                savedState[PrevValueKey].Dispose();
                savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim);

                SetInputBuffer(incrementalState, savedState);
            }

            Debug.Assert(k.IsNotNull() && v.IsNotNull());
            var srcLen = k !.size(1);

            // This is part of a workaround to get around fork/join parallelism not supporting Optional types.
            if (keyPaddingMaskPad?.shape.Length == 0)
            {
                keyPaddingMaskPad = null;
            }
            Debug.Assert(keyPaddingMaskPad.IsNull() ||
                         (keyPaddingMaskPad.size(0) == batchSize && keyPaddingMaskPad.size(1) == srcLen));

            if (_addZeroAttention)
            {
                srcLen += 1;
                var zeroPadSize = k.size();
                zeroPadSize[1] = 1;
                var kZeros = k.new_zeros(zeroPadSize);
                var vZeros = v !.new_zeros(zeroPadSize);
                k = torch.cat(new List <torch.Tensor> {
                    k, kZeros
                }, dimension: 1);
                v = torch.cat(new List <torch.Tensor> {
                    v, vZeros
                }, dimension: 1);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            var attentionWeights = torch.matmul(q, k.transpose(1, 2));

            Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, srcLen }));

            if (attentionMaskPad.IsNotNull())
            {
                attentionWeights.add_(attentionMaskPad.unsqueeze(0));
            }

            if (keyPaddingMaskPad.IsNotNull())
            {
                // Don't attend to pad symbols
                keyPaddingMaskPad = keyPaddingMaskPad.unsqueeze(1).unsqueeze(2);

                attentionWeights = attentionWeights
                                   .view(batchSize, _numHeads, tgtLen, srcLen)
                                   .masked_fill(keyPaddingMaskPad, float.NegativeInfinity)
                                   .view(batchSize * _numHeads, tgtLen, srcLen);
            }

            attentionWeights = torch.nn.functional.softmax(attentionWeights, dim: -1);
            attentionWeights = DropoutLayer.forward(attentionWeights);

            if (needWeights)
            {
                // Average attention weights over heads
                var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen);
                outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads);
            }

            var attention = torch.matmul(attentionWeights, v);

            Debug.Assert(attention.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, _headDim }));
            attention = attention.transpose(0, 1).contiguous().view(tgtLen, batchSize, embedDim);
            var attentionOutput = OutProjLinear.forward(attention);

            outAttentionWeights?.MoveToOuterDisposeScope();
            return(attentionOutput.MoveToOuterDisposeScope());
        }
Example #6
0
 public override torch.Tensor forward(torch.Tensor x, Dictionary <string, object> param = null)
 {
     return(x.alias());
 }