public void Train() { //Create neural net work Logger.WriteLine("Create a new network according settings in configuration file."); Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}"); RNN <T> rnn = RNN <T> .CreateRNN(networkType); if (ModelSettings.IncrementalTrain) { Logger.WriteLine($"Loading previous trained model from {modelFilePath}."); rnn.LoadModel(modelFilePath, true); } else { Logger.WriteLine("Create a new network."); rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer); //Create tag-bigram transition probability matrix only for sequence RNN mode if (IsCRFTraining) { Logger.WriteLine("Initialize bigram transition for CRF output layer."); rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition); } } rnn.MaxSeqLength = maxSequenceLength; rnn.bVQ = ModelSettings.VQ != 0 ? true : false; rnn.IsCRFTraining = IsCRFTraining; int N = Environment.ProcessorCount * 2; List <RNN <T> > rnns = new List <RNN <T> >(); rnns.Add(rnn); for (int i = 1; i < N; i++) { rnns.Add(rnn.Clone()); } //Initialize RNNHelper RNNHelper.LearningRate = ModelSettings.LearningRate; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); RNNHelper.GradientCutoff = ModelSettings.GradientCutoff; RNNHelper.vecMaxGrad = new Vector <float>(RNNHelper.GradientCutoff); RNNHelper.vecMinGrad = new Vector <float>(-RNNHelper.GradientCutoff); RNNHelper.IsConstAlpha = ModelSettings.IsConstAlpha; RNNHelper.MiniBatchSize = ModelSettings.MiniBatchSize; Logger.WriteLine(""); Logger.WriteLine("Iterative training begins ..."); var bestTrainTknErrCnt = long.MaxValue; var bestValidTknErrCnt = long.MaxValue; var lastAlpha = RNNHelper.LearningRate; var iter = 0; ParallelOptions parallelOption = new ParallelOptions(); parallelOption.MaxDegreeOfParallelism = -1; while (true) { if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration) { Logger.WriteLine("We have trained this model {0} iteration, exit."); break; } var start = DateTime.Now; Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}"); //Clean all RNN instances' status for training foreach (var r in rnns) { r.CleanStatusForTraining(); } Process(rnns, TrainingSet, RunningMode.Training); var duration = DateTime.Now.Subtract(start); Logger.WriteLine($"End {iter} iteration. Time duration = {duration}"); Logger.WriteLine(""); if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate) { //Although we reduce alpha value, we still cannot get better result. Logger.WriteLine( $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early."); Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha); break; } lastAlpha = RNNHelper.LearningRate; int trainTknErrCnt = tknErrCnt; //Validate the model by validated corpus if (ValidationSet != null) { Logger.WriteLine("Verify model on validated corpus."); Process(rnns, ValidationSet, RunningMode.Validate); Logger.WriteLine("End model verification."); Logger.WriteLine(""); if (tknErrCnt < bestValidTknErrCnt) { //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set."); Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); bestValidTknErrCnt = tknErrCnt; } } else if (trainTknErrCnt < bestTrainTknErrCnt) { //We don't have validate corpus, but we get a better result on training corpus //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set."); Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); } Logger.WriteLine(""); if (trainTknErrCnt >= bestTrainTknErrCnt) { //We don't have better result on training set, so reduce learning rate RNNHelper.LearningRate = RNNHelper.LearningRate / 2.0f; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); } else { bestTrainTknErrCnt = trainTknErrCnt; } iter++; } }