Beispiel #1
0
        /// <summary>
        /// Convert a tensor of token indices to a string.
        /// Can optionally remove BPE symbols or escape "&lt;unk&gt;" words.
        /// </summary>
        public string Tensor2String(torch.Tensor tensor, string bpeSymbol = null, bool escapeUnk = false)
        {
            if (tensor.IsNull())
            {
                return(string.Empty);
            }
            bpeSymbol ??= "";

            List <string> subStrings;

            if (tensor.dim() == 2)
            {
                subStrings = Enumerable.Range(0, (int)tensor.shape[0])
                             .Select(i => Tensor2String(tensor[i], bpeSymbol, escapeUnk))
                             .ToList();
                return(string.Join("\n", subStrings));
            }

            subStrings = Enumerable.Range(0, (int)tensor.shape[0])
                         .Select(i => _symbols[i])
                         .ToList();
            var sentence = string.Join(" ", subStrings);

            return(ProcessBpeSymbol(sentence, bpeSymbol));
        }
        private static torch.Tensor PadMask(torch.Tensor tensor)
        {
            if (tensor.IsNull())
            {
                return(null);
            }

            using var zeros = tensor.new_zeros(tensor.size(0), 1);
            return(torch.cat(new List <torch.Tensor> {
                tensor, zeros
            }, dimension: 1));
        }
Beispiel #3
0
        private static torch.Tensor ForwardOneLayer(torch.Tensor input, torch.Tensor selfAttentionPaddingMask,
                                                    torch.nn.Module convLayer, torch.nn.Module layerNorm)
        {
            using var disposeScope = torch.NewDisposeScope();

            torch.Tensor x = selfAttentionPaddingMask.IsNull()
                ? input.alias()
                : input.masked_fill(selfAttentionPaddingMask.T.unsqueeze(-1), 0);

            var conv = convLayer.forward(x);

            conv.add_(input);
            var norm = layerNorm.forward(conv);

            return(norm.MoveToOuterDisposeScope());
        }
        private (torch.Tensor, torch.Tensor, torch.Tensor) QkvProjection(
            torch.Tensor query, torch.Tensor key, torch.Tensor value)
        {
            using var disposeScope = torch.NewDisposeScope();

            torch.Tensor q = null;
            torch.Tensor k = null;
            torch.Tensor v = null;
            if (_selfAttention)
            {
                q = QProjection.forward(query);
                k = KProjection.forward(query);
                v = VProjection.forward(query);
            }
            else if (_encoderDecoderAttention)
            {
                q = QProjection.forward(query);
                if (key.IsNull())
                {
                    k = v = null;
                }
                else
                {
                    k = KProjection.forward(key);
                    v = VProjection.forward(key);
                }
            }
            else
            {
                q = QProjection.forward(query);
                k = KProjection.forward(key);
                v = VProjection.forward(value);
            }

            return(q.MoveToOuterDisposeScope(), k.MoveToOuterDisposeScope(), v.MoveToOuterDisposeScope());
        }
        public torch.Tensor forward(
            torch.Tensor query,
            torch.Tensor key,
            torch.Tensor value,
            out torch.Tensor outAttentionWeights,
            torch.Tensor keyPaddingMask = null,
            Dictionary <string, Dictionary <string, torch.Tensor> > incrementalState = null,
            bool needWeights           = true,
            bool staticKv              = false,
            torch.Tensor attentionMask = null)
        {
            outAttentionWeights = null;

            if (query.IsNull() || query.size().Length != 3 || query.size(2) != _embeddingDim)
            {
                throw new ArgumentException("query must NOT be null and must be 3D in multi-head attention;" +
                                            "the last dimension should be the same as embedding dimension.");
            }

            using var disposeScope = torch.NewDisposeScope();

            var qSize     = query.size();
            var tgtLen    = qSize[0];
            var batchSize = qSize[1];
            var embedDim  = qSize[2];

            // Get saved state from incrementalState
            Dictionary <string, torch.Tensor> savedState = null;

            if (incrementalState != null)
            {
                savedState = GetInputBuffer(incrementalState);

                // previous time steps are cached - no need to recompute key and value if they are static.
                if (savedState.ContainsKey(PrevKeyKey) && savedState.ContainsKey(PrevValueKey) && staticKv)
                {
                    if (_selfAttention || !_encoderDecoderAttention)
                    {
                        throw new ArgumentException(
                                  "prevKey and prevValue are only valid in encoder-decoder attention.");
                    }

                    key = value = null;
                }
            }

            // Calculate current qkv projection
            var(q, k, v) = QkvProjection(query, key, value);

            // Simulate using-statement by try-finally
            torch.Tensor attentionMaskPad  = attentionMask?.alias();
            torch.Tensor keyPaddingMaskPad = keyPaddingMask?.alias();
            q.mul_(_scaling);

            if (_addBiasKv)
            {
                var kRepeat = KBias.repeat(1, batchSize, 1);
                var vRepeat = VBias.repeat(1, batchSize, 1);
                k = torch.cat(new List <torch.Tensor> {
                    k, kRepeat
                }, dimension: 0);
                v = torch.cat(new List <torch.Tensor> {
                    v, vRepeat
                }, dimension: 0);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            q = q.view(tgtLen, batchSize * _numHeads, _headDim).transpose_(0, 1);
            k = k?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);
            v = v?.view(-1, batchSize * _numHeads, _headDim).transpose_(0, 1);

            if (savedState != null)
            {
                // saved states are stored with shape (batchSize, NumHeads, seqLen, HeadDim)
                if (savedState.ContainsKey(PrevKeyKey))
                {
                    var prevKey = savedState[PrevKeyKey].view(batchSize * _numHeads, -1, _headDim);
                    k = staticKv
                        ? prevKey
                        : torch.cat(new List <torch.Tensor> {
                        prevKey, k
                    }, dimension: 1);
                }

                if (savedState.ContainsKey(PrevValueKey))
                {
                    var prevValue = savedState[PrevValueKey].view(batchSize * _numHeads, -1, _headDim);
                    v = staticKv
                        ? prevValue
                        : torch.cat(new List <torch.Tensor> {
                        prevValue, v
                    }, dimension: 1);
                }

                savedState[PrevKeyKey].Dispose();
                savedState[PrevKeyKey] = k?.view(batchSize, _numHeads, -1, _headDim);
                savedState[PrevValueKey].Dispose();
                savedState[PrevValueKey] = v?.view(batchSize, _numHeads, -1, _headDim);

                SetInputBuffer(incrementalState, savedState);
            }

            Debug.Assert(k.IsNotNull() && v.IsNotNull());
            var srcLen = k !.size(1);

            // This is part of a workaround to get around fork/join parallelism not supporting Optional types.
            if (keyPaddingMaskPad?.shape.Length == 0)
            {
                keyPaddingMaskPad = null;
            }
            Debug.Assert(keyPaddingMaskPad.IsNull() ||
                         (keyPaddingMaskPad.size(0) == batchSize && keyPaddingMaskPad.size(1) == srcLen));

            if (_addZeroAttention)
            {
                srcLen += 1;
                var zeroPadSize = k.size();
                zeroPadSize[1] = 1;
                var kZeros = k.new_zeros(zeroPadSize);
                var vZeros = v !.new_zeros(zeroPadSize);
                k = torch.cat(new List <torch.Tensor> {
                    k, kZeros
                }, dimension: 1);
                v = torch.cat(new List <torch.Tensor> {
                    v, vZeros
                }, dimension: 1);
                attentionMaskPad  = PadMask(attentionMaskPad);
                keyPaddingMaskPad = PadMask(keyPaddingMaskPad);
            }

            var attentionWeights = torch.matmul(q, k.transpose(1, 2));

            Debug.Assert(attentionWeights.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, srcLen }));

            if (attentionMaskPad.IsNotNull())
            {
                attentionWeights.add_(attentionMaskPad.unsqueeze(0));
            }

            if (keyPaddingMaskPad.IsNotNull())
            {
                // Don't attend to pad symbols
                keyPaddingMaskPad = keyPaddingMaskPad.unsqueeze(1).unsqueeze(2);

                attentionWeights = attentionWeights
                                   .view(batchSize, _numHeads, tgtLen, srcLen)
                                   .masked_fill(keyPaddingMaskPad, float.NegativeInfinity)
                                   .view(batchSize * _numHeads, tgtLen, srcLen);
            }

            attentionWeights = torch.nn.functional.softmax(attentionWeights, dim: -1);
            attentionWeights = DropoutLayer.forward(attentionWeights);

            if (needWeights)
            {
                // Average attention weights over heads
                var weightsView = attentionWeights.view(batchSize, _numHeads, tgtLen, srcLen);
                outAttentionWeights = weightsView.sum(dim: 1).div_(_numHeads);
            }

            var attention = torch.matmul(attentionWeights, v);

            Debug.Assert(attention.size().SequenceEqual(new[] { batchSize *_numHeads, tgtLen, _headDim }));
            attention = attention.transpose(0, 1).contiguous().view(tgtLen, batchSize, embedDim);
            var attentionOutput = OutProjLinear.forward(attention);

            outAttentionWeights?.MoveToOuterDisposeScope();
            return(attentionOutput.MoveToOuterDisposeScope());
        }