public BiRNN(int modeltype) { if (modeltype == 0) { SimpleRNN s_forwardRNN = new SimpleRNN(); SimpleRNN s_backwardRNN = new SimpleRNN(); s_forwardRNN.setBPTT(4 + 1); s_forwardRNN.setBPTTBlock(30); s_backwardRNN.setBPTT(4 + 1); s_backwardRNN.setBPTTBlock(30); forwardRNN = s_forwardRNN; backwardRNN = s_backwardRNN; } else { forwardRNN = new LSTMRNN(); backwardRNN = new LSTMRNN(); } m_modeldirection = MODELDIRECTION.BI_DIRECTIONAL; }
public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN) { forwardRNN = s_forwardRNN; backwardRNN = s_backwardRNN; ModelType = forwardRNN.ModelType; ModelDirection = MODELDIRECTION.BI_DIRECTIONAL; }
private RNN <Sequence> GetRNNInstance() { RNN <Sequence> r = null; while (qRNNs.TryDequeue(out r) == false) { Thread.Yield(); } return(r); }
private RNN <Sequence> GetRNNInstance() { RNN <Sequence> r = null; if (qRNNs.TryDequeue(out r) == false) { r = rnn.Clone(); } return(r); }
public RNNDecoder(Config config) { Config = config; if (Config.ModelDirection == MODELDIRECTION.BiDirectional) { Logger.WriteLine("Model Structure: Bi-directional RNN"); rnn = new BiRNN <Sequence>(); } else { Logger.WriteLine("Model Structure: Simple RNN"); rnn = new ForwardRNN <Sequence>(); } rnn.LoadModel(config.ModelFilePath); Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); }
public RNNDecoder(Config config) { Config = config; RNN <Sequence> rnn = RNN <Sequence> .CreateRNN(Config.NetworkType); rnn.LoadModel(config.ModelFilePath); rnn.MaxSeqLength = config.MaxSequenceLength; Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); Logger.WriteLine($"Max Sequence Length: {rnn.MaxSeqLength}"); Logger.WriteLine($"Processor Count: {Environment.ProcessorCount}"); qRNNs = new ConcurrentQueue <RNN <Sequence> >(); for (var i = 0; i < Environment.ProcessorCount; i++) { qRNNs.Enqueue(rnn.Clone()); } }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; RNNHelper.CheckModelFileType(strModelFileName, out modelDir); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); m_Rnn = new BiRNN(); } else { Logger.WriteLine("Model Structure: Simple RNN"); m_Rnn = new ForwardRNN(); } m_Rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining); m_Featurizer = featurizer; }
public RNNDecoder(string strModelFileName) { MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; MODELTYPE modelType; RNNHelper.CheckModelFileType(strModelFileName, out modelDir, out modelType); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); rnn = new BiRNN <Sequence>(); } else { Logger.WriteLine("Model Structure: Simple RNN"); rnn = new ForwardRNN <Sequence>(); } ModelType = modelType; rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELTYPE model_type = RNN.CheckModelFileType(strModelFileName); if (model_type == MODELTYPE.SIMPLE) { Console.WriteLine("Model Structure: Simple RNN"); SimpleRNN sRNN = new SimpleRNN(); m_Rnn = sRNN; } else { Console.WriteLine("Model Structure: LSTM-RNN"); LSTMRNN lstmRNN = new LSTMRNN(); m_Rnn = lstmRNN; } m_Rnn.loadNetBin(strModelFileName); Console.WriteLine("CRF Model: {0}", m_Rnn.IsCRFModel()); m_Featurizer = featurizer; }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELTYPE modelType = MODELTYPE.SIMPLE; MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine(Logger.Level.info, "Model Structure: Bi-directional RNN"); if (modelType == MODELTYPE.SIMPLE) { m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN()); } else { m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN()); } } else { if (modelType == MODELTYPE.SIMPLE) { Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN"); m_Rnn = new SimpleRNN(); } else { Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN"); m_Rnn = new LSTMRNN(); } } m_Rnn.loadNetBin(strModelFileName); Logger.WriteLine(Logger.Level.info, "CRF Model: {0}", m_Rnn.IsCRFTraining); m_Featurizer = featurizer; }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELTYPE modelType = MODELTYPE.SIMPLE; MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); if (modelType == MODELTYPE.SIMPLE) { m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN()); } else { m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN()); } } else { if (modelType == MODELTYPE.SIMPLE) { Logger.WriteLine("Model Structure: Simple RNN"); m_Rnn = new SimpleRNN(); } else { Logger.WriteLine("Model Structure: LSTM-RNN"); m_Rnn = new LSTMRNN(); } } m_Rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining); m_Featurizer = featurizer; }
private void FreeRNNInstance(RNN <Sequence> r) { qRNNs.Enqueue(r); }
public void Train() { //Create neural net work Logger.WriteLine("Create a new network according settings in configuration file."); Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}"); RNN <T> rnn = RNN <T> .CreateRNN(networkType); if (ModelSettings.IncrementalTrain) { Logger.WriteLine($"Loading previous trained model from {modelFilePath}."); rnn.LoadModel(modelFilePath, true); } else { Logger.WriteLine("Create a new network."); rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer); //Create tag-bigram transition probability matrix only for sequence RNN mode if (IsCRFTraining) { Logger.WriteLine("Initialize bigram transition for CRF output layer."); rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition); } } rnn.MaxSeqLength = maxSequenceLength; rnn.bVQ = ModelSettings.VQ != 0 ? true : false; rnn.IsCRFTraining = IsCRFTraining; int N = Environment.ProcessorCount * 2; List <RNN <T> > rnns = new List <RNN <T> >(); rnns.Add(rnn); for (int i = 1; i < N; i++) { rnns.Add(rnn.Clone()); } //Initialize RNNHelper RNNHelper.LearningRate = ModelSettings.LearningRate; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); RNNHelper.GradientCutoff = ModelSettings.GradientCutoff; RNNHelper.vecMaxGrad = new Vector <float>(RNNHelper.GradientCutoff); RNNHelper.vecMinGrad = new Vector <float>(-RNNHelper.GradientCutoff); RNNHelper.IsConstAlpha = ModelSettings.IsConstAlpha; RNNHelper.MiniBatchSize = ModelSettings.MiniBatchSize; Logger.WriteLine(""); Logger.WriteLine("Iterative training begins ..."); var bestTrainTknErrCnt = long.MaxValue; var bestValidTknErrCnt = long.MaxValue; var lastAlpha = RNNHelper.LearningRate; var iter = 0; ParallelOptions parallelOption = new ParallelOptions(); parallelOption.MaxDegreeOfParallelism = -1; while (true) { if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration) { Logger.WriteLine("We have trained this model {0} iteration, exit."); break; } var start = DateTime.Now; Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}"); //Clean all RNN instances' status for training foreach (var r in rnns) { r.CleanStatusForTraining(); } Process(rnns, TrainingSet, RunningMode.Training); var duration = DateTime.Now.Subtract(start); Logger.WriteLine($"End {iter} iteration. Time duration = {duration}"); Logger.WriteLine(""); if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate) { //Although we reduce alpha value, we still cannot get better result. Logger.WriteLine( $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early."); Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha); break; } lastAlpha = RNNHelper.LearningRate; int trainTknErrCnt = tknErrCnt; //Validate the model by validated corpus if (ValidationSet != null) { Logger.WriteLine("Verify model on validated corpus."); Process(rnns, ValidationSet, RunningMode.Validate); Logger.WriteLine("End model verification."); Logger.WriteLine(""); if (tknErrCnt < bestValidTknErrCnt) { //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set."); Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); bestValidTknErrCnt = tknErrCnt; } } else if (trainTknErrCnt < bestTrainTknErrCnt) { //We don't have validate corpus, but we get a better result on training corpus //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set."); Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); } Logger.WriteLine(""); if (trainTknErrCnt >= bestTrainTknErrCnt) { //We don't have better result on training set, so reduce learning rate RNNHelper.LearningRate = RNNHelper.LearningRate / 2.0f; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); } else { bestTrainTknErrCnt = trainTknErrCnt; } iter++; } }
public void Process(RNN <T> rnn, DataSet <T> trainingSet, RunningMode runningMode, int totalSequenceNum) { //Shffle training corpus trainingSet.Shuffle(); for (var i = 0; i < trainingSet.SequenceList.Count; i++) { var pSequence = trainingSet.SequenceList[i]; int wordCnt = 0; if (pSequence is Sequence) { wordCnt = (pSequence as Sequence).States.Length; } else { SequencePair sp = pSequence as SequencePair; if (sp.srcSentence.TokensList.Count > rnn.MaxSeqLength) { continue; } wordCnt = sp.tgtSequence.States.Length; } if (wordCnt > rnn.MaxSeqLength) { continue; } Interlocked.Add(ref processedWordCnt, wordCnt); int[] predicted; if (IsCRFTraining) { predicted = rnn.ProcessSequenceCRF(pSequence as Sequence, runningMode); } else { Matrix <float> m; predicted = rnn.ProcessSequence(pSequence, runningMode, false, out m); } int newTknErrCnt; if (pSequence is Sequence) { newTknErrCnt = GetErrorTokenNum(pSequence as Sequence, predicted); } else { newTknErrCnt = GetErrorTokenNum((pSequence as SequencePair).tgtSequence, predicted); } Interlocked.Add(ref tknErrCnt, newTknErrCnt); if (newTknErrCnt > 0) { Interlocked.Increment(ref sentErrCnt); } Interlocked.Increment(ref processedSequence); if (processedSequence % 1000 == 0) { Logger.WriteLine("Progress = {0} ", processedSequence / 1000 + "K/" + totalSequenceNum / 1000.0 + "K"); Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)processedWordCnt * 100.0); Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)processedSequence * 100.0); } if (ModelSettings.SaveStep > 0 && processedSequence % ModelSettings.SaveStep == 0) { //After processed every m_SaveStep sentences, save current model into a temporary file Logger.WriteLine("Saving temporary model into file..."); rnn.SaveModel("model.tmp"); } } }