示例#1
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            List <int>    originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int           srcSeqPaddedLen    = srcSnts[0].Count;
            int           batchSize          = srcSnts.Count;
            IWeightTensor srcSelfMask        = m_shuffleType == ShuffleEnums.NoPaddingInSrc ? null : MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, DeviceIds[deviceIdIdx]); // The length of source sentences are same in a single mini-batch, so we don't have source mask.

            // Encoding input source sentences
            IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, srcSelfMask, posEmbedding, originalSrcLengths);

            if (srcSelfMask != null)
            {
                srcSelfMask.Dispose();
            }

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    return(DecodeTransformer(tgtSnts, computeGraph, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining));
                }
                else
                {
                    for (int i = 0; i < m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            DecodeTransformer(tgtSnts, g, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining);

                            bool allSntsEnd = true;
                            for (int j = 0; j < tgtSnts.Count; j++)
                            {
                                if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS)
                                {
                                    allSntsEnd = false;
                                    break;
                                }
                            }

                            if (allSntsEnd)
                            {
                                break;
                            }
                        }
                    }

                    RemoveDuplicatedEOS(tgtSnts);
                    return(0.0f);
                }
            }
        }
示例#2
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int        srcSeqPaddedLen    = srcSnts[0].Count;
            int        batchSize          = srcSnts.Count;

            IWeightTensor encSelfMask = MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, deviceIdIdx);
            IWeightTensor encDimMask  = MaskUtils.BuildPadDimMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, m_modelMetaData.HiddenDim, deviceIdIdx);

            // Encoding input source sentences
            IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, encSelfMask, encDimMask);

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                    int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                    IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(computeGraph, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                    return(DecodeTransformer(tgtSnts, computeGraph, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining));
                }
                else
                {
                    for (int i = 0; i < m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            List <int>    originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts);
                            int           tgtSeqPaddedLen    = tgtSnts[0].Count;
                            IWeightTensor encDecMask         = MaskUtils.BuildSrcTgtMask(g, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx);

                            DecodeTransformer(tgtSnts, g, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining);

                            bool allSntsEnd = true;
                            for (int j = 0; j < tgtSnts.Count; j++)
                            {
                                if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS)
                                {
                                    allSntsEnd = false;
                                    break;
                                }
                            }

                            if (allSntsEnd)
                            {
                                break;
                            }
                        }
                    }
                    return(0.0f);
                }
            }
        }
示例#3
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            var(encoder, decoder, srcEmbedding, tgtEmbedding, posEmbedding) = this.GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            var originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            var srcSeqPaddedLen    = srcSnts[0].Count;
            var batchSize          = srcSnts.Count;
            var srcSelfMask        = this.m_shuffleType == ShuffleEnums.NoPaddingInSrc ? null : MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, this.DeviceIds[deviceIdIdx]); // The length of source sentences are same in a single mini-batch, so we don't have source mask.

            // Encoding input source sentences
            var encOutput = this.Encode(computeGraph, srcSnts, encoder, srcEmbedding, srcSelfMask, posEmbedding, originalSrcLengths);

            srcSelfMask?.Dispose();

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(this.DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    return(this.DecodeTransformer(tgtSnts, computeGraph, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, this.DeviceIds[deviceIdIdx], originalSrcLengths, isTraining));
                }
                else
                {
                    for (var i = 0; i < this.m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            this.DecodeTransformer(tgtSnts, g, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, this.DeviceIds[deviceIdIdx], originalSrcLengths, isTraining);

                            var allSntsEnd = true;
                            foreach (var t in tgtSnts)
                            {
                                if (t[^ 1] != ParallelCorpus.EOS)