public override void SaveModel(string filename) { //Save bi-directional model forwardRNN.Hidden2OutputWeight = Hidden2OutputWeight; backwardRNN.Hidden2OutputWeight = Hidden2OutputWeight; forwardRNN.CRFTagTransWeights = CRFTagTransWeights; backwardRNN.CRFTagTransWeights = CRFTagTransWeights; forwardRNN.SaveModel(filename + ".forward"); backwardRNN.SaveModel(filename + ".backward"); //Save meta data using (StreamWriter sw = new StreamWriter(filename)) { BinaryWriter fo = new BinaryWriter(sw.BaseStream); fo.Write((int)ModelType); fo.Write((int)ModelDirection); // Signiture , 0 is for RNN or 1 is for RNN-CRF int iflag = 0; if (IsCRFTraining == true) { iflag = 1; } fo.Write(iflag); fo.Write(L0); fo.Write(L1); fo.Write(L2); fo.Write(DenseFeatureSize); } }
public void Train() { //Create neural net work Logger.WriteLine("Create a new network according settings in configuration file."); Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}"); RNN <T> rnn = RNN <T> .CreateRNN(networkType); if (ModelSettings.IncrementalTrain) { Logger.WriteLine($"Loading previous trained model from {modelFilePath}."); rnn.LoadModel(modelFilePath, true); } else { Logger.WriteLine("Create a new network."); rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer); //Create tag-bigram transition probability matrix only for sequence RNN mode if (IsCRFTraining) { Logger.WriteLine("Initialize bigram transition for CRF output layer."); rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition); } } rnn.MaxSeqLength = maxSequenceLength; rnn.bVQ = ModelSettings.VQ != 0 ? true : false; rnn.IsCRFTraining = IsCRFTraining; int N = Environment.ProcessorCount * 2; List <RNN <T> > rnns = new List <RNN <T> >(); rnns.Add(rnn); for (int i = 1; i < N; i++) { rnns.Add(rnn.Clone()); } //Initialize RNNHelper RNNHelper.LearningRate = ModelSettings.LearningRate; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); RNNHelper.GradientCutoff = ModelSettings.GradientCutoff; RNNHelper.vecMaxGrad = new Vector <float>(RNNHelper.GradientCutoff); RNNHelper.vecMinGrad = new Vector <float>(-RNNHelper.GradientCutoff); RNNHelper.IsConstAlpha = ModelSettings.IsConstAlpha; RNNHelper.MiniBatchSize = ModelSettings.MiniBatchSize; Logger.WriteLine(""); Logger.WriteLine("Iterative training begins ..."); var bestTrainTknErrCnt = long.MaxValue; var bestValidTknErrCnt = long.MaxValue; var lastAlpha = RNNHelper.LearningRate; var iter = 0; ParallelOptions parallelOption = new ParallelOptions(); parallelOption.MaxDegreeOfParallelism = -1; while (true) { if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration) { Logger.WriteLine("We have trained this model {0} iteration, exit."); break; } var start = DateTime.Now; Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}"); //Clean all RNN instances' status for training foreach (var r in rnns) { r.CleanStatusForTraining(); } Process(rnns, TrainingSet, RunningMode.Training); var duration = DateTime.Now.Subtract(start); Logger.WriteLine($"End {iter} iteration. Time duration = {duration}"); Logger.WriteLine(""); if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate) { //Although we reduce alpha value, we still cannot get better result. Logger.WriteLine( $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early."); Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha); break; } lastAlpha = RNNHelper.LearningRate; int trainTknErrCnt = tknErrCnt; //Validate the model by validated corpus if (ValidationSet != null) { Logger.WriteLine("Verify model on validated corpus."); Process(rnns, ValidationSet, RunningMode.Validate); Logger.WriteLine("End model verification."); Logger.WriteLine(""); if (tknErrCnt < bestValidTknErrCnt) { //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set."); Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); bestValidTknErrCnt = tknErrCnt; } } else if (trainTknErrCnt < bestTrainTknErrCnt) { //We don't have validate corpus, but we get a better result on training corpus //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set."); Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); } Logger.WriteLine(""); if (trainTknErrCnt >= bestTrainTknErrCnt) { //We don't have better result on training set, so reduce learning rate RNNHelper.LearningRate = RNNHelper.LearningRate / 2.0f; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); } else { bestTrainTknErrCnt = trainTknErrCnt; } iter++; } }
public void Process(RNN <T> rnn, DataSet <T> trainingSet, RunningMode runningMode, int totalSequenceNum) { //Shffle training corpus trainingSet.Shuffle(); for (var i = 0; i < trainingSet.SequenceList.Count; i++) { var pSequence = trainingSet.SequenceList[i]; int wordCnt = 0; if (pSequence is Sequence) { wordCnt = (pSequence as Sequence).States.Length; } else { SequencePair sp = pSequence as SequencePair; if (sp.srcSentence.TokensList.Count > rnn.MaxSeqLength) { continue; } wordCnt = sp.tgtSequence.States.Length; } if (wordCnt > rnn.MaxSeqLength) { continue; } Interlocked.Add(ref processedWordCnt, wordCnt); int[] predicted; if (IsCRFTraining) { predicted = rnn.ProcessSequenceCRF(pSequence as Sequence, runningMode); } else { Matrix <float> m; predicted = rnn.ProcessSequence(pSequence, runningMode, false, out m); } int newTknErrCnt; if (pSequence is Sequence) { newTknErrCnt = GetErrorTokenNum(pSequence as Sequence, predicted); } else { newTknErrCnt = GetErrorTokenNum((pSequence as SequencePair).tgtSequence, predicted); } Interlocked.Add(ref tknErrCnt, newTknErrCnt); if (newTknErrCnt > 0) { Interlocked.Increment(ref sentErrCnt); } Interlocked.Increment(ref processedSequence); if (processedSequence % 1000 == 0) { Logger.WriteLine("Progress = {0} ", processedSequence / 1000 + "K/" + totalSequenceNum / 1000.0 + "K"); Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)processedWordCnt * 100.0); Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)processedSequence * 100.0); } if (ModelSettings.SaveStep > 0 && processedSequence % ModelSettings.SaveStep == 0) { //After processed every m_SaveStep sentences, save current model into a temporary file Logger.WriteLine("Saving temporary model into file..."); rnn.SaveModel("model.tmp"); } } }