Example #1
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int        srcSeqPaddedLen    = srcSnts[0].Count;
            int        batchSize          = srcSnts.Count;

            IWeightTensor encSelfMask = MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, deviceIdIdx);
            IWeightTensor encDimMask  = MaskUtils.BuildPadDimMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, m_modelMetaData.HiddenDim, deviceIdIdx);

            // Encoding input source sentences
            IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, encSelfMask, encDimMask);

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                    int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                    IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(computeGraph, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                    return(DecodeTransformer(tgtSnts, computeGraph, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining));
                }
                else
                {
                    for (int i = 0; i < m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                            int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                            IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(g, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                            DecodeTransformer(tgtSnts, g, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining);

                            bool allSntsEnd = true;
                            for (int j = 0; j < tgtSnts.Count; j++)
                            {
                                if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS)
                                {
                                    allSntsEnd = false;
                                    break;
                                }
                            }

                            if (allSntsEnd)
                            {
                                break;
                            }
                        }
                    }
                    return(0.0f);
                }
            }
        }
Example #2
0
        private float DecodeTransformer(List <List <string> > tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder,
                                        IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, int batchSize, int deviceId, List <int> srcOriginalLenghts, bool isTraining = true)
        {
            float cost = 0.0f;

            var tgtOriginalLengths = ParallelCorpus.PadSentences(tgtSeqs);
            int tgtSeqLen          = tgtSeqs[0].Count;
            int srcSeqLen          = encOutputs.Rows / batchSize;

            using (IWeightTensor srcTgtMask = MaskUtils.BuildSrcTgtMask(g, srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, deviceId))
            {
                using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, tgtOriginalLengths, deviceId))
                {
                    List <IWeightTensor> inputs = new List <IWeightTensor>();
                    for (int i = 0; i < batchSize; i++)
                    {
                        for (int j = 0; j < tgtSeqLen; j++)
                        {
                            int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSeqs[i][j], logUnk: true);

                            var emb = g.PeekRow(tgtEmbedding, ix_targets_k, runGradients: j < tgtOriginalLengths[i] ? true : false);

                            inputs.Add(emb);
                        }
                    }

                    IWeightTensor inputEmbs = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0];

                    inputEmbs = AddPositionEmbedding(g, posEmbedding, batchSize, tgtSeqLen, inputEmbs);

                    IWeightTensor decOutput = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g);

                    using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true))
                    {
                        if (isTraining)
                        {
                            var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS);
                            for (int i = 0; i < batchSize; i++)
                            {
                                for (int j = 0; j < tgtSeqLen; j++)
                                {
                                    using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false))
                                    {
                                        if (j < tgtOriginalLengths[i])
                                        {
                                            int   ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true);
                                            float score_i_j      = probs_i_j.GetWeightAt(ix_targets_i_j);

                                            cost += (float)-Math.Log(score_i_j);

                                            probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j);
                                        }
                                        else
                                        {
                                            probs_i_j.CleanWeight();
                                        }
                                    }
                                }
                            }

                            decOutput.CopyWeightsToGradients(probs);
                        }
                        //if (isTraining)
                        //{
                        //    var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS);
                        //    int[] targetIds = new int[batchSize * tgtSeqLen];
                        //    int ids = 0;
                        //    for (int i = 0; i < batchSize; i++)
                        //    {
                        //        for (int j = 0; j < tgtSeqLen; j++)
                        //        {
                        //            targetIds[ids] = j < tgtOriginalLengths[i] ? m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true) : -1;
                        //            ids++;
                        //        }
                        //    }

                        //    cost += g.UpdateCost(probs, targetIds);
                        //    decOutput.CopyWeightsToGradients(probs);
                        //}
                        else
                        {
                            // Output "i"th target word
                            int[]         targetIdx   = g.Argmax(probs, 1);
                            List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList());

                            for (int i = 0; i < batchSize; i++)
                            {
                                tgtSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]);
                            }
                        }
                    }
                }
            }



            return(cost);
        }