Exemplo n.º 1
0
 public void SetTrainingSet(DataSet train)
 {
     m_TrainingSet = train;
 }
Exemplo n.º 2
0
 public void SetValidationSet(DataSet validation)
 {
     m_ValidationSet = validation;
 }
Exemplo n.º 3
0
        public override void SetValidationSet(DataSet validation)
        {
            m_ValidationSet = validation;

            forwardRNN.SetValidationSet(validation);
            backwardRNN.SetValidationSet(validation);
        }
Exemplo n.º 4
0
        static void LoadDataset(string strFileName, Featurizer featurizer, DataSet dataSet)
        {
            CheckCorpus(strFileName);

            StreamReader sr = new StreamReader(strFileName);
            int RecordCount = 0;

            while (true)
            {
                //Extract features from it and convert it into sequence
                Sentence sent = new Sentence(ReadRecord(sr));
                if (sent.TokensList.Count <= 2)
                {
                    //No more record, it only contain <s> and </s>
                    break;
                }

                Sequence seq = featurizer.ExtractFeatures(sent);

                //Set label for the sequence
                seq.SetLabel(sent, featurizer.TagSet);

                //Add the sequence into data set
                dataSet.SequenceList.Add(seq);

                //Show state at every 1000 record
                RecordCount++;
                if (RecordCount % 10000 == 0)
                {
                    Logger.WriteLine(Logger.Level.info, "{0}...", RecordCount);
                }
            }

            sr.Close();

        }
Exemplo n.º 5
0
        public override void SetTrainingSet(DataSet train)
        {
            m_TrainingSet = train;
            fea_size = m_TrainingSet.GetDenseDimension();
            L0 = m_TrainingSet.GetSparseDimension() + L1;
            L2 = m_TrainingSet.GetTagSize();

            forwardRNN.SetTrainingSet(train);
            backwardRNN.SetTrainingSet(train);
        }
Exemplo n.º 6
0
        private static void Train()
        {
            if (File.Exists(strTagFile) == false)
            {
                Console.WriteLine("FAILED: The tag mapping file {0} isn't existed.", strTagFile);
                UsageTrain();
                return;
            }

            //Load tag id and its name from file
            TagSet tagSet = new TagSet(strTagFile);

            //Create configuration instance and set parameters
            ModelSetting RNNConfig = new ModelSetting();
            RNNConfig.SetModelFile(strModelFile);
            RNNConfig.SetNumHidden(layersize);
            RNNConfig.SetCRFTraining((iCRF == 1) ? true : false);
            RNNConfig.SetDir(iDir);
            RNNConfig.SetModelType(modelType);
            RNNConfig.SetMaxIteration(maxIter);
            RNNConfig.SetSaveStep(savestep);
            RNNConfig.SetLearningRate(alpha);
            RNNConfig.SetRegularization(beta);
            RNNConfig.SetBptt(bptt);

            //Dump RNN setting on console
            RNNConfig.DumpSetting();

            if (File.Exists(strFeatureConfigFile) == false)
            {
                Console.WriteLine("FAILED: The feature configuration file {0} isn't existed.", strFeatureConfigFile);
                UsageTrain();
                return;
            }
            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(strFeatureConfigFile, tagSet);
            featurizer.ShowFeatureSize();

            if (File.Exists(strTrainFile) == false)
            {
                Console.WriteLine("FAILED: The training corpus {0} isn't existed.", strTrainFile);
                UsageTrain();
                return;
            }

            //LoadFeatureConfig training corpus and extract feature set
            DataSet dataSetTrain = new DataSet(tagSet.GetSize());
            LoadDataset(strTrainFile, featurizer, dataSetTrain);

            if (File.Exists(strValidFile) == false)
            {
                Console.WriteLine("FAILED: The validated corpus {0} isn't existed.", strValidFile);
                UsageTrain();
                return;
            }

            //LoadFeatureConfig validated corpus and extract feature set
            DataSet dataSetValidation = new DataSet(tagSet.GetSize());
            LoadDataset(strValidFile, featurizer, dataSetValidation);

            //Create RNN encoder and save necessary parameters
            RNNEncoder encoder = new RNNEncoder(RNNConfig);
            encoder.SetTrainingSet(dataSetTrain);
            encoder.SetValidationSet(dataSetValidation);

            if (iCRF == 1)
            {
                Console.WriteLine("Initialize output tag bigram transition probability...");
                //Build tag bigram transition matrix
                dataSetTrain.BuildLabelBigramTransition();
                encoder.SetLabelBigramTransition(dataSetTrain.GetLabelBigramTransition());
            }

            //Start to train the model
            encoder.Train();
        }
Exemplo n.º 7
0
        static void LoadDataset(string strFileName, Featurizer featurizer, DataSet dataSet)
        {
            CheckCorpus(strFileName);

            StreamReader sr = new StreamReader(strFileName);
            int RecordCount = 0;

            while (true)
            {
                List<string> tokenList = ReadRecord(sr);
                if (tokenList.Count == 0)
                {
                    //No more record
                    break;
                }

                //Extract features from it and convert it into sequence
                Sentence sent = new Sentence();
                sent.SetFeatures(tokenList);
                Sequence seq = featurizer.ExtractFeatures(sent);

                //Set label for the sequence
                if (seq.SetLabel(sent, featurizer.GetTagSet()) == false)
                {
                    Console.WriteLine("Error: Invalidated record.");
                    sent.DumpFeatures();
                    continue;
                }

                //Add the sequence into data set
                dataSet.Add(seq);

                //Show state at every 1000 record
                RecordCount++;
                if (RecordCount % 10000 == 0)
                {
                    Console.Write("{0}...", RecordCount);
                }
            }

            Console.WriteLine();

            sr.Close();
        }
Exemplo n.º 8
0
        public bool ValidateNet(DataSet validationSet, int iter)
        {
            Logger.WriteLine("Start validation ...");
            int wordcn = 0;
            int tknErrCnt = 0;
            int sentErrCnt = 0;

            //Initialize varibles
            logp = 0;
            int numSequence = validationSet.SequenceList.Count;
            for (int curSequence = 0; curSequence < numSequence; curSequence++)
            {
                Sequence pSequence = validationSet.SequenceList[curSequence];
                wordcn += pSequence.States.Length;

                int[] predicted;
                if (IsCRFTraining == true)
                {
                    predicted = ProcessSequenceCRF(pSequence, RunningMode.Validate);
                }
                else
                {
                    Matrix<double> m;
                    m = ProcessSequence(pSequence, RunningMode.Validate);
                    predicted = GetBestResult(m);
                }

                int newTknErrCnt = GetErrorTokenNum(pSequence, predicted);
                tknErrCnt += newTknErrCnt;
                if (newTknErrCnt > 0)
                {
                    sentErrCnt++;
                }
            }

            double entropy = -logp / Math.Log10(2.0) / wordcn;
            double ppl = exp_10(-logp / wordcn);
            double tknErrRatio = (double)tknErrCnt / (double)wordcn * 100.0;
            double sentErrRatio = (double)sentErrCnt / (double)numSequence * 100.0;

            Logger.WriteLine("In validation: error token ratio = {0}% error sentence ratio = {1}%", tknErrRatio, sentErrRatio);
            Logger.WriteLine("In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);
            Logger.WriteLine("");

            bool bUpdate = false;
            if (tknErrRatio < minTknErrRatio)
            {
                //We have better result on validated set, save this model
                bUpdate = true;
                minTknErrRatio = tknErrRatio;
            }

            return bUpdate;
        }
Exemplo n.º 9
0
        public double TrainNet(DataSet trainingSet, int iter)
        {
            DateTime start = DateTime.Now;
            Logger.WriteLine("Iter " + iter + " begins with learning rate alpha = " + RNNHelper.LearningRate + " ...");

            //Initialize varibles
            logp = 0;

            //Shffle training corpus
            trainingSet.Shuffle();

            int numSequence = trainingSet.SequenceList.Count;
            int wordCnt = 0;
            int tknErrCnt = 0;
            int sentErrCnt = 0;
            Logger.WriteLine("Progress = 0/" + numSequence / 1000.0 + "K\r");
            for (int curSequence = 0; curSequence < numSequence; curSequence++)
            {
                Sequence pSequence = trainingSet.SequenceList[curSequence];
                wordCnt += pSequence.States.Length;

                int[] predicted;
                if (IsCRFTraining == true)
                {
                    predicted = ProcessSequenceCRF(pSequence, RunningMode.Train);
                }
                else
                {
                    Matrix<double> m;
                    m = ProcessSequence(pSequence, RunningMode.Train);
                    predicted = GetBestResult(m);
                }

                int newTknErrCnt = GetErrorTokenNum(pSequence, predicted);
                tknErrCnt += newTknErrCnt;
                if (newTknErrCnt > 0)
                {
                    sentErrCnt++;
                }

                if ((curSequence + 1) % 1000 == 0)
                {
                    Logger.WriteLine("Progress = {0} ", (curSequence + 1) / 1000 + "K/" + numSequence / 1000.0 + "K");
                    Logger.WriteLine(" Train cross-entropy = {0} ", -logp / Math.Log10(2.0) / wordCnt);
                    Logger.WriteLine(" Error token ratio = {0}%", (double)tknErrCnt / (double)wordCnt * 100.0);
                    Logger.WriteLine(" Error sentence ratio = {0}%", (double)sentErrCnt / (double)curSequence * 100.0);
                }

                if (SaveStep > 0 && (curSequence + 1) % SaveStep == 0)
                {
                    //After processed every m_SaveStep sentences, save current model into a temporary file
                    Logger.WriteLine("Saving temporary model into file...");
                    SaveModel(ModelTempFile);
                }
            }

            DateTime now = DateTime.Now;
            TimeSpan duration = now.Subtract(start);

            double entropy = -logp / Math.Log10(2.0) / wordCnt;
            double ppl = exp_10(-logp / wordCnt);
            Logger.WriteLine("Iter " + iter + " completed");
            Logger.WriteLine("Sentences = " + numSequence + ", time escape = " + duration + "s, speed = " + numSequence / duration.TotalSeconds);
            Logger.WriteLine("In training: log probability = " + logp + ", cross-entropy = " + entropy + ", perplexity = " + ppl);

            return ppl;
        }