Example #1
0
        public override void LoadModel(string filename)
        {
            Logger.WriteLine(Logger.Level.info, "Loading bi-directional model: {0}", filename);

            forwardRNN.LoadModel(filename + ".forward");
            backwardRNN.LoadModel(filename + ".backward");

            Hidden2OutputWeight = forwardRNN.Hidden2OutputWeight;
            CRFTagTransWeights  = forwardRNN.CRFTagTransWeights;

            using (StreamReader sr = new StreamReader(filename))
            {
                BinaryReader br = new BinaryReader(sr.BaseStream);

                ModelType      = (MODELTYPE)br.ReadInt32();
                ModelDirection = (MODELDIRECTION)br.ReadInt32();

                int iflag = br.ReadInt32();
                if (iflag == 1)
                {
                    IsCRFTraining = true;
                }
                else
                {
                    IsCRFTraining = false;
                }

                //Load basic parameters
                L0 = br.ReadInt32();
                L1 = br.ReadInt32();
                L2 = br.ReadInt32();
                DenseFeatureSize = br.ReadInt32();
            }
        }
Example #2
0
        public RNNDecoder(Config config)
        {
            Config = config;
            if (Config.ModelDirection == MODELDIRECTION.BiDirectional)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                rnn = new BiRNN <Sequence>();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                rnn = new ForwardRNN <Sequence>();
            }

            rnn.LoadModel(config.ModelFilePath);
            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
        }
Example #3
0
        public RNNDecoder(Config config)
        {
            Config = config;
            RNN <Sequence> rnn = RNN <Sequence> .CreateRNN(Config.NetworkType);

            rnn.LoadModel(config.ModelFilePath);
            rnn.MaxSeqLength = config.MaxSequenceLength;

            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
            Logger.WriteLine($"Max Sequence Length: {rnn.MaxSeqLength}");
            Logger.WriteLine($"Processor Count: {Environment.ProcessorCount}");

            qRNNs = new ConcurrentQueue <RNN <Sequence> >();
            for (var i = 0; i < Environment.ProcessorCount; i++)
            {
                qRNNs.Enqueue(rnn.Clone());
            }
        }
Example #4
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                m_Rnn = new BiRNN();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                m_Rnn = new ForwardRNN();
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Example #5
0
        public RNNDecoder(string strModelFileName)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;
            MODELTYPE      modelType;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir, out modelType);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                rnn = new BiRNN <Sequence>();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                rnn = new ForwardRNN <Sequence>();
            }
            ModelType = modelType;

            rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
        }
Example #6
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE      modelType = MODELTYPE.SIMPLE;
            MODELDIRECTION modelDir  = MODELDIRECTION.FORWARD;

            RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir);

            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                if (modelType == MODELTYPE.SIMPLE)
                {
                    m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
                }
                else
                {
                    m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }
            else
            {
                if (modelType == MODELTYPE.SIMPLE)
                {
                    Logger.WriteLine("Model Structure: Simple RNN");
                    m_Rnn = new SimpleRNN();
                }
                else
                {
                    Logger.WriteLine("Model Structure: LSTM-RNN");
                    m_Rnn = new LSTMRNN();
                }
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Example #7
0
        public void Train()
        {
            //Create neural net work
            Logger.WriteLine("Create a new network according settings in configuration file.");
            Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}");

            RNN <T> rnn = RNN <T> .CreateRNN(networkType);

            if (ModelSettings.IncrementalTrain)
            {
                Logger.WriteLine($"Loading previous trained model from {modelFilePath}.");
                rnn.LoadModel(modelFilePath, true);
            }
            else
            {
                Logger.WriteLine("Create a new network.");
                rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer);
                //Create tag-bigram transition probability matrix only for sequence RNN mode
                if (IsCRFTraining)
                {
                    Logger.WriteLine("Initialize bigram transition for CRF output layer.");
                    rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition);
                }
            }

            rnn.MaxSeqLength  = maxSequenceLength;
            rnn.bVQ           = ModelSettings.VQ != 0 ? true : false;
            rnn.IsCRFTraining = IsCRFTraining;

            int             N    = Environment.ProcessorCount * 2;
            List <RNN <T> > rnns = new List <RNN <T> >();

            rnns.Add(rnn);

            for (int i = 1; i < N; i++)
            {
                rnns.Add(rnn.Clone());
            }

            //Initialize RNNHelper
            RNNHelper.LearningRate          = ModelSettings.LearningRate;
            RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);

            RNNHelper.GradientCutoff = ModelSettings.GradientCutoff;
            RNNHelper.vecMaxGrad     = new Vector <float>(RNNHelper.GradientCutoff);
            RNNHelper.vecMinGrad     = new Vector <float>(-RNNHelper.GradientCutoff);
            RNNHelper.IsConstAlpha   = ModelSettings.IsConstAlpha;
            RNNHelper.MiniBatchSize  = ModelSettings.MiniBatchSize;

            Logger.WriteLine("");

            Logger.WriteLine("Iterative training begins ...");
            var             bestTrainTknErrCnt = long.MaxValue;
            var             bestValidTknErrCnt = long.MaxValue;
            var             lastAlpha          = RNNHelper.LearningRate;
            var             iter           = 0;
            ParallelOptions parallelOption = new ParallelOptions();

            parallelOption.MaxDegreeOfParallelism = -1;
            while (true)
            {
                if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration)
                {
                    Logger.WriteLine("We have trained this model {0} iteration, exit.");
                    break;
                }

                var start = DateTime.Now;
                Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}");

                //Clean all RNN instances' status for training
                foreach (var r in rnns)
                {
                    r.CleanStatusForTraining();
                }
                Process(rnns, TrainingSet, RunningMode.Training);

                var duration = DateTime.Now.Subtract(start);

                Logger.WriteLine($"End {iter} iteration. Time duration = {duration}");
                Logger.WriteLine("");

                if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate)
                {
                    //Although we reduce alpha value, we still cannot get better result.
                    Logger.WriteLine(
                        $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early.");
                    Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha);
                    break;
                }
                lastAlpha = RNNHelper.LearningRate;

                int trainTknErrCnt = tknErrCnt;
                //Validate the model by validated corpus
                if (ValidationSet != null)
                {
                    Logger.WriteLine("Verify model on validated corpus.");
                    Process(rnns, ValidationSet, RunningMode.Validate);
                    Logger.WriteLine("End model verification.");
                    Logger.WriteLine("");

                    if (tknErrCnt < bestValidTknErrCnt)
                    {
                        //We got better result on validated corpus, save this model
                        Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set.");
                        Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                        rnn.SaveModel(modelFilePath);
                        bestValidTknErrCnt = tknErrCnt;
                    }
                }
                else if (trainTknErrCnt < bestTrainTknErrCnt)
                {
                    //We don't have validate corpus, but we get a better result on training corpus
                    //We got better result on validated corpus, save this model
                    Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set.");
                    Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                    rnn.SaveModel(modelFilePath);
                }

                Logger.WriteLine("");

                if (trainTknErrCnt >= bestTrainTknErrCnt)
                {
                    //We don't have better result on training set, so reduce learning rate
                    RNNHelper.LearningRate          = RNNHelper.LearningRate / 2.0f;
                    RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);
                }
                else
                {
                    bestTrainTknErrCnt = trainTknErrCnt;
                }

                iter++;
            }
        }