public override void LoadModel(string filename) { Logger.WriteLine(Logger.Level.info, "Loading bi-directional model: {0}", filename); forwardRNN.LoadModel(filename + ".forward"); backwardRNN.LoadModel(filename + ".backward"); Hidden2OutputWeight = forwardRNN.Hidden2OutputWeight; CRFTagTransWeights = forwardRNN.CRFTagTransWeights; using (StreamReader sr = new StreamReader(filename)) { BinaryReader br = new BinaryReader(sr.BaseStream); ModelType = (MODELTYPE)br.ReadInt32(); ModelDirection = (MODELDIRECTION)br.ReadInt32(); int iflag = br.ReadInt32(); if (iflag == 1) { IsCRFTraining = true; } else { IsCRFTraining = false; } //Load basic parameters L0 = br.ReadInt32(); L1 = br.ReadInt32(); L2 = br.ReadInt32(); DenseFeatureSize = br.ReadInt32(); } }
public RNNDecoder(Config config) { Config = config; if (Config.ModelDirection == MODELDIRECTION.BiDirectional) { Logger.WriteLine("Model Structure: Bi-directional RNN"); rnn = new BiRNN <Sequence>(); } else { Logger.WriteLine("Model Structure: Simple RNN"); rnn = new ForwardRNN <Sequence>(); } rnn.LoadModel(config.ModelFilePath); Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); }
public RNNDecoder(Config config) { Config = config; RNN <Sequence> rnn = RNN <Sequence> .CreateRNN(Config.NetworkType); rnn.LoadModel(config.ModelFilePath); rnn.MaxSeqLength = config.MaxSequenceLength; Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); Logger.WriteLine($"Max Sequence Length: {rnn.MaxSeqLength}"); Logger.WriteLine($"Processor Count: {Environment.ProcessorCount}"); qRNNs = new ConcurrentQueue <RNN <Sequence> >(); for (var i = 0; i < Environment.ProcessorCount; i++) { qRNNs.Enqueue(rnn.Clone()); } }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; RNNHelper.CheckModelFileType(strModelFileName, out modelDir); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); m_Rnn = new BiRNN(); } else { Logger.WriteLine("Model Structure: Simple RNN"); m_Rnn = new ForwardRNN(); } m_Rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining); m_Featurizer = featurizer; }
public RNNDecoder(string strModelFileName) { MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; MODELTYPE modelType; RNNHelper.CheckModelFileType(strModelFileName, out modelDir, out modelType); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); rnn = new BiRNN <Sequence>(); } else { Logger.WriteLine("Model Structure: Simple RNN"); rnn = new ForwardRNN <Sequence>(); } ModelType = modelType; rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining); }
public RNNDecoder(string strModelFileName, Featurizer featurizer) { MODELTYPE modelType = MODELTYPE.SIMPLE; MODELDIRECTION modelDir = MODELDIRECTION.FORWARD; RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir); if (modelDir == MODELDIRECTION.BI_DIRECTIONAL) { Logger.WriteLine("Model Structure: Bi-directional RNN"); if (modelType == MODELTYPE.SIMPLE) { m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN()); } else { m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN()); } } else { if (modelType == MODELTYPE.SIMPLE) { Logger.WriteLine("Model Structure: Simple RNN"); m_Rnn = new SimpleRNN(); } else { Logger.WriteLine("Model Structure: LSTM-RNN"); m_Rnn = new LSTMRNN(); } } m_Rnn.LoadModel(strModelFileName); Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining); m_Featurizer = featurizer; }
public void Train() { //Create neural net work Logger.WriteLine("Create a new network according settings in configuration file."); Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}"); RNN <T> rnn = RNN <T> .CreateRNN(networkType); if (ModelSettings.IncrementalTrain) { Logger.WriteLine($"Loading previous trained model from {modelFilePath}."); rnn.LoadModel(modelFilePath, true); } else { Logger.WriteLine("Create a new network."); rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer); //Create tag-bigram transition probability matrix only for sequence RNN mode if (IsCRFTraining) { Logger.WriteLine("Initialize bigram transition for CRF output layer."); rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition); } } rnn.MaxSeqLength = maxSequenceLength; rnn.bVQ = ModelSettings.VQ != 0 ? true : false; rnn.IsCRFTraining = IsCRFTraining; int N = Environment.ProcessorCount * 2; List <RNN <T> > rnns = new List <RNN <T> >(); rnns.Add(rnn); for (int i = 1; i < N; i++) { rnns.Add(rnn.Clone()); } //Initialize RNNHelper RNNHelper.LearningRate = ModelSettings.LearningRate; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); RNNHelper.GradientCutoff = ModelSettings.GradientCutoff; RNNHelper.vecMaxGrad = new Vector <float>(RNNHelper.GradientCutoff); RNNHelper.vecMinGrad = new Vector <float>(-RNNHelper.GradientCutoff); RNNHelper.IsConstAlpha = ModelSettings.IsConstAlpha; RNNHelper.MiniBatchSize = ModelSettings.MiniBatchSize; Logger.WriteLine(""); Logger.WriteLine("Iterative training begins ..."); var bestTrainTknErrCnt = long.MaxValue; var bestValidTknErrCnt = long.MaxValue; var lastAlpha = RNNHelper.LearningRate; var iter = 0; ParallelOptions parallelOption = new ParallelOptions(); parallelOption.MaxDegreeOfParallelism = -1; while (true) { if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration) { Logger.WriteLine("We have trained this model {0} iteration, exit."); break; } var start = DateTime.Now; Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}"); //Clean all RNN instances' status for training foreach (var r in rnns) { r.CleanStatusForTraining(); } Process(rnns, TrainingSet, RunningMode.Training); var duration = DateTime.Now.Subtract(start); Logger.WriteLine($"End {iter} iteration. Time duration = {duration}"); Logger.WriteLine(""); if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate) { //Although we reduce alpha value, we still cannot get better result. Logger.WriteLine( $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early."); Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha); break; } lastAlpha = RNNHelper.LearningRate; int trainTknErrCnt = tknErrCnt; //Validate the model by validated corpus if (ValidationSet != null) { Logger.WriteLine("Verify model on validated corpus."); Process(rnns, ValidationSet, RunningMode.Validate); Logger.WriteLine("End model verification."); Logger.WriteLine(""); if (tknErrCnt < bestValidTknErrCnt) { //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set."); Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); bestValidTknErrCnt = tknErrCnt; } } else if (trainTknErrCnt < bestTrainTknErrCnt) { //We don't have validate corpus, but we get a better result on training corpus //We got better result on validated corpus, save this model Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set."); Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%"); rnn.SaveModel(modelFilePath); } Logger.WriteLine(""); if (trainTknErrCnt >= bestTrainTknErrCnt) { //We don't have better result on training set, so reduce learning rate RNNHelper.LearningRate = RNNHelper.LearningRate / 2.0f; RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate); } else { bestTrainTknErrCnt = trainTknErrCnt; } iter++; } }