public List <string> GetAllTokens(bool keepBuildInTokens = true) { if (keepBuildInTokens) { return(Items); } else { List <string> results = new List <string>(); foreach (var item in Items) { if (BuildInTokens.IsPreDefinedToken(item) == false) { results.Add(item); } } return(results); } }
/// <summary> /// Load vocabulary from given files /// </summary> public Vocab(string vocabFilePath) { Logger.WriteLine("Loading vocabulary files..."); string[] vocab = File.ReadAllLines(vocabFilePath); CreateIndex(); //Build word index for both source and target sides int q = 3; foreach (string line in vocab) { string[] items = line.Split('\t'); string word = items[0]; if (BuildInTokens.IsPreDefinedToken(word) == false) { Items.Add(word); WordToIndex[word] = q; IndexToWord[q] = word; q++; } } }
/// <summary> /// Create input embedding from token embeddings, segment embeddings /// </summary> /// <param name="seqs"></param> /// <param name="g"></param> /// <param name="embeddingsTensor"></param> /// <param name="seqOriginalLengths"></param> /// <param name="segmentEmbedding"></param> /// <param name="vocab"></param> /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns> public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor, IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false) { int batchSize = seqs.Count; int seqLen = seqs[0].Count; float[] idxs = new float[batchSize * seqLen]; float[] segIdxs = new float[batchSize * seqLen]; List <float[]> tagIdxsList = new List <float[]>(); //float[] tagIdxs = new float[batchSize * seqLen]; for (int i = 0; i < batchSize; i++) { int segIdx = 0; List <int> currTagIdxs = new List <int>(); int currTagLevel = 0; for (int j = 0; j < seqLen; j++) { idxs[i * seqLen + j] = seqs[i][j]; segIdxs[i * seqLen + j] = segIdx; string token = vocab.GetString(seqs[i][j]); if (token == BuildInTokens.SEP) { //A new segment segIdx++; } if (enableTagEmbedding) { if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false) { if (token[1] == '/') { currTagLevel--; currTagIdxs[currTagLevel] = -1; } else { //A new opening tag while (tagIdxsList.Count <= currTagLevel) { float[] tagIdxs = new float[batchSize * seqLen]; Array.Fill(tagIdxs, -1.0f); tagIdxsList.Add(tagIdxs); } while (currTagIdxs.Count <= currTagLevel) { currTagIdxs.Add(-1); } currTagIdxs[currTagLevel] = seqs[i][j]; currTagLevel++; } } else { for (int k = 0; k < currTagLevel; k++) { tagIdxsList[k][i * seqLen + j] = currTagIdxs[k]; //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'"); } } } } } IWeightTensor tagEmbeddings = null; if (enableTagEmbedding) { for (int k = 0; k < tagIdxsList.Count; k++) { var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true); if (tagEmbeddings == null) { tagEmbeddings = tagEmbeddings_k; } else { tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k); } } } IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs); if (scaleFactor != 1.0f) { embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true); } // Apply segment embeddings to the input sequence embeddings if (segmentEmbedding != null) { embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs)); } if (tagEmbeddings != null) { embeddingRst = g.Add(embeddingRst, tagEmbeddings); } return(embeddingRst); }
private static void Main(string[] args) { ShowOptions(args); Logger.LogFile = $"{nameof(SeqLabelConsole)}_{Utils.GetTimeStamp(DateTime.Now)}.log"; //Parse command line SeqLabelOptions opts = new SeqLabelOptions(); ArgParser argParser = new ArgParser(args, opts); if (!opts.ConfigFilePath.IsNullOrEmpty()) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject<SeqLabelOptions>(File.ReadAllText(opts.ConfigFilePath)); } DecodingOptions decodingOptions = opts.CreateDecodingOptions(); SeqLabel sl = null; //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if ( opts.Task == ModeEnums.Train ) { // Load train corpus SeqLabelingCorpus trainCorpus = new SeqLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTrainSentLength); // Load valid corpus List<SeqLabelingCorpus> validCorpusList = new List<SeqLabelingCorpus>(); if (!opts.ValidCorpusPaths.IsNullOrEmpty()) { string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';'); foreach (var validCorpusPath in validCorpusPathList) { validCorpusList.Add(new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTestSentLength)); } } // Load or build vocabulary Vocab srcVocab = null; Vocab tgtVocab = null; if (!opts.SrcVocab.IsNullOrEmpty() && !opts.TgtVocab.IsNullOrEmpty() ) { // Vocabulary files are specified, so we load them srcVocab = new Vocab(opts.SrcVocab); tgtVocab = new Vocab(opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer IOptimizer optimizer = Misc.CreateOptimizer(opts); // Create metrics List<IMetric> metrics = new List<IMetric>(); foreach (string word in tgtVocab.Items) { if (BuildInTokens.IsPreDefinedToken(word) == false) { metrics.Add(new SequenceLabelFscoreMetric(word)); } } if (File.Exists(opts.ModelFilePath) == false) { //New training sl = new SeqLabel(opts, srcVocab: srcVocab, clsVocab: tgtVocab); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); sl = new SeqLabel(opts); } // Add event handler for monitoring sl.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher; // Kick off training sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions); } else if ( opts.Task == ModeEnums.Valid ) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'"); // Load valid corpus SeqLabelingCorpus validCorpus = new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxTestSentLength); (Vocab srcVocab, Vocab tgtVocab) = validCorpus.BuildVocabs(); // Create metrics List<IMetric> metrics = new List<IMetric>(); foreach (string word in tgtVocab.Items) { if (BuildInTokens.IsPreDefinedToken(word) == false) { metrics.Add(new SequenceLabelFscoreMetric(word)); } } sl = new SeqLabel(opts); sl.Valid(validCorpus: validCorpus, metrics: metrics, decodingOptions: decodingOptions); } else if ( opts.Task == ModeEnums.Test ) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model sl = new SeqLabel(opts); List<string> outputLines = new List<string>(); string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { var nrs = sl.Test<SeqLabelingCorpusBatch>(ConstructInputTokens(line.Trim().Split(' ').ToList()), null, decodingOptions: decodingOptions); outputLines.AddRange(nrs[0].Output[0].Select(x => string.Join(" ", x))); } File.WriteAllLines(opts.OutputFile, outputLines); } else { argParser.Usage(); } }