public void VisualizeNeuralNetwork(string visNNFilePath) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1); // Build input sentence List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(null); int batchSize = inputSeqs.Count; IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false, visNetwork: true); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Run encoder IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null); // Prepare for attention over encoder-decoder AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); // Run decoder IWeightTensor x = g.PeekRow(tgtEmbedding, (int)SENTTAGS.START); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); IWeightTensor probs = g.Softmax(eOutput); g.VisualizeNeuralNetToFile(visNNFilePath); }
public void VisualizeNeuralNetwork(string visNNFilePath) { (IEncoder encoder, AttentionDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(-1); // Build input sentence var inputSeqs = ParallelCorpus.ConstructInputTokens(null); int batchSize = inputSeqs.Count; var g = CreateComputGraph(m_defaultDeviceId, needBack: false, visNetwork: true); encoder.Reset(g.GetWeightFactory(), batchSize); decoder.Reset(g.GetWeightFactory(), batchSize); // Run encoder IWeightTensor encodedWeightMatrix = Encode(g.CreateSubGraph("Encoder"), inputSeqs, encoder, srcEmbedding); // Prepare for attention over encoder-decoder g = g.CreateSubGraph("Decoder"); var attPreProcessResult = decoder.PreProcess(encodedWeightMatrix, batchSize, g); // Run decoder var x = g.PeekRow(tgtEmbedding, (int)SENTTAGS.START); var eOutput = decoder.Decode(x, attPreProcessResult, batchSize, g); var o = decoderFFLayer.Process(eOutput, batchSize, g); var probs = g.Softmax(o); g.VisualizeNeuralNetToFile(visNNFilePath); }
/// <summary> /// Given input sentence and generate output sentence by seq2seq model with beam search /// </summary> /// <param name="input"></param> /// <param name="beamSearchSize"></param> /// <param name="maxOutputLength"></param> /// <returns></returns> public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1); List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input); int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Construct beam search status list List <BeamSearchStatus> bssList = new List <BeamSearchStatus>(); BeamSearchStatus bss = new BeamSearchStatus(); bss.OutputIds.Add((int)SENTTAGS.START); bss.CTs = rnnDecoder.GetCTs(); bss.HTs = rnnDecoder.GetHTs(); bssList.Add(bss); IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null); AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>(); bool finished = false; int outputLength = 0; while (finished == false && outputLength < maxOutputLength) { finished = true; for (int i = 0; i < bssList.Count; i++) { bss = bssList[i]; if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END) { newBSSList.Add(bss); } else if (bss.OutputIds.Count > maxOutputLength) { newBSSList.Add(bss); } else { finished = false; int ix_input = bss.OutputIds[bss.OutputIds.Count - 1]; rnnDecoder.SetCTs(bss.CTs); rnnDecoder.SetHTs(bss.HTs); IWeightTensor x = g.PeekRow(tgtEmbedding, ix_input); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); using (IWeightTensor probs = g.Softmax(eOutput)) { List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize); for (int j = 0; j < preds.Count; j++) { BeamSearchStatus newBSS = new BeamSearchStatus(); newBSS.OutputIds.AddRange(bss.OutputIds); newBSS.OutputIds.Add(preds[j]); newBSS.CTs = rnnDecoder.GetCTs(); newBSS.HTs = rnnDecoder.GetHTs(); float score = probs.GetWeightAt(preds[j]); newBSS.Score = bss.Score; newBSS.Score += (float)(-Math.Log(score)); //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6); //newBSS.Score /= (float)lengthPenalty; newBSSList.Add(newBSS); } } } } bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize); newBSSList.Clear(); outputLength++; } // Convert output target word ids to real string List <List <string> > results = new List <List <string> >(); for (int i = 0; i < bssList.Count; i++) { results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds)); } return(results); }
private static void Main(string[] args) { try { Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log"; ShowOptions(args); //Parse command line Options opts = new Options(); ArgParser argParser = new ArgParser(args, opts); if (string.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } AttentionSeq2Seq ss = null; ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); DecoderTypeEnums decoderType = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType); ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); ShuffleEnums shuffleType = (ShuffleEnums)Enum.Parse(typeof(ShuffleEnums), opts.ShuffleType); string[] cudaCompilerOptions = String.IsNullOrEmpty(opts.CompilerOptions) ? null : opts.CompilerOptions.Split(' ', StringSplitOptions.RemoveEmptyEntries); //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus ParallelCorpus trainCorpus = new ParallelCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, batchSize: opts.BatchSize, shuffleBlockSize: opts.ShuffleBlockSize, maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: shuffleType); // Load valid corpus ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength); // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics List <IMetric> metrics = new List <IMetric> { new BleuMetric(), new LengthRatioMetric() }; if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath)) { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions); } else { // Load or build vocabulary Vocab vocab = null; if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } //New training ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, enableCoverageModel: opts.EnableCoverageModel, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions); } // Add event handler for monitoring ss.IterationDone += ss_IterationDone; // Kick off training ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Create metrics List <IMetric> metrics = new List <IMetric> { new BleuMetric(), new LengthRatioMetric() }; // Load valid corpus ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValBatchSize, opts.ShuffleBlockSize, opts.MaxSrcSentLength, opts.MaxTgtSentLength); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, compilerOptions: cudaCompilerOptions); ss.Valid(validCorpus: validCorpus, metrics: metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, memoryUsageRatio: opts.MemoryUsageRatio, shuffleType: shuffleType, maxSrcSntSize: opts.MaxSrcSentLength, maxTgtSntSize: opts.MaxTgtSentLength, compilerOptions: cudaCompilerOptions); List <string> outputLines = new List <string>(); string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { if (opts.BeamSearch > 1) { // Below support beam search List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch); outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x))); } else { var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList())); outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x))); } } File.WriteAllLines(opts.OutputTestFile, outputLines); } else if (mode == ModeEnums.DumpVocab) { ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, compilerOptions: cudaCompilerOptions); ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab); } else { argParser.Usage(); } } catch (Exception err) { Logger.WriteLine($"Exception: '{err.Message}'"); Logger.WriteLine($"Call stack: '{err.StackTrace}'"); } }
static void Main(string[] args) { ShowOptions(args); Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log"; //Parse command line Options opts = new Options(); ArgParser argParser = new ArgParser(args, opts); if (String.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } SequenceLabel sl = null; ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false); // Load valid corpus ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, addBOSEOS: false); // Load or build vocabulary Vocab vocab = null; if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics List <IMetric> metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } if (File.Exists(opts.ModelFilePath) == false) { //New training sl = new SequenceLabel(hiddenDim: opts.HiddenSize, embeddingDim: opts.WordVectorSize, encoderLayerDepth: opts.EncoderLayerDepth, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds, dropoutRatio: opts.DropoutRatio); } // Add event handler for monitoring sl.IterationDone += ss_IterationDone; // Kick off training sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Load valid corpus ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength, false); Vocab vocab = new Vocab(validCorpus); // Create metrics List <IMetric> metrics = new List <IMetric>(); foreach (var word in vocab.TgtVocab) { metrics.Add(new SequenceLabelFscoreMetric(word)); } sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); sl.Valid(validCorpus: validCorpus, metrics: metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model sl = new SequenceLabel(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); List <string> outputLines = new List <string>(); var data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false)); outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x))); } File.WriteAllLines(opts.OutputTestFile, outputLines); } //else if (mode == ModeEnums.VisualizeNetwork) //{ // ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, // vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, // processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType); // ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath); //} else { argParser.Usage(); } }