Example #1
0
        private float DecodeTransformer(List <List <string> > outInputSeqs, IComputeGraph g, IWeightTensor encOutputs, IWeightTensor encMask, TransformerDecoder decoder,
                                        IWeightTensor tgtEmbedding, int batchSize, int deviceId, bool isTraining = true)
        {
            float cost = 0.0f;

            var originalInputLengths = ParallelCorpus.PadSentences(outInputSeqs);
            int tgtSeqLen            = outInputSeqs[0].Count;

            IWeightTensor tgtDimMask = MaskUtils.BuildPadDimMask(g, tgtSeqLen, originalInputLengths, m_modelMetaData.HiddenDim, deviceId);

            using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, originalInputLengths, deviceId))
            {
                List <IWeightTensor> inputs = new List <IWeightTensor>();
                for (int i = 0; i < batchSize; i++)
                {
                    for (int j = 0; j < tgtSeqLen; j++)
                    {
                        int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(outInputSeqs[i][j], logUnk: true);
                        inputs.Add(g.PeekRow(tgtEmbedding, ix_targets_k));
                    }
                }

                IWeightTensor tgtInputEmbeddings = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0];
                IWeightTensor decOutput          = decoder.Decode(tgtInputEmbeddings, encOutputs, tgtSelfTriMask, encMask, tgtDimMask, batchSize, g);

                decOutput = g.Mul(decOutput, g.Transpose(tgtEmbedding));

                using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true))
                {
                    if (isTraining)
                    {
                        var leftShiftInputSeqs    = ParallelCorpus.LeftShiftSnts(outInputSeqs, ParallelCorpus.EOS);
                        var originalOutputLengths = ParallelCorpus.PadSentences(leftShiftInputSeqs, tgtSeqLen);

                        for (int i = 0; i < batchSize; i++)
                        {
                            for (int j = 0; j < tgtSeqLen; j++)
                            {
                                using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false))
                                {
                                    if (j < originalOutputLengths[i])
                                    {
                                        int   ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true);
                                        float score_i_j      = probs_i_j.GetWeightAt(ix_targets_i_j);
                                        if (j < originalOutputLengths[i])
                                        {
                                            cost += (float)-Math.Log(score_i_j);
                                        }

                                        probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j);
                                    }
                                    else
                                    {
                                        probs_i_j.CleanWeight();
                                    }
                                }
                            }
                        }

                        decOutput.CopyWeightsToGradients(probs);
                    }
                    else
                    {
                        // Output "i"th target word
                        int[]         targetIdx   = g.Argmax(probs, 1);
                        List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList());

                        for (int i = 0; i < batchSize; i++)
                        {
                            outInputSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]);
                        }
                    }
                }
            }


            return(cost);
        }
Example #2
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int        srcSeqPaddedLen    = srcSnts[0].Count;
            int        batchSize          = srcSnts.Count;

            IWeightTensor encSelfMask = MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, deviceIdIdx);
            IWeightTensor encDimMask  = MaskUtils.BuildPadDimMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, m_modelMetaData.HiddenDim, deviceIdIdx);

            // Encoding input source sentences
            IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, encSelfMask, encDimMask);

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                    int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                    IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(computeGraph, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                    return(DecodeTransformer(tgtSnts, computeGraph, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining));
                }
                else
                {
                    for (int i = 0; i < m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                            int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                            IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(g, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                            DecodeTransformer(tgtSnts, g, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining);

                            bool allSntsEnd = true;
                            for (int j = 0; j < tgtSnts.Count; j++)
                            {
                                if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS)
                                {
                                    allSntsEnd = false;
                                    break;
                                }
                            }

                            if (allSntsEnd)
                            {
                                break;
                            }
                        }
                    }
                    return(0.0f);
                }
            }
        }