Exemplo n.º 1
0
        public override void SaveModel(string filename)
        {
            //Save bi-directional model
            forwardRNN.Hidden2OutputWeight  = Hidden2OutputWeight;
            backwardRNN.Hidden2OutputWeight = Hidden2OutputWeight;

            forwardRNN.CRFTagTransWeights  = CRFTagTransWeights;
            backwardRNN.CRFTagTransWeights = CRFTagTransWeights;

            forwardRNN.SaveModel(filename + ".forward");
            backwardRNN.SaveModel(filename + ".backward");

            //Save meta data
            using (StreamWriter sw = new StreamWriter(filename))
            {
                BinaryWriter fo = new BinaryWriter(sw.BaseStream);
                fo.Write((int)ModelType);
                fo.Write((int)ModelDirection);

                // Signiture , 0 is for RNN or 1 is for RNN-CRF
                int iflag = 0;
                if (IsCRFTraining == true)
                {
                    iflag = 1;
                }
                fo.Write(iflag);

                fo.Write(L0);
                fo.Write(L1);
                fo.Write(L2);
                fo.Write(DenseFeatureSize);
            }
        }
Exemplo n.º 2
0
        public void Train()
        {
            //Create neural net work
            Logger.WriteLine("Create a new network according settings in configuration file.");
            Logger.WriteLine($"Processor Count = {Environment.ProcessorCount}");

            RNN <T> rnn = RNN <T> .CreateRNN(networkType);

            if (ModelSettings.IncrementalTrain)
            {
                Logger.WriteLine($"Loading previous trained model from {modelFilePath}.");
                rnn.LoadModel(modelFilePath, true);
            }
            else
            {
                Logger.WriteLine("Create a new network.");
                rnn.CreateNetwork(hiddenLayersConfig, outputLayerConfig, TrainingSet, featurizer);
                //Create tag-bigram transition probability matrix only for sequence RNN mode
                if (IsCRFTraining)
                {
                    Logger.WriteLine("Initialize bigram transition for CRF output layer.");
                    rnn.InitializeCRFWeights(TrainingSet.CRFLabelBigramTransition);
                }
            }

            rnn.MaxSeqLength  = maxSequenceLength;
            rnn.bVQ           = ModelSettings.VQ != 0 ? true : false;
            rnn.IsCRFTraining = IsCRFTraining;

            int             N    = Environment.ProcessorCount * 2;
            List <RNN <T> > rnns = new List <RNN <T> >();

            rnns.Add(rnn);

            for (int i = 1; i < N; i++)
            {
                rnns.Add(rnn.Clone());
            }

            //Initialize RNNHelper
            RNNHelper.LearningRate          = ModelSettings.LearningRate;
            RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);

            RNNHelper.GradientCutoff = ModelSettings.GradientCutoff;
            RNNHelper.vecMaxGrad     = new Vector <float>(RNNHelper.GradientCutoff);
            RNNHelper.vecMinGrad     = new Vector <float>(-RNNHelper.GradientCutoff);
            RNNHelper.IsConstAlpha   = ModelSettings.IsConstAlpha;
            RNNHelper.MiniBatchSize  = ModelSettings.MiniBatchSize;

            Logger.WriteLine("");

            Logger.WriteLine("Iterative training begins ...");
            var             bestTrainTknErrCnt = long.MaxValue;
            var             bestValidTknErrCnt = long.MaxValue;
            var             lastAlpha          = RNNHelper.LearningRate;
            var             iter           = 0;
            ParallelOptions parallelOption = new ParallelOptions();

            parallelOption.MaxDegreeOfParallelism = -1;
            while (true)
            {
                if (ModelSettings.MaxIteration > 0 && iter > ModelSettings.MaxIteration)
                {
                    Logger.WriteLine("We have trained this model {0} iteration, exit.");
                    break;
                }

                var start = DateTime.Now;
                Logger.WriteLine($"Start to training {iter} iteration. learning rate = {RNNHelper.LearningRate}");

                //Clean all RNN instances' status for training
                foreach (var r in rnns)
                {
                    r.CleanStatusForTraining();
                }
                Process(rnns, TrainingSet, RunningMode.Training);

                var duration = DateTime.Now.Subtract(start);

                Logger.WriteLine($"End {iter} iteration. Time duration = {duration}");
                Logger.WriteLine("");

                if (tknErrCnt >= bestTrainTknErrCnt && lastAlpha != RNNHelper.LearningRate)
                {
                    //Although we reduce alpha value, we still cannot get better result.
                    Logger.WriteLine(
                        $"Current token error count({(double)tknErrCnt / (double)processedWordCnt * 100.0}%) is larger than the previous one({(double)bestTrainTknErrCnt / (double)processedWordCnt * 100.0}%) on training set. End training early.");
                    Logger.WriteLine("Current alpha: {0}, the previous alpha: {1}", RNNHelper.LearningRate, lastAlpha);
                    break;
                }
                lastAlpha = RNNHelper.LearningRate;

                int trainTknErrCnt = tknErrCnt;
                //Validate the model by validated corpus
                if (ValidationSet != null)
                {
                    Logger.WriteLine("Verify model on validated corpus.");
                    Process(rnns, ValidationSet, RunningMode.Validate);
                    Logger.WriteLine("End model verification.");
                    Logger.WriteLine("");

                    if (tknErrCnt < bestValidTknErrCnt)
                    {
                        //We got better result on validated corpus, save this model
                        Logger.WriteLine($"Saving better model into file {modelFilePath}, since we got a better result on validation set.");
                        Logger.WriteLine($"Error token percent: {(double)tknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                        rnn.SaveModel(modelFilePath);
                        bestValidTknErrCnt = tknErrCnt;
                    }
                }
                else if (trainTknErrCnt < bestTrainTknErrCnt)
                {
                    //We don't have validate corpus, but we get a better result on training corpus
                    //We got better result on validated corpus, save this model
                    Logger.WriteLine($"Saving better model into file {modelFilePath}, although validation set doesn't exist, we have better result on training set.");
                    Logger.WriteLine($"Error token percent: {(double)trainTknErrCnt / (double)processedWordCnt * 100.0}%, Error sequence percent: {(double)sentErrCnt / (double)processedSequence * 100.0}%");

                    rnn.SaveModel(modelFilePath);
                }

                Logger.WriteLine("");

                if (trainTknErrCnt >= bestTrainTknErrCnt)
                {
                    //We don't have better result on training set, so reduce learning rate
                    RNNHelper.LearningRate          = RNNHelper.LearningRate / 2.0f;
                    RNNHelper.vecNormalLearningRate = new Vector <float>(RNNHelper.LearningRate);
                }
                else
                {
                    bestTrainTknErrCnt = trainTknErrCnt;
                }

                iter++;
            }
        }
Exemplo n.º 3
0
        public void Process(RNN <T> rnn, DataSet <T> trainingSet, RunningMode runningMode, int totalSequenceNum)
        {
            //Shffle training corpus
            trainingSet.Shuffle();

            for (var i = 0; i < trainingSet.SequenceList.Count; i++)
            {
                var pSequence = trainingSet.SequenceList[i];

                int wordCnt = 0;
                if (pSequence is Sequence)
                {
                    wordCnt = (pSequence as Sequence).States.Length;
                }
                else
                {
                    SequencePair sp = pSequence as SequencePair;
                    if (sp.srcSentence.TokensList.Count > rnn.MaxSeqLength)
                    {
                        continue;
                    }

                    wordCnt = sp.tgtSequence.States.Length;
                }

                if (wordCnt > rnn.MaxSeqLength)
                {
                    continue;
                }

                Interlocked.Add(ref processedWordCnt, wordCnt);

                int[] predicted;
                if (IsCRFTraining)
                {
                    predicted = rnn.ProcessSequenceCRF(pSequence as Sequence, runningMode);
                }
                else
                {
                    Matrix <float> m;
                    predicted = rnn.ProcessSequence(pSequence, runningMode, false, out m);
                }

                int newTknErrCnt;
                if (pSequence is Sequence)
                {
                    newTknErrCnt = GetErrorTokenNum(pSequence as Sequence, predicted);
                }
                else
                {
                    newTknErrCnt = GetErrorTokenNum((pSequence as SequencePair).tgtSequence, predicted);
                }

                Interlocked.Add(ref tknErrCnt, newTknErrCnt);
                if (newTknErrCnt > 0)
                {
                    Interlocked.Increment(ref sentErrCnt);
                }

                Interlocked.Increment(ref processedSequence);

                if (processedSequence % 1000 == 0)
                {
                    Logger.WriteLine("Progress = {0} ", processedSequence / 1000 + "K/" + totalSequenceNum / 1000.0 + "K");
                    Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)processedWordCnt * 100.0);
                    Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)processedSequence * 100.0);
                }

                if (ModelSettings.SaveStep > 0 && processedSequence % ModelSettings.SaveStep == 0)
                {
                    //After processed every m_SaveStep sentences, save current model into a temporary file
                    Logger.WriteLine("Saving temporary model into file...");
                    rnn.SaveModel("model.tmp");
                }
            }
        }