/// <summary> /// Run forward part on given single device /// </summary> /// <param name="g">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side. In training mode, it inputs target tokens, otherwise, it outputs target tokens generated by decoder</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> private float RunForwardOnSingleDevice(IComputeGraph g, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining) { (IEncoder encoder, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(g.GetWeightFactory(), srcSnts.Count); List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts); int seqLen = srcSnts[0].Count; int batchSize = srcSnts.Count; // Encoding input source sentences IWeightTensor encOutput = Encode(g, srcSnts, encoder, srcEmbedding, null, posEmbedding, originalSrcLengths); IWeightTensor ffLayer = decoderFFLayer.Process(encOutput, batchSize, g); IWeightTensor ffLayerBatch = g.TransposeBatch(ffLayer, batchSize); float cost = 0.0f; using (IWeightTensor probs = g.Softmax(ffLayerBatch, runGradients: false, inPlace: true)) { if (isTraining) { //Calculate loss for each word in the batch for (int k = 0; k < batchSize; k++) { for (int j = 0; j < seqLen; j++) { using (IWeightTensor probs_k_j = g.PeekRow(probs, k * seqLen + j, runGradients: false)) { int ix_targets_k_j = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSnts[k][j]); float score_k = probs_k_j.GetWeightAt(ix_targets_k_j); cost += (float)-Math.Log(score_k); probs_k_j.SetWeightAt(score_k - 1, ix_targets_k_j); } } ////CRF part //using (var probs_k = g.PeekRow(probs, k * seqLen, seqLen, runGradients: false)) //{ // var weights_k = probs_k.ToWeightArray(); // var crfOutput_k = m_crfDecoder.ForwardBackward(seqLen, weights_k); // int[] trueTags = new int[seqLen]; // for (int j = 0; j < seqLen; j++) // { // trueTags[j] = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSnts[k][j]); // } // m_crfDecoder.UpdateBigramTransition(seqLen, crfOutput_k, trueTags); //} } ffLayerBatch.CopyWeightsToGradients(probs); } else { // CRF decoder //for (int k = 0; k < batchSize; k++) //{ // //CRF part // using (var probs_k = g.PeekRow(probs, k * seqLen, seqLen, runGradients: false)) // { // var weights_k = probs_k.ToWeightArray(); // var crfOutput_k = m_crfDecoder.DecodeNBestCRF(weights_k, seqLen, 1); // var targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(crfOutput_k[0].ToList()); // tgtSnts.Add(targetWords); // } //} // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int k = 0; k < batchSize; k++) { tgtSnts[k] = targetWords.GetRange(k * seqLen, seqLen); } } } return(cost); }
public void Valid(ParallelCorpus validCorpus, List <IMetric> metrics) { RunValid(validCorpus, RunForwardOnSingleDevice, metrics, true); }
/// <summary> /// Build vocabulary from training corpus /// </summary> /// <param name="trainCorpus"></param> /// <param name="minFreq"></param> public Vocab(ParallelCorpus trainCorpus, int minFreq = 1) { Logger.WriteLine($"Building vocabulary from given training corpus."); // count up all words Dictionary <string, int> s_d = new Dictionary <string, int>(); Dictionary <string, int> t_d = new Dictionary <string, int>(); CreateIndex(); foreach (SntPairBatch sntPairBatch in trainCorpus) { foreach (SntPair sntPair in sntPairBatch.SntPairs) { var item = sntPair.SrcSnt; for (int i = 0, n = item.Length; i < n; i++) { var txti = item[i]; if (s_d.Keys.Contains(txti)) { s_d[txti] += 1; } else { s_d.Add(txti, 1); } } var item2 = sntPair.TgtSnt; for (int i = 0, n = item2.Length; i < n; i++) { var txti = item2[i]; if (t_d.Keys.Contains(txti)) { t_d[txti] += 1; } else { t_d.Add(txti, 1); } } } } var q = 3; foreach (var ch in s_d) { if (ch.Value >= minFreq) { // add word to vocab SrcWordToIndex[ch.Key] = q; m_srcIndexToWord[q] = ch.Key; m_srcVocab.Add(ch.Key); q++; } } Logger.WriteLine($"Source language Max term id = '{q}'"); q = 3; foreach (var ch in t_d) { if (ch.Value >= minFreq) { // add word to vocab TgtWordToIndex[ch.Key] = q; m_tgtIndexToWord[q] = ch.Key; m_tgtVocab.Add(ch.Key); q++; } } Logger.WriteLine($"Target language Max term id = '{q}'"); }
private static void Main(string[] args) { Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log"; ShowOptions(args); //Parse command line Options opts = new Options(); ArgParser argParser = new ArgParser(args, opts); if (string.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } AttentionSeq2Seq ss = null; ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); DecoderTypeEnums decoderType = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType); ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Load valid corpus ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics List <IMetric> metrics = new List <IMetric> { new BleuMetric(), new LengthRatioMetric() }; if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath)) { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable, maxTgtSntSize: opts.MaxSentLength); } else { // Load or build vocabulary Vocab vocab = null; if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } //New training ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType, maxTgtSntSize: opts.MaxSentLength, enableCoverageModel: opts.EnableCoverageModel); } // Add event handler for monitoring ss.IterationDone += ss_IterationDone; // Kick off training ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Create metrics List <IMetric> metrics = new List <IMetric> { new BleuMetric(), new LengthRatioMetric() }; // Load valid corpus ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); ss.Valid(validCorpus: validCorpus, metrics: metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); List <string> outputLines = new List <string>(); string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { if (opts.BeamSearch > 1) { // Below support beam search List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch); outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x))); } else { var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList())); outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x))); } } File.WriteAllLines(opts.OutputTestFile, outputLines); } else if (mode == ModeEnums.VisualizeNetwork) { ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType, enableCoverageModel: opts.EnableCoverageModel); ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath); } else if (mode == ModeEnums.DumpVocab) { ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab); } else { argParser.Usage(); } }
/// <summary> /// Given input sentence and generate output sentence by seq2seq model with beam search /// </summary> /// <param name="input"></param> /// <param name="beamSearchSize"></param> /// <param name="maxOutputLength"></param> /// <returns></returns> public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding) = GetNetworksOnDeviceAt(-1); List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input); int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Construct beam search status list List <BeamSearchStatus> bssList = new List <BeamSearchStatus>(); BeamSearchStatus bss = new BeamSearchStatus(); bss.OutputIds.Add((int)SENTTAGS.START); bss.CTs = rnnDecoder.GetCTs(); bss.HTs = rnnDecoder.GetHTs(); bssList.Add(bss); IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, posEmbedding, null); AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>(); bool finished = false; int outputLength = 0; while (finished == false && outputLength < maxOutputLength) { finished = true; for (int i = 0; i < bssList.Count; i++) { bss = bssList[i]; if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END) { newBSSList.Add(bss); } else if (bss.OutputIds.Count > maxOutputLength) { newBSSList.Add(bss); } else { finished = false; int ix_input = bss.OutputIds[bss.OutputIds.Count - 1]; rnnDecoder.SetCTs(bss.CTs); rnnDecoder.SetHTs(bss.HTs); IWeightTensor x = g.PeekRow(tgtEmbedding, ix_input); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); using (IWeightTensor probs = g.Softmax(eOutput)) { List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize); for (int j = 0; j < preds.Count; j++) { BeamSearchStatus newBSS = new BeamSearchStatus(); newBSS.OutputIds.AddRange(bss.OutputIds); newBSS.OutputIds.Add(preds[j]); newBSS.CTs = rnnDecoder.GetCTs(); newBSS.HTs = rnnDecoder.GetHTs(); float score = probs.GetWeightAt(preds[j]); newBSS.Score = bss.Score; newBSS.Score += (float)(-Math.Log(score)); //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6); //newBSS.Score /= (float)lengthPenalty; newBSSList.Add(newBSS); } } } } bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize); newBSSList.Clear(); outputLength++; } // Convert output target word ids to real string List <List <string> > results = new List <List <string> >(); for (int i = 0; i < bssList.Count; i++) { results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds)); } return(results); }
/// <summary> /// Decode output sentences in training /// </summary> /// <param name="outputSnts">In training mode, they are golden target sentences, otherwise, they are target sentences generated by the decoder</param> /// <param name="g"></param> /// <param name="encOutputs"></param> /// <param name="decoder"></param> /// <param name="decoderFFLayer"></param> /// <param name="tgtEmbedding"></param> /// <returns></returns> private float DecodeAttentionLSTM(List <List <string> > outputSnts, IComputeGraph g, IWeightTensor encOutputs, AttentionDecoder decoder, IWeightTensor tgtEmbedding, int batchSize, bool isTraining = true) { float cost = 0.0f; int[] ix_inputs = new int[batchSize]; for (int i = 0; i < ix_inputs.Length; i++) { ix_inputs[i] = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[i][0]); } // Initialize variables accoridng to current mode List <int> originalOutputLengths = isTraining ? ParallelCorpus.PadSentences(outputSnts) : null; int seqLen = isTraining ? outputSnts[0].Count : 64; float dropoutRatio = isTraining ? m_dropoutRatio : 0.0f; HashSet <int> setEndSentId = isTraining ? null : new HashSet <int>(); // Pre-process for attention model AttentionPreProcessResult attPreProcessResult = decoder.PreProcess(encOutputs, batchSize, g); for (int i = 1; i < seqLen; i++) { //Get embedding for all sentence in the batch at position i List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int j = 0; j < batchSize; j++) { inputs.Add(g.PeekRow(tgtEmbedding, ix_inputs[j])); } IWeightTensor inputsM = g.ConcatRows(inputs); //Decode output sentence at position i IWeightTensor eOutput = decoder.Decode(inputsM, attPreProcessResult, batchSize, g); //Softmax for output using (IWeightTensor probs = g.Softmax(eOutput, runGradients: false, inPlace: true)) { if (isTraining) { //Calculate loss for each word in the batch for (int k = 0; k < batchSize; k++) { using (IWeightTensor probs_k = g.PeekRow(probs, k, runGradients: false)) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[k][i]); float score_k = probs_k.GetWeightAt(ix_targets_k); if (i < originalOutputLengths[k]) { var lcost = (float)-Math.Log(score_k); if (float.IsNaN(lcost)) { throw new ArithmeticException($"Score = '{score_k}' Cost = Nan at index '{i}' word '{outputSnts[k][i]}', Output Sentence = '{String.Join(" ", outputSnts[k])}'"); } else { cost += lcost; } } probs_k.SetWeightAt(score_k - 1, ix_targets_k); ix_inputs[k] = ix_targets_k; } } eOutput.CopyWeightsToGradients(probs); } //if (isTraining) //{ // //Calculate loss for each word in the batch // int[] targetIds = new int[batchSize]; // int ids = 0; // for (int k = 0; k < batchSize; k++) // { // int targetsId_k = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[k][i]); // targetIds[ids] = i < originalOutputLengths[k] ? targetsId_k : -1; // ix_inputs[k] = targetsId_k; // ids++; // } // cost += g.UpdateCost(probs, targetIds); // eOutput.CopyWeightsToGradients(probs); //} else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int j = 0; j < targetWords.Count; j++) { if (setEndSentId.Contains(j) == false) { outputSnts[j].Add(targetWords[j]); if (targetWords[j] == ParallelCorpus.EOS) { setEndSentId.Add(j); } } } if (setEndSentId.Count == batchSize) { // All target sentences in current batch are finished, so we exit. break; } ix_inputs = targetIdx; } } } return(cost); }
private float DecodeTransformer(List <List <string> > tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, int batchSize, int deviceId, List <int> srcOriginalLenghts, bool isTraining = true) { float cost = 0.0f; var tgtOriginalLengths = ParallelCorpus.PadSentences(tgtSeqs); int tgtSeqLen = tgtSeqs[0].Count; int srcSeqLen = encOutputs.Rows / batchSize; using (IWeightTensor srcTgtMask = MaskUtils.BuildSrcTgtMask(g, srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, deviceId)) { using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, tgtOriginalLengths, deviceId)) { List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSeqs[i][j], logUnk: true); var emb = g.PeekRow(tgtEmbedding, ix_targets_k, runGradients: j < tgtOriginalLengths[i] ? true : false); inputs.Add(emb); } } IWeightTensor inputEmbs = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0]; inputEmbs = AddPositionEmbedding(g, posEmbedding, batchSize, tgtSeqLen, inputEmbs); IWeightTensor decOutput = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g); using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true)) { if (isTraining) { var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false)) { if (j < tgtOriginalLengths[i]) { int ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true); float score_i_j = probs_i_j.GetWeightAt(ix_targets_i_j); cost += (float)-Math.Log(score_i_j); probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j); } else { probs_i_j.CleanWeight(); } } } } decOutput.CopyWeightsToGradients(probs); } //if (isTraining) //{ // var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); // int[] targetIds = new int[batchSize * tgtSeqLen]; // int ids = 0; // for (int i = 0; i < batchSize; i++) // { // for (int j = 0; j < tgtSeqLen; j++) // { // targetIds[ids] = j < tgtOriginalLengths[i] ? m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true) : -1; // ids++; // } // } // cost += g.UpdateCost(probs, targetIds); // decOutput.CopyWeightsToGradients(probs); //} else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int i = 0; i < batchSize; i++) { tgtSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]); } } } } } return(cost); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; IWeightTensor srcSelfMask = m_shuffleType == ShuffleEnums.NoPaddingInSrc ? null : MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, DeviceIds[deviceIdIdx]); // The length of source sentences are same in a single mini-batch, so we don't have source mask. // Encoding input source sentences IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, srcSelfMask, posEmbedding, originalSrcLengths); if (srcSelfMask != null) { srcSelfMask.Dispose(); } // Generate output decoder sentences if (decoder is AttentionDecoder) { return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining)); } else { if (isTraining) { return(DecodeTransformer(tgtSnts, computeGraph, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining)); } else { for (int i = 0; i < m_maxTgtSntSize; i++) { using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}")) { DecodeTransformer(tgtSnts, g, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining); bool allSntsEnd = true; for (int j = 0; j < tgtSnts.Count; j++) { if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS) { allSntsEnd = false; break; } } if (allSntsEnd) { break; } } } RemoveDuplicatedEOS(tgtSnts); return(0.0f); } } }
private static void Main(string[] args) { ShowOptions(args); Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log"; //Parse command line var opts = new Options(); var argParser = new ArgParser(args, opts); if (string.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } SequenceLabel sl = null; var processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); var encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); var mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); //Parse device ids from options var deviceIds = opts.DeviceIds.Split(',').Select(int.Parse).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus var trainCorpus = new SequenceLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Load valid corpus var validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new SequenceLabelingCorpus(opts.ValidCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Load or build vocabulary Vocab vocab = null; if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer var optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics var metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } if (File.Exists(opts.ModelFilePath) == false) { //New training sl = new SequenceLabel(opts.HiddenSize, opts.WordVectorSize, opts.EncoderLayerDepth, opts.MultiHeadNum, encoderType, opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab, maxSntSize: opts.MaxSentLength); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, opts.DropoutRatio, opts.MaxSentLength); } // Add event handler for monitoring sl.IterationDone += ss_IterationDone; // Kick off training sl.Train(opts.MaxEpochNum, trainCorpus, validCorpus, learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Load valid corpus var validCorpus = new SequenceLabelingCorpus(opts.ValidCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); var vocab = new Vocab(validCorpus); // Create metrics var metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, maxSntSize: opts.MaxSentLength); sl.Valid(validCorpus, metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, maxSntSize: opts.MaxSentLength); var outputLines = new List <string>(); var data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (var line in data_sents_raw1) { var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false)); outputLines.AddRange(outputTokensBatch.Select(x => string.Join(" ", x))); } File.WriteAllLines(opts.OutputTestFile, outputLines); } else { argParser.Usage(); } }