示例#1
1
文件: BiRNN.cs 项目: zxz/RNNSharp
        public BiRNN(int modeltype)
        {
            if (modeltype == 0)
            {
                SimpleRNN s_forwardRNN = new SimpleRNN();
                SimpleRNN s_backwardRNN = new SimpleRNN();

                s_forwardRNN.setBPTT(4 + 1);
                s_forwardRNN.setBPTTBlock(30);

                s_backwardRNN.setBPTT(4 + 1);
                s_backwardRNN.setBPTTBlock(30);

                forwardRNN = s_forwardRNN;
                backwardRNN = s_backwardRNN;
            }
            else
            {
                forwardRNN = new LSTMRNN();
                backwardRNN = new LSTMRNN();
            }

            m_modeldirection = MODELDIRECTION.BI_DIRECTIONAL;
        }
示例#2
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE model_type = RNN.CheckModelFileType(strModelFileName);

            if (model_type == MODELTYPE.SIMPLE)
            {
                Console.WriteLine("Model Structure: Simple RNN");
                SimpleRNN sRNN = new SimpleRNN();
                m_Rnn = sRNN;
            }
            else
            {
                Console.WriteLine("Model Structure: LSTM-RNN");
                LSTMRNN lstmRNN = new LSTMRNN();
                m_Rnn = lstmRNN;
            }

            m_Rnn.loadNetBin(strModelFileName);
            Console.WriteLine("CRF Model: {0}", m_Rnn.IsCRFModel());
            m_Featurizer = featurizer;
        }
示例#3
0
        public void Train()
        {
            RNN rnn;

            if (m_modelSetting.GetModelDirection() == 0)
            {
                if (m_modelSetting.GetModelType() == 0)
                {
                    SimpleRNN sRNN = new SimpleRNN();

                    sRNN.setBPTT(m_modelSetting.GetBptt() + 1);
                    sRNN.setBPTTBlock(30);

                    rnn = sRNN;
                }
                else
                {
                    LSTMRNN lstmRNN = new LSTMRNN();
                    rnn = lstmRNN;
                }
            }
            else
            {
                BiRNN biRNN = new BiRNN(m_modelSetting.GetModelType());
                rnn = biRNN;
            }

            rnn.SetModelDirection(m_modelSetting.GetModelDirection());
            rnn.SetTrainingSet(m_TrainingSet);
            rnn.SetValidationSet(m_ValidationSet);
            rnn.SetModelFile(m_modelSetting.GetModelFile());
            rnn.SetSaveStep(m_modelSetting.GetSaveStep());
            rnn.SetMaxIter(m_modelSetting.GetMaxIteration());
            rnn.SetCRFTraining(m_modelSetting.IsCRFTraining());
            rnn.SetLearningRate(m_modelSetting.GetLearningRate());
            rnn.SetGradientCutoff(15.0);
            rnn.SetRegularization(m_modelSetting.GetRegularization());
            rnn.SetHiddenLayerSize(m_modelSetting.GetNumHidden());
            rnn.SetTagBigramTransitionWeight(m_modelSetting.GetTagTransitionWeight());

            rnn.initMem();

            //Create tag-bigram transition probability matrix only for sequence RNN mode
            if (m_modelSetting.IsCRFTraining() == true)
            {
                rnn.setTagBigramTransition(m_LabelBigramTransition);
            }

            Console.WriteLine();

            Console.WriteLine("[TRACE] Iterative training begins ...");
            while (rnn.ShouldTrainingStop() == false)
            {
                //Start to train model
                rnn.TrainNet();

                //Validate the model by validated corpus
                if (rnn.ValidateNet() == true)
                {
                    //If current model is better than before, save it into file
                    Console.Write("Saving better model into file {0}...", m_modelSetting.GetModelFile());
                    rnn.saveNetBin(m_modelSetting.GetModelFile());
                    Console.WriteLine("Done.");
                }
                //else
                //{
                //    Console.Write("Loading previous best model from file {0}...", m_modelSetting.GetModelFile());
                //    rnn.loadNetBin(m_modelSetting.GetModelFile());
                //    Console.WriteLine("Done.");
                //}
            }
        }
示例#4
0
        public void Train()
        {
            RNN rnn;

            if (m_modelSetting.ModelDirection == 0)
            {
                if (m_modelSetting.ModelType == 0)
                {
                    SimpleRNN sRNN = new SimpleRNN();

                    sRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sRNN.setBPTTBlock(10);

                    rnn = sRNN;
                }
                else
                {
                    rnn = new LSTMRNN();
                }
            }
            else
            {
                if (m_modelSetting.ModelType == 0)
                {
                    SimpleRNN sForwardRNN  = new SimpleRNN();
                    SimpleRNN sBackwardRNN = new SimpleRNN();

                    sForwardRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sForwardRNN.setBPTTBlock(10);

                    sBackwardRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sBackwardRNN.setBPTTBlock(10);

                    rnn = new BiRNN(sForwardRNN, sBackwardRNN);
                }
                else
                {
                    rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }

            rnn.ModelDirection = (MODELDIRECTION)m_modelSetting.ModelDirection;
            rnn.bVQ            = (m_modelSetting.VQ != 0) ? true : false;
            rnn.ModelFile      = m_modelSetting.ModelFile;
            rnn.SaveStep       = m_modelSetting.SaveStep;
            rnn.MaxIter        = m_modelSetting.MaxIteration;
            rnn.IsCRFTraining  = m_modelSetting.IsCRFTraining;
            rnn.LearningRate   = m_modelSetting.LearningRate;
            rnn.GradientCutoff = m_modelSetting.GradientCutoff;
            rnn.Dropout        = m_modelSetting.Dropout;
            rnn.L1             = m_modelSetting.NumHidden;

            rnn.DenseFeatureSize = TrainingSet.DenseFeatureSize();
            rnn.L0 = TrainingSet.GetSparseDimension();
            rnn.L2 = TrainingSet.TagSize;

            rnn.InitMem();

            //Create tag-bigram transition probability matrix only for sequence RNN mode
            if (m_modelSetting.IsCRFTraining)
            {
                rnn.setTagBigramTransition(TrainingSet.CRFLabelBigramTransition);
            }

            Logger.WriteLine("");

            Logger.WriteLine("Iterative training begins ...");
            double lastPPL   = double.MaxValue;
            double lastAlpha = rnn.LearningRate;
            int    iter      = 0;

            while (true)
            {
                Logger.WriteLine("Cleaning training status...");
                rnn.CleanStatus();

                if (rnn.MaxIter > 0 && iter > rnn.MaxIter)
                {
                    Logger.WriteLine("We have trained this model {0} iteration, exit.");
                    break;
                }

                //Start to train model
                double ppl = rnn.TrainNet(TrainingSet, iter);
                if (ppl >= lastPPL && lastAlpha != rnn.LearningRate)
                {
                    //Although we reduce alpha value, we still cannot get better result.
                    Logger.WriteLine("Current perplexity({0}) is larger than the previous one({1}). End training early.", ppl, lastPPL);
                    Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", rnn.LearningRate, lastAlpha);
                    break;
                }
                lastAlpha = rnn.LearningRate;

                //Validate the model by validated corpus
                if (ValidationSet != null)
                {
                    Logger.WriteLine("Verify model on validated corpus.");
                    if (rnn.ValidateNet(ValidationSet, iter) == true)
                    {
                        //We got better result on validated corpus, save this model
                        Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
                        rnn.SaveModel(m_modelSetting.ModelFile);
                    }
                }
                else if (ppl < lastPPL)
                {
                    //We don't have validate corpus, but we get a better result on training corpus
                    //We got better result on validated corpus, save this model
                    Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
                    rnn.SaveModel(m_modelSetting.ModelFile);
                }

                if (ppl >= lastPPL)
                {
                    //We cannot get a better result on training corpus, so reduce learning rate
                    rnn.LearningRate = rnn.LearningRate / 2.0f;
                }

                lastPPL = ppl;

                iter++;
            }
        }
示例#5
0
        public void Train()
        {
            RNN rnn;

            if (m_modelSetting.ModelDirection == 0)
            {
                if (m_modelSetting.ModelType == 0)
                {
                    SimpleRNN sRNN = new SimpleRNN();

                    sRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sRNN.setBPTTBlock(10);

                    rnn = sRNN;
                }
                else
                {
                    rnn = new LSTMRNN();
                }
            }
            else
            {
                if (m_modelSetting.ModelType == 0)
                {
                    SimpleRNN sForwardRNN = new SimpleRNN();
                    SimpleRNN sBackwardRNN = new SimpleRNN();

                    sForwardRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sForwardRNN.setBPTTBlock(10);

                    sBackwardRNN.setBPTT(m_modelSetting.Bptt + 1);
                    sBackwardRNN.setBPTTBlock(10);

                    rnn = new BiRNN(sForwardRNN, sBackwardRNN);
                }
                else
                {
                    rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }

            rnn.ModelDirection = (MODELDIRECTION)m_modelSetting.ModelDirection;
            rnn.ModelFile = m_modelSetting.ModelFile;
            rnn.SaveStep = m_modelSetting.SaveStep;
            rnn.MaxIter = m_modelSetting.MaxIteration;
            rnn.IsCRFTraining = m_modelSetting.IsCRFTraining;
            rnn.LearningRate = m_modelSetting.LearningRate;
            rnn.GradientCutoff = 15.0f;
            rnn.Dropout = m_modelSetting.Dropout;
            rnn.L1 = m_modelSetting.NumHidden;

            rnn.DenseFeatureSize = TrainingSet.DenseFeatureSize();
            rnn.L0 = TrainingSet.GetSparseDimension();
            rnn.L2 = TrainingSet.TagSize;

            rnn.initMem();
            
            //Create tag-bigram transition probability matrix only for sequence RNN mode
            if (m_modelSetting.IsCRFTraining)
            {
                rnn.setTagBigramTransition(TrainingSet.CRFLabelBigramTransition);
            }

            Logger.WriteLine(Logger.Level.info, "");

            Logger.WriteLine(Logger.Level.info, "[TRACE] Iterative training begins ...");
            double lastPPL = double.MaxValue;
            double lastAlpha = rnn.LearningRate;
            int iter = 0;
            while (true)
            {
                if (rnn.MaxIter > 0 && iter > rnn.MaxIter)
                {
                    Logger.WriteLine(Logger.Level.info, "We have trained this model {0} iteration, exit.");
                    break;
                }

                //Start to train model
                double ppl = rnn.TrainNet(TrainingSet, iter);

                //Validate the model by validated corpus
                bool betterValidateNet = false;
                if (rnn.ValidateNet(ValidationSet, iter) == true)
                {
                    //If current model is better than before, save it into file
                    Logger.WriteLine(Logger.Level.info, "Saving better model into file {0}...", m_modelSetting.ModelFile);
                    rnn.saveNetBin(m_modelSetting.ModelFile);

                    betterValidateNet = true;
                }

                if (ppl >= lastPPL && lastAlpha != rnn.LearningRate)
                {
                    //Although we reduce alpha value, we still cannot get better result.
                    Logger.WriteLine(Logger.Level.info, "Current perplexity({0}) is larger than the previous one({1}). End training early.", ppl, lastPPL);
                    Logger.WriteLine(Logger.Level.info, "Current alpha: {0}, the previous alpha: {1}", rnn.LearningRate, lastAlpha);
                    break;
                }

                lastAlpha = rnn.LearningRate;
                if (betterValidateNet == false)
                {
                    rnn.LearningRate = rnn.LearningRate / 2.0f;
                }

                lastPPL = ppl;

                iter++;
            }
        }