Exemplo n.º 1
1
        public BiRNN(int modeltype)
        {
            if (modeltype == 0)
            {
                SimpleRNN s_forwardRNN = new SimpleRNN();
                SimpleRNN s_backwardRNN = new SimpleRNN();

                s_forwardRNN.setBPTT(4 + 1);
                s_forwardRNN.setBPTTBlock(30);

                s_backwardRNN.setBPTT(4 + 1);
                s_backwardRNN.setBPTTBlock(30);

                forwardRNN = s_forwardRNN;
                backwardRNN = s_backwardRNN;
            }
            else
            {
                forwardRNN = new LSTMRNN();
                backwardRNN = new LSTMRNN();
            }

            m_modeldirection = MODELDIRECTION.BI_DIRECTIONAL;
        }
Exemplo n.º 2
0
        public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN)
        {
            forwardRNN = s_forwardRNN;
            backwardRNN = s_backwardRNN;

            ModelType = forwardRNN.ModelType;
            ModelDirection = MODELDIRECTION.BI_DIRECTIONAL;
        }
Exemplo n.º 3
0
        public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN)
        {
            forwardRNN  = s_forwardRNN;
            backwardRNN = s_backwardRNN;

            ModelType      = forwardRNN.ModelType;
            ModelDirection = MODELDIRECTION.BI_DIRECTIONAL;
        }
Exemplo n.º 4
0
        private RNN <Sequence> GetRNNInstance()
        {
            RNN <Sequence> r = null;

            while (qRNNs.TryDequeue(out r) == false)
            {
                Thread.Yield();
            }

            return(r);
        }
Exemplo n.º 5
0
        private RNN <Sequence> GetRNNInstance()
        {
            RNN <Sequence> r = null;

            if (qRNNs.TryDequeue(out r) == false)
            {
                r = rnn.Clone();
            }

            return(r);
        }
Exemplo n.º 6
0
        public RNNDecoder(Config config)
        {
            Config = config;
            if (Config.ModelDirection == MODELDIRECTION.BiDirectional)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                rnn = new BiRNN <Sequence>();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                rnn = new ForwardRNN <Sequence>();
            }

            rnn.LoadModel(config.ModelFilePath);
            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
        }
Exemplo n.º 7
0
        public RNNDecoder(Config config)
        {
            Config = config;
            RNN <Sequence> rnn = RNN <Sequence> .CreateRNN(Config.NetworkType);

            rnn.LoadModel(config.ModelFilePath);
            rnn.MaxSeqLength = config.MaxSequenceLength;

            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
            Logger.WriteLine($"Max Sequence Length: {rnn.MaxSeqLength}");
            Logger.WriteLine($"Processor Count: {Environment.ProcessorCount}");

            qRNNs = new ConcurrentQueue <RNN <Sequence> >();
            for (var i = 0; i < Environment.ProcessorCount; i++)
            {
                qRNNs.Enqueue(rnn.Clone());
            }
        }
Exemplo n.º 8
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                m_Rnn = new BiRNN();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                m_Rnn = new ForwardRNN();
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Exemplo n.º 9
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                m_Rnn = new BiRNN();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                m_Rnn = new ForwardRNN();
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Exemplo n.º 10
0
        public RNNDecoder(string strModelFileName)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;
            MODELTYPE      modelType;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir, out modelType);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                rnn = new BiRNN <Sequence>();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                rnn = new ForwardRNN <Sequence>();
            }
            ModelType = modelType;

            rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", rnn.IsCRFTraining);
        }
Exemplo n.º 11
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE model_type = RNN.CheckModelFileType(strModelFileName);

            if (model_type == MODELTYPE.SIMPLE)
            {
                Console.WriteLine("Model Structure: Simple RNN");
                SimpleRNN sRNN = new SimpleRNN();
                m_Rnn = sRNN;
            }
            else
            {
                Console.WriteLine("Model Structure: LSTM-RNN");
                LSTMRNN lstmRNN = new LSTMRNN();
                m_Rnn = lstmRNN;
            }

            m_Rnn.loadNetBin(strModelFileName);
            Console.WriteLine("CRF Model: {0}", m_Rnn.IsCRFModel());
            m_Featurizer = featurizer;
        }
Exemplo n.º 12
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE modelType = MODELTYPE.SIMPLE;
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir);

            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine(Logger.Level.info, "Model Structure: Bi-directional RNN");
                if (modelType == MODELTYPE.SIMPLE)
                {
                    m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
                }
                else
                {
                    m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }
            else
            {
                if (modelType == MODELTYPE.SIMPLE)
                {
                    Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN");
                    m_Rnn = new SimpleRNN();
                }
                else
                {
                    Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN");
                    m_Rnn = new LSTMRNN();
                }
            }

            m_Rnn.loadNetBin(strModelFileName);
            Logger.WriteLine(Logger.Level.info, "CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Exemplo n.º 13
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE      modelType = MODELTYPE.SIMPLE;
            MODELDIRECTION modelDir  = MODELDIRECTION.FORWARD;

            RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir);

            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                if (modelType == MODELTYPE.SIMPLE)
                {
                    m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
                }
                else
                {
                    m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }
            else
            {
                if (modelType == MODELTYPE.SIMPLE)
                {
                    Logger.WriteLine("Model Structure: Simple RNN");
                    m_Rnn = new SimpleRNN();
                }
                else
                {
                    Logger.WriteLine("Model Structure: LSTM-RNN");
                    m_Rnn = new LSTMRNN();
                }
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
Exemplo n.º 14
0
 private void FreeRNNInstance(RNN <Sequence> r)
 {
     qRNNs.Enqueue(r);
 }
Exemplo n.º 15
0
        public void Train()
        {
            //Create neural net work
            Logger.WriteLine("Create a new network according settings in configuration file.");
            Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}");

            RNN <T> rnn = RNN <T> .CreateRNN(networkType);

            if (ModelSettings.IncrementalTrain)
            {
                Logger.WriteLine($"Loading previous trained model from {modelFilePath}.");
                rnn.LoadModel(modelFilePath, true);
            }
            else
            {
                Logger.WriteLine("Create a new network.");
                rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer);
                //Create tag-bigram transition probability matrix only for sequence RNN mode
                if (IsCRFTraining)
                {
                    Logger.WriteLine("Initialize bigram transition for CRF output layer.");
                    rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition);
                }
            }

            rnn.MaxSeqLength  = maxSequenceLength;
            rnn.bVQ           = ModelSettings.VQ != 0 ? true : false;
            rnn.IsCRFTraining = IsCRFTraining;

            int             N    = Environment.ProcessorCount * 2;
            List <RNN <T> > rnns = new List <RNN <T> >();

            rnns.Add(rnn);

            for (int i = 1; i < N; i++)
            {
                rnns.Add(rnn.Clone());
            }

            //Initialize RNNHelper
            RNNHelper.LearningRate          = ModelSettings.LearningRate;
            RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);

            RNNHelper.GradientCutoff = ModelSettings.GradientCutoff;
            RNNHelper.vecMaxGrad     = new Vector <float>(RNNHelper.GradientCutoff);
            RNNHelper.vecMinGrad     = new Vector <float>(-RNNHelper.GradientCutoff);
            RNNHelper.IsConstAlpha   = ModelSettings.IsConstAlpha;
            RNNHelper.MiniBatchSize  = ModelSettings.MiniBatchSize;

            Logger.WriteLine("");

            Logger.WriteLine("Iterative training begins ...");
            var             bestTrainTknErrCnt = long.MaxValue;
            var             bestValidTknErrCnt = long.MaxValue;
            var             lastAlpha          = RNNHelper.LearningRate;
            var             iter           = 0;
            ParallelOptions parallelOption = new ParallelOptions();

            parallelOption.MaxDegreeOfParallelism = -1;
            while (true)
            {
                if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration)
                {
                    Logger.WriteLine("We have trained this model {0} iteration, exit.");
                    break;
                }

                var start = DateTime.Now;
                Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}");

                //Clean all RNN instances' status for training
                foreach (var r in rnns)
                {
                    r.CleanStatusForTraining();
                }
                Process(rnns, TrainingSet, RunningMode.Training);

                var duration = DateTime.Now.Subtract(start);

                Logger.WriteLine($"End {iter} iteration. Time duration = {duration}");
                Logger.WriteLine("");

                if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate)
                {
                    //Although we reduce alpha value, we still cannot get better result.
                    Logger.WriteLine(
                        $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early.");
                    Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha);
                    break;
                }
                lastAlpha = RNNHelper.LearningRate;

                int trainTknErrCnt = tknErrCnt;
                //Validate the model by validated corpus
                if (ValidationSet != null)
                {
                    Logger.WriteLine("Verify model on validated corpus.");
                    Process(rnns, ValidationSet, RunningMode.Validate);
                    Logger.WriteLine("End model verification.");
                    Logger.WriteLine("");

                    if (tknErrCnt < bestValidTknErrCnt)
                    {
                        //We got better result on validated corpus, save this model
                        Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set.");
                        Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                        rnn.SaveModel(modelFilePath);
                        bestValidTknErrCnt = tknErrCnt;
                    }
                }
                else if (trainTknErrCnt < bestTrainTknErrCnt)
                {
                    //We don't have validate corpus, but we get a better result on training corpus
                    //We got better result on validated corpus, save this model
                    Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set.");
                    Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                    rnn.SaveModel(modelFilePath);
                }

                Logger.WriteLine("");

                if (trainTknErrCnt >= bestTrainTknErrCnt)
                {
                    //We don't have better result on training set, so reduce learning rate
                    RNNHelper.LearningRate          = RNNHelper.LearningRate / 2.0f;
                    RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);
                }
                else
                {
                    bestTrainTknErrCnt = trainTknErrCnt;
                }

                iter++;
            }
        }
Exemplo n.º 16
0
        public void Process(RNN <T> rnn, DataSet <T> trainingSet, RunningMode runningMode, int totalSequenceNum)
        {
            //Shffle training corpus
            trainingSet.Shuffle();

            for (var i = 0; i < trainingSet.SequenceList.Count; i++)
            {
                var pSequence = trainingSet.SequenceList[i];

                int wordCnt = 0;
                if (pSequence is Sequence)
                {
                    wordCnt = (pSequence as Sequence).States.Length;
                }
                else
                {
                    SequencePair sp = pSequence as SequencePair;
                    if (sp.srcSentence.TokensList.Count > rnn.MaxSeqLength)
                    {
                        continue;
                    }

                    wordCnt = sp.tgtSequence.States.Length;
                }

                if (wordCnt > rnn.MaxSeqLength)
                {
                    continue;
                }

                Interlocked.Add(ref processedWordCnt, wordCnt);

                int[] predicted;
                if (IsCRFTraining)
                {
                    predicted = rnn.ProcessSequenceCRF(pSequence as Sequence, runningMode);
                }
                else
                {
                    Matrix <float> m;
                    predicted = rnn.ProcessSequence(pSequence, runningMode, false, out m);
                }

                int newTknErrCnt;
                if (pSequence is Sequence)
                {
                    newTknErrCnt = GetErrorTokenNum(pSequence as Sequence, predicted);
                }
                else
                {
                    newTknErrCnt = GetErrorTokenNum((pSequence as SequencePair).tgtSequence, predicted);
                }

                Interlocked.Add(ref tknErrCnt, newTknErrCnt);
                if (newTknErrCnt > 0)
                {
                    Interlocked.Increment(ref sentErrCnt);
                }

                Interlocked.Increment(ref processedSequence);

                if (processedSequence % 1000 == 0)
                {
                    Logger.WriteLine("Progress = {0} ", processedSequence / 1000 + "K/" + totalSequenceNum / 1000.0 + "K");
                    Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)processedWordCnt * 100.0);
                    Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)processedSequence * 100.0);
                }

                if (ModelSettings.SaveStep > 0 && processedSequence % ModelSettings.SaveStep == 0)
                {
                    //After processed every m_SaveStep sentences, save current model into a temporary file
                    Logger.WriteLine("Saving temporary model into file...");
                    rnn.SaveModel("model.tmp");
                }
            }
        }