Beispiel #1
0
        public void MergeVocab(Vocab srcVocab)
        {
            int maxId = 0;

            foreach (var pair in WordToIndex)
            {
                if (pair.Value > maxId)
                {
                    maxId = pair.Value;
                }
            }

            maxId++;
            foreach (var pair in srcVocab.WordToIndex)
            {
                if (WordToIndex.ContainsKey(pair.Key) == false)
                {
                    WordToIndex.Add(pair.Key, maxId);
                    IndexToWord.Add(maxId, pair.Key);
                    Items.Add(pair.Key);
                    maxId++;
                }
            }
        }
Beispiel #2
0
        /// <summary>
        /// Create input embedding from token embeddings, segment embeddings
        /// </summary>
        /// <param name="seqs"></param>
        /// <param name="g"></param>
        /// <param name="embeddingsTensor"></param>
        /// <param name="seqOriginalLengths"></param>
        /// <param name="segmentEmbedding"></param>
        /// <param name="vocab"></param>
        /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns>
        public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor,
                                                           IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false)
        {
            int batchSize = seqs.Count;
            int seqLen    = seqs[0].Count;

            float[]        idxs        = new float[batchSize * seqLen];
            float[]        segIdxs     = new float[batchSize * seqLen];
            List <float[]> tagIdxsList = new List <float[]>();

            //float[] tagIdxs = new float[batchSize * seqLen];

            for (int i = 0; i < batchSize; i++)
            {
                int        segIdx       = 0;
                List <int> currTagIdxs  = new List <int>();
                int        currTagLevel = 0;

                for (int j = 0; j < seqLen; j++)
                {
                    idxs[i * seqLen + j]    = seqs[i][j];
                    segIdxs[i * seqLen + j] = segIdx;

                    string token = vocab.GetString(seqs[i][j]);
                    if (token == BuildInTokens.SEP)
                    {
                        //A new segment
                        segIdx++;
                    }


                    if (enableTagEmbedding)
                    {
                        if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false)
                        {
                            if (token[1] == '/')
                            {
                                currTagLevel--;
                                currTagIdxs[currTagLevel] = -1;
                            }
                            else
                            {
                                //A new opening tag
                                while (tagIdxsList.Count <= currTagLevel)
                                {
                                    float[] tagIdxs = new float[batchSize * seqLen];
                                    Array.Fill(tagIdxs, -1.0f);
                                    tagIdxsList.Add(tagIdxs);
                                }

                                while (currTagIdxs.Count <= currTagLevel)
                                {
                                    currTagIdxs.Add(-1);
                                }

                                currTagIdxs[currTagLevel] = seqs[i][j];

                                currTagLevel++;
                            }
                        }
                        else
                        {
                            for (int k = 0; k < currTagLevel; k++)
                            {
                                tagIdxsList[k][i * seqLen + j] = currTagIdxs[k];

                                //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'");
                            }
                        }
                    }
                }
            }

            IWeightTensor tagEmbeddings = null;

            if (enableTagEmbedding)
            {
                for (int k = 0; k < tagIdxsList.Count; k++)
                {
                    var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true);
                    if (tagEmbeddings == null)
                    {
                        tagEmbeddings = tagEmbeddings_k;
                    }
                    else
                    {
                        tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k);
                    }
                }
            }

            IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs);

            if (scaleFactor != 1.0f)
            {
                embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true);
            }

            // Apply segment embeddings to the input sequence embeddings
            if (segmentEmbedding != null)
            {
                embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs));
            }

            if (tagEmbeddings != null)
            {
                embeddingRst = g.Add(embeddingRst, tagEmbeddings);
            }

            return(embeddingRst);
        }