示例#1
0
        public List <string> GetAllTokens(bool keepBuildInTokens = true)
        {
            if (keepBuildInTokens)
            {
                return(Items);
            }
            else
            {
                List <string> results = new List <string>();
                foreach (var item in Items)
                {
                    if (BuildInTokens.IsPreDefinedToken(item) == false)
                    {
                        results.Add(item);
                    }
                }

                return(results);
            }
        }
示例#2
0
        /// <summary>
        /// Load vocabulary from given files
        /// </summary>
        public Vocab(string vocabFilePath)
        {
            Logger.WriteLine("Loading vocabulary files...");
            string[] vocab = File.ReadAllLines(vocabFilePath);

            CreateIndex();

            //Build word index for both source and target sides
            int q = 3;

            foreach (string line in vocab)
            {
                string[] items = line.Split('\t');
                string   word  = items[0];

                if (BuildInTokens.IsPreDefinedToken(word) == false)
                {
                    Items.Add(word);
                    WordToIndex[word] = q;
                    IndexToWord[q]    = word;
                    q++;
                }
            }
        }
示例#3
0
        /// <summary>
        /// Create input embedding from token embeddings, segment embeddings
        /// </summary>
        /// <param name="seqs"></param>
        /// <param name="g"></param>
        /// <param name="embeddingsTensor"></param>
        /// <param name="seqOriginalLengths"></param>
        /// <param name="segmentEmbedding"></param>
        /// <param name="vocab"></param>
        /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns>
        public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor,
                                                           IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false)
        {
            int batchSize = seqs.Count;
            int seqLen    = seqs[0].Count;

            float[]        idxs        = new float[batchSize * seqLen];
            float[]        segIdxs     = new float[batchSize * seqLen];
            List <float[]> tagIdxsList = new List <float[]>();

            //float[] tagIdxs = new float[batchSize * seqLen];

            for (int i = 0; i < batchSize; i++)
            {
                int        segIdx       = 0;
                List <int> currTagIdxs  = new List <int>();
                int        currTagLevel = 0;

                for (int j = 0; j < seqLen; j++)
                {
                    idxs[i * seqLen + j]    = seqs[i][j];
                    segIdxs[i * seqLen + j] = segIdx;

                    string token = vocab.GetString(seqs[i][j]);
                    if (token == BuildInTokens.SEP)
                    {
                        //A new segment
                        segIdx++;
                    }


                    if (enableTagEmbedding)
                    {
                        if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false)
                        {
                            if (token[1] == '/')
                            {
                                currTagLevel--;
                                currTagIdxs[currTagLevel] = -1;
                            }
                            else
                            {
                                //A new opening tag
                                while (tagIdxsList.Count <= currTagLevel)
                                {
                                    float[] tagIdxs = new float[batchSize * seqLen];
                                    Array.Fill(tagIdxs, -1.0f);
                                    tagIdxsList.Add(tagIdxs);
                                }

                                while (currTagIdxs.Count <= currTagLevel)
                                {
                                    currTagIdxs.Add(-1);
                                }

                                currTagIdxs[currTagLevel] = seqs[i][j];

                                currTagLevel++;
                            }
                        }
                        else
                        {
                            for (int k = 0; k < currTagLevel; k++)
                            {
                                tagIdxsList[k][i * seqLen + j] = currTagIdxs[k];

                                //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'");
                            }
                        }
                    }
                }
            }

            IWeightTensor tagEmbeddings = null;

            if (enableTagEmbedding)
            {
                for (int k = 0; k < tagIdxsList.Count; k++)
                {
                    var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true);
                    if (tagEmbeddings == null)
                    {
                        tagEmbeddings = tagEmbeddings_k;
                    }
                    else
                    {
                        tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k);
                    }
                }
            }

            IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs);

            if (scaleFactor != 1.0f)
            {
                embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true);
            }

            // Apply segment embeddings to the input sequence embeddings
            if (segmentEmbedding != null)
            {
                embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs));
            }

            if (tagEmbeddings != null)
            {
                embeddingRst = g.Add(embeddingRst, tagEmbeddings);
            }

            return(embeddingRst);
        }
示例#4
0
        private static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{Utils.GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            SeqLabelOptions opts = new SeqLabelOptions();
            ArgParser argParser = new ArgParser(args, opts);

            if (!opts.ConfigFilePath.IsNullOrEmpty())
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject<SeqLabelOptions>(File.ReadAllText(opts.ConfigFilePath));
            }

            DecodingOptions decodingOptions = opts.CreateDecodingOptions();

            SeqLabel sl = null;

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if ( opts.Task == ModeEnums.Train )
            {
                // Load train corpus
                SeqLabelingCorpus trainCorpus = new SeqLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTrainSentLength);

                // Load valid corpus
                List<SeqLabelingCorpus> validCorpusList = new List<SeqLabelingCorpus>();
                if (!opts.ValidCorpusPaths.IsNullOrEmpty())
                {
                    string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';');
                    foreach (var validCorpusPath in validCorpusPathList)
                    {
                        validCorpusList.Add(new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTestSentLength));
                    }
                }

                // Load or build vocabulary
                Vocab srcVocab = null;
                Vocab tgtVocab = null;
                if (!opts.SrcVocab.IsNullOrEmpty() && !opts.TgtVocab.IsNullOrEmpty() )
                {
                    // Vocabulary files are specified, so we load them
                    srcVocab = new Vocab(opts.SrcVocab);
                    tgtVocab = new Vocab(opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                IOptimizer optimizer = Misc.CreateOptimizer(opts);

                // Create metrics
                List<IMetric> metrics = new List<IMetric>();
                foreach (string word in tgtVocab.Items)
                {
                    if (BuildInTokens.IsPreDefinedToken(word) == false)
                    {
                        metrics.Add(new SequenceLabelFscoreMetric(word));
                    }
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SeqLabel(opts, srcVocab: srcVocab, clsVocab: tgtVocab);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SeqLabel(opts);
                }

                // Add event handler for monitoring
                sl.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher;

                // Kick off training
                sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions);

            }
            else if ( opts.Task == ModeEnums.Valid )
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'");

                // Load valid corpus
                SeqLabelingCorpus validCorpus = new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxTestSentLength);
                (Vocab srcVocab, Vocab tgtVocab) = validCorpus.BuildVocabs();

                // Create metrics
                List<IMetric> metrics = new List<IMetric>();
                foreach (string word in tgtVocab.Items)
                {
                    if (BuildInTokens.IsPreDefinedToken(word) == false)
                    {
                        metrics.Add(new SequenceLabelFscoreMetric(word));
                    }
                }

                sl = new SeqLabel(opts);
                sl.Valid(validCorpus: validCorpus, metrics: metrics, decodingOptions: decodingOptions);
            }
            else if ( opts.Task == ModeEnums.Test )
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SeqLabel(opts);

                List<string> outputLines = new List<string>();
                string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    var nrs = sl.Test<SeqLabelingCorpusBatch>(ConstructInputTokens(line.Trim().Split(' ').ToList()), null, decodingOptions: decodingOptions);
                    outputLines.AddRange(nrs[0].Output[0].Select(x => string.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }