예제 #1
0
        public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding, int groupId)
        {
            var contextTokens            = InsertCLSToken(sntPairBatch.GetSrcTokens(groupId));
            var originalSrcContextLength = BuildInTokens.PadSentences(contextTokens);
            var contextTokenIds          = modelMetaData.SrcVocab.GetWordIndex(contextTokens);

            IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding);

            int contextPaddedLen = contextTokens[0].Count;

            float[] contextCLSIdxs = new float[sntPairBatch.BatchSize];
            for (int j = 0; j < sntPairBatch.BatchSize; j++)
            {
                contextCLSIdxs[j] = j * contextPaddedLen;
            }

            IWeightTensor contextCLSOutput = computeGraph.IndexSelect(encContextOutput, contextCLSIdxs);

            return(contextCLSOutput);
        }
예제 #2
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions)
        {
            List <NetworkResult> nrs = new List <NetworkResult>();

            (IEncoder encoder, IWeightTensor srcEmbedding, List <IFeedForwardLayer> encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            var srcSnts            = sntPairBatch.GetSrcTokens(0);
            var originalSrcLengths = BuildInTokens.PadSentences(srcSnts);
            var srcTokensList      = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts);

            IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths);

            int srcSeqPaddedLen = srcSnts[0].Count;
            int batchSize       = srcSnts.Count;

            float[] clsIdxs = new float[batchSize];
            for (int i = 0; i < batchSize; i++)
            {
                for (int j = 0; j < srcSnts[i].Count; j++)
                {
                    if (srcSnts[i][j] == BuildInTokens.CLS)
                    {
                        clsIdxs[i] = i * srcSeqPaddedLen + j;
                        break;
                    }
                }
            }

            IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs);

            for (int i = 0; i < m_encoderFFLayer.Length; i++)
            {
                float         cost = 0.0f;
                NetworkResult nr   = new NetworkResult
                {
                    Output = new List <List <List <string> > >()
                };

                IWeightTensor ffLayer = encoderFFLayer[i].Process(clsWeightTensor, batchSize, computeGraph);
                using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true))
                {
                    if (isTraining)
                    {
                        var tgtSnts = sntPairBatch.GetTgtTokens(i);
                        for (int k = 0; k < batchSize; k++)
                        {
                            int   ix_targets_k_j = m_modelMetaData.ClsVocabs[i].GetWordIndex(tgtSnts[k][0]);
                            float score_k        = probs.GetWeightAt(new long[] { k, ix_targets_k_j });
                            cost += (float)-Math.Log(score_k);
                            probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j });
                        }

                        ffLayer.CopyWeightsToGradients(probs);

                        nr.Cost = cost / batchSize;
                    }
                    else
                    {
                        // Output "i"th target word
                        using var targetIdxTensor = computeGraph.Argmax(probs, 1);
                        float[]       targetIdx   = targetIdxTensor.ToWeightArray();
                        List <string> targetWords = m_modelMetaData.ClsVocabs[i].ConvertIdsToString(targetIdx.ToList());
                        nr.Output.Add(new List <List <string> >());

                        for (int k = 0; k < batchSize; k++)
                        {
                            nr.Output[0].Add(new List <string>());
                            nr.Output[0][k].Add(targetWords[k]);
                        }
                    }
                }

                nrs.Add(nr);
            }


            return(nrs);
        }
예제 #3
0
        /// <summary>
        /// Create input embedding from token embeddings, segment embeddings
        /// </summary>
        /// <param name="seqs"></param>
        /// <param name="g"></param>
        /// <param name="embeddingsTensor"></param>
        /// <param name="seqOriginalLengths"></param>
        /// <param name="segmentEmbedding"></param>
        /// <param name="vocab"></param>
        /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns>
        public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor,
                                                           IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false)
        {
            int batchSize = seqs.Count;
            int seqLen    = seqs[0].Count;

            float[]        idxs        = new float[batchSize * seqLen];
            float[]        segIdxs     = new float[batchSize * seqLen];
            List <float[]> tagIdxsList = new List <float[]>();

            //float[] tagIdxs = new float[batchSize * seqLen];

            for (int i = 0; i < batchSize; i++)
            {
                int        segIdx       = 0;
                List <int> currTagIdxs  = new List <int>();
                int        currTagLevel = 0;

                for (int j = 0; j < seqLen; j++)
                {
                    idxs[i * seqLen + j]    = seqs[i][j];
                    segIdxs[i * seqLen + j] = segIdx;

                    string token = vocab.GetString(seqs[i][j]);
                    if (token == BuildInTokens.SEP)
                    {
                        //A new segment
                        segIdx++;
                    }


                    if (enableTagEmbedding)
                    {
                        if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false)
                        {
                            if (token[1] == '/')
                            {
                                currTagLevel--;
                                currTagIdxs[currTagLevel] = -1;
                            }
                            else
                            {
                                //A new opening tag
                                while (tagIdxsList.Count <= currTagLevel)
                                {
                                    float[] tagIdxs = new float[batchSize * seqLen];
                                    Array.Fill(tagIdxs, -1.0f);
                                    tagIdxsList.Add(tagIdxs);
                                }

                                while (currTagIdxs.Count <= currTagLevel)
                                {
                                    currTagIdxs.Add(-1);
                                }

                                currTagIdxs[currTagLevel] = seqs[i][j];

                                currTagLevel++;
                            }
                        }
                        else
                        {
                            for (int k = 0; k < currTagLevel; k++)
                            {
                                tagIdxsList[k][i * seqLen + j] = currTagIdxs[k];

                                //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'");
                            }
                        }
                    }
                }
            }

            IWeightTensor tagEmbeddings = null;

            if (enableTagEmbedding)
            {
                for (int k = 0; k < tagIdxsList.Count; k++)
                {
                    var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true);
                    if (tagEmbeddings == null)
                    {
                        tagEmbeddings = tagEmbeddings_k;
                    }
                    else
                    {
                        tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k);
                    }
                }
            }

            IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs);

            if (scaleFactor != 1.0f)
            {
                embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true);
            }

            // Apply segment embeddings to the input sequence embeddings
            if (segmentEmbedding != null)
            {
                embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs));
            }

            if (tagEmbeddings != null)
            {
                embeddingRst = g.Add(embeddingRst, tagEmbeddings);
            }

            return(embeddingRst);
        }
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions)
        {
            (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            var srcSnts            = sntPairBatch.GetSrcTokens(0);
            var originalSrcLengths = BuildInTokens.PadSentences(srcSnts);
            var srcTokensList      = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts);

            IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths);

            List <NetworkResult> nrs = new List <NetworkResult>();
            int srcSeqPaddedLen      = srcSnts[0].Count;
            int batchSize            = srcSnts.Count;

            float[] clsIdxs = new float[batchSize];
            for (int i = 0; i < batchSize; i++)
            {
                for (int j = 0; j < srcSnts[i].Count; j++)
                {
                    if (srcSnts[i][j] == BuildInTokens.CLS)
                    {
                        clsIdxs[i] = i * srcSeqPaddedLen + j;
                        break;
                    }
                }
            }

            IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs);

            float         cost  = 0.0f;
            NetworkResult nrCLS = new NetworkResult
            {
                Output = new List <List <List <string> > >()
            };

            IWeightTensor ffLayer = encoderFFLayer.Process(clsWeightTensor, batchSize, computeGraph);

            using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true))
            {
                if (isTraining)
                {
                    var clsSnts = sntPairBatch.GetTgtTokens(0);
                    for (int k = 0; k < batchSize; k++)
                    {
                        int   ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(clsSnts[k][0]);
                        float score_k        = probs.GetWeightAt(new long[] { k, ix_targets_k_j });
                        cost += (float)-Math.Log(score_k);
                        probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j });
                    }

                    ffLayer.CopyWeightsToGradients(probs);

                    nrCLS.Cost = cost / batchSize;
                }
                else
                {
                    // Output "i"th target word
                    using var targetIdxTensor = computeGraph.Argmax(probs, 1);
                    float[]       targetIdx   = targetIdxTensor.ToWeightArray();
                    List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList());
                    nrCLS.Output.Add(new List <List <string> >());

                    for (int k = 0; k < batchSize; k++)
                    {
                        nrCLS.Output[0].Add(new List <string>());
                        nrCLS.Output[0][k].Add(targetWords[k]);
                    }
                }
            }

            // Reset networks
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);

            // Generate output decoder sentences
            var tgtSnts       = sntPairBatch.GetTgtTokens(1);
            var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts);

            NetworkResult nr = new NetworkResult();

            if (decoder is AttentionDecoder)
            {
                nr.Cost   = Decoder.DecodeAttentionLSTM(tgtTokensList, computeGraph, encOutput, decoder as AttentionDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, srcSnts.Count, isTraining);
                nr.Output = new List <List <List <string> > >
                {
                    m_modelMetaData.TgtVocab.ConvertIdsToString(tgtTokensList)
                };
            }
            else
            {
                if (isTraining)
                {
                    (var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, m_options.DropoutRatio, null, isTraining);
                    nr.Cost    = c;
                    nr.Output  = null;
                }
                else
                {
                    List <List <BeamSearchStatus> > beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList);
                    for (int i = 0; i < decodingOptions.MaxTgtSentLength; i++)
                    {
                        List <List <BeamSearchStatus> > batch2beam2seq = null; //(batch_size, beam_search_size)
                        try
                        {
                            foreach (var batchStatus in beam2batchStatus)
                            {
                                var batch2tgtTokens = Decoder.ExtractBatchTokens(batchStatus);
                                using var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}");
                                (var cost2, var bssSeqList) = Decoder.DecodeTransformer(batch2tgtTokens, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding,
                                                                                        originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining,
                                                                                        outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus);

                                bssSeqList     = Decoder.SwapBeamAndBatch(bssSeqList);
                                batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList);
                            }
                        }
                        catch (OutOfMemoryException)
                        {
                            GC.Collect();
                            Logger.WriteLine(Logger.Level.warn, $"We have out of memory while generating '{i}th' tokens, so terminate decoding for current sequences.");
                            break;
                        }

                        if (decodingOptions.BeamSearchSize > 1)
                        {
                            // Keep top N result and drop all others
                            for (int k = 0; k < batchSize; k++)
                            {
                                batch2beam2seq[k] = BeamSearch.GetTopNBSS(batch2beam2seq[k], decodingOptions.BeamSearchSize);
                            }
                        }


                        beam2batchStatus = Decoder.SwapBeamAndBatch(batch2beam2seq);
                        if (Decoder.AreAllSentsCompleted(beam2batchStatus))
                        {
                            break;
                        }
                    }

                    nr.Cost   = 0.0f;
                    nr.Output = m_modelMetaData.TgtVocab.ExtractTokens(beam2batchStatus);
                }
            }

            nr.RemoveDuplicatedEOS();

            nrs.Add(nrCLS);
            nrs.Add(nr);

            return(nrs);
        }