示例#1
0
        public void InitializeAutoEncoder()
        {
            Logger.WriteLine("Initialize auto encoder...");

            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(autoEncoderFeatureConfigFile, null);

            //Create instance for decoder
            autoEncoder = new RNNSharp.RNNDecoder(autoEncoderModelFile, featurizer);
        }
示例#2
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                m_Rnn = new BiRNN();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                m_Rnn = new ForwardRNN();
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
示例#3
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNNHelper.CheckModelFileType(strModelFileName, out modelDir);
            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                m_Rnn = new BiRNN();
            }
            else
            {
                Logger.WriteLine("Model Structure: Simple RNN");
                m_Rnn = new ForwardRNN();
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
示例#4
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE model_type = RNN.CheckModelFileType(strModelFileName);

            if (model_type == MODELTYPE.SIMPLE)
            {
                Console.WriteLine("Model Structure: Simple RNN");
                SimpleRNN sRNN = new SimpleRNN();
                m_Rnn = sRNN;
            }
            else
            {
                Console.WriteLine("Model Structure: LSTM-RNN");
                LSTMRNN lstmRNN = new LSTMRNN();
                m_Rnn = lstmRNN;
            }

            m_Rnn.loadNetBin(strModelFileName);
            Console.WriteLine("CRF Model: {0}", m_Rnn.IsCRFModel());
            m_Featurizer = featurizer;
        }
示例#5
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE modelType = MODELTYPE.SIMPLE;
            MODELDIRECTION modelDir = MODELDIRECTION.FORWARD;

            RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir);

            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine(Logger.Level.info, "Model Structure: Bi-directional RNN");
                if (modelType == MODELTYPE.SIMPLE)
                {
                    m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
                }
                else
                {
                    m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }
            else
            {
                if (modelType == MODELTYPE.SIMPLE)
                {
                    Logger.WriteLine(Logger.Level.info, "Model Structure: Simple RNN");
                    m_Rnn = new SimpleRNN();
                }
                else
                {
                    Logger.WriteLine(Logger.Level.info, "Model Structure: LSTM-RNN");
                    m_Rnn = new LSTMRNN();
                }
            }

            m_Rnn.loadNetBin(strModelFileName);
            Logger.WriteLine(Logger.Level.info, "CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
示例#6
0
        public RNNDecoder(string strModelFileName, Featurizer featurizer)
        {
            MODELTYPE      modelType = MODELTYPE.SIMPLE;
            MODELDIRECTION modelDir  = MODELDIRECTION.FORWARD;

            RNN.CheckModelFileType(strModelFileName, out modelType, out modelDir);

            if (modelDir == MODELDIRECTION.BI_DIRECTIONAL)
            {
                Logger.WriteLine("Model Structure: Bi-directional RNN");
                if (modelType == MODELTYPE.SIMPLE)
                {
                    m_Rnn = new BiRNN(new SimpleRNN(), new SimpleRNN());
                }
                else
                {
                    m_Rnn = new BiRNN(new LSTMRNN(), new LSTMRNN());
                }
            }
            else
            {
                if (modelType == MODELTYPE.SIMPLE)
                {
                    Logger.WriteLine("Model Structure: Simple RNN");
                    m_Rnn = new SimpleRNN();
                }
                else
                {
                    Logger.WriteLine("Model Structure: LSTM-RNN");
                    m_Rnn = new LSTMRNN();
                }
            }

            m_Rnn.LoadModel(strModelFileName);
            Logger.WriteLine("CRF Model: {0}", m_Rnn.IsCRFTraining);
            m_Featurizer = featurizer;
        }
示例#7
0
        private static void Test()
        {
            if (String.IsNullOrEmpty(strTagFile) == true)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The tag mapping file {0} isn't specified.", strTagFile);
                UsageTest();
                return;
            }

            //Load tag name
            TagSet tagSet = new TagSet(strTagFile);

            if (String.IsNullOrEmpty(strModelFile) == true)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The model file {0} isn't specified.", strModelFile);
                UsageTest();
                return;
            }

            if (String.IsNullOrEmpty(strFeatureConfigFile) == true)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The feature configuration file {0} isn't specified.", strFeatureConfigFile);
                UsageTest();
                return;
            }

            if (strOutputFile.Length == 0)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The output file name should not be empty.");
                UsageTest();
                return;
            }

            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(strFeatureConfigFile, tagSet);
            featurizer.ShowFeatureSize();

            //Create instance for decoder
            RNNSharp.RNNDecoder decoder = new RNNSharp.RNNDecoder(strModelFile, featurizer);


            if (File.Exists(strTestFile) == false)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The test corpus {0} isn't existed.", strTestFile);
                UsageTest();
                return;
            }

            StreamReader sr = new StreamReader(strTestFile);
            StreamWriter sw = new StreamWriter(strOutputFile);

            while (true)
            {
                Sentence sent = new Sentence(ReadRecord(sr));
                if (sent.TokensList.Count <= 2)
                {
                    //No more record, it only contains <s> and </s>
                    break;
                }

                if (nBest == 1)
                {
                    int[] output = decoder.Process(sent);
                    //Output decoded result
                    //Append the decoded result into the end of feature set of each token
                    StringBuilder sb = new StringBuilder();
                    for (int i = 0; i < sent.TokensList.Count; i++)
                    {
                        string tokens = String.Join("\t", sent.TokensList[i]);
                        sb.Append(tokens);
                        sb.Append("\t");
                        sb.Append(tagSet.GetTagName(output[i]));
                        sb.AppendLine();
                    }

                    sw.WriteLine(sb.ToString());
                }
                else
                {
                    int[][] output = decoder.ProcessNBest(sent, nBest);
                    StringBuilder sb = new StringBuilder();
                    for (int i = 0; i < nBest; i++)
                    {
                        for (int j = 0; j < sent.TokensList.Count; j++)
                        {
                            string tokens = String.Join("\t", sent.TokensList[i]);
                            sb.Append(tokens);
                            sb.Append("\t");
                            sb.Append(tagSet.GetTagName(output[i][j]));
                            sb.AppendLine();
                        }
                        sb.AppendLine();
                    }

                    sw.WriteLine(sb.ToString());
                }
            }

            sr.Close();
            sw.Close();
        }
示例#8
0
        private static void Train()
        {
            Logger.LogFile = "RNNSharpConsole.log";

            if (File.Exists(strTagFile) == false)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The tag mapping file {0} isn't existed.", strTagFile);
                UsageTrain();
                return;
            }

            //Load tag id and its name from file
            TagSet tagSet = new TagSet(strTagFile);

            //Create configuration instance and set parameters
            ModelSetting RNNConfig = new ModelSetting();
            RNNConfig.ModelFile = strModelFile;
            RNNConfig.NumHidden = layersize;
            RNNConfig.IsCRFTraining = (iCRF == 1) ? true : false;
            RNNConfig.ModelDirection = iDir;
            RNNConfig.ModelType = modelType;
            RNNConfig.MaxIteration = maxIter;
            RNNConfig.SaveStep = savestep;
            RNNConfig.LearningRate = alpha;
            RNNConfig.Dropout = dropout;
            RNNConfig.Bptt = bptt;

            //Dump RNN setting on console
            RNNConfig.DumpSetting();

            if (File.Exists(strFeatureConfigFile) == false)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The feature configuration file {0} doesn't exist.", strFeatureConfigFile);
                UsageTrain();
                return;
            }
            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(strFeatureConfigFile, tagSet);
            featurizer.ShowFeatureSize();

            if (featurizer.IsRunTimeFeatureUsed() == true && iDir == 1)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: Run time feature is not available for bi-directional RNN model.");
                UsageTrain();
                return;
            }

            if (File.Exists(strTrainFile) == false)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The training corpus doesn't exist.");
                UsageTrain();
                return;
            }

            if (File.Exists(strValidFile) == false)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The validation corpus doesn't exist.");
                UsageTrain();
                return;
            }

            //Create RNN encoder and save necessary parameters
            RNNEncoder encoder = new RNNEncoder(RNNConfig);

            //LoadFeatureConfig training corpus and extract feature set
            encoder.TrainingSet = new DataSet(tagSet.GetSize());
            LoadDataset(strTrainFile, featurizer, encoder.TrainingSet);

            //LoadFeatureConfig validated corpus and extract feature set
            encoder.ValidationSet = new DataSet(tagSet.GetSize());
            LoadDataset(strValidFile, featurizer, encoder.ValidationSet);

            if (iCRF == 1)
            {
                Logger.WriteLine(Logger.Level.info, "Initialize output tag bigram transition probability...");
                //Build tag bigram transition matrix
                encoder.TrainingSet.BuildLabelBigramTransition();
            }

            //Start to train the model
            encoder.Train();

        }
示例#9
0
 public void SetFeaturizer(Featurizer featurizer)
 {
     Featurizer = featurizer;
 }
示例#10
0
        static void LoadDataset(string strFileName, Featurizer featurizer, DataSet dataSet)
        {
            CheckCorpus(strFileName);

            StreamReader sr = new StreamReader(strFileName);
            int RecordCount = 0;

            while (true)
            {
                //Extract features from it and convert it into sequence
                Sentence sent = new Sentence(ReadRecord(sr));
                if (sent.TokensList.Count <= 2)
                {
                    //No more record, it only contain <s> and </s>
                    break;
                }

                Sequence seq = featurizer.ExtractFeatures(sent);

                //Set label for the sequence
                seq.SetLabel(sent, featurizer.TagSet);

                //Add the sequence into data set
                dataSet.SequenceList.Add(seq);

                //Show state at every 1000 record
                RecordCount++;
                if (RecordCount % 10000 == 0)
                {
                    Logger.WriteLine(Logger.Level.info, "{0}...", RecordCount);
                }
            }

            sr.Close();

        }
示例#11
0
文件: Program.cs 项目: zxz/RNNSharp
        private static void Train()
        {
            if (File.Exists(strTagFile) == false)
            {
                Console.WriteLine("FAILED: The tag mapping file {0} isn't existed.", strTagFile);
                UsageTrain();
                return;
            }

            //Load tag id and its name from file
            TagSet tagSet = new TagSet(strTagFile);

            //Create configuration instance and set parameters
            ModelSetting RNNConfig = new ModelSetting();
            RNNConfig.SetModelFile(strModelFile);
            RNNConfig.SetNumHidden(layersize);
            RNNConfig.SetCRFTraining((iCRF == 1) ? true : false);
            RNNConfig.SetDir(iDir);
            RNNConfig.SetModelType(modelType);
            RNNConfig.SetMaxIteration(maxIter);
            RNNConfig.SetSaveStep(savestep);
            RNNConfig.SetLearningRate(alpha);
            RNNConfig.SetRegularization(beta);
            RNNConfig.SetBptt(bptt);

            //Dump RNN setting on console
            RNNConfig.DumpSetting();

            if (File.Exists(strFeatureConfigFile) == false)
            {
                Console.WriteLine("FAILED: The feature configuration file {0} isn't existed.", strFeatureConfigFile);
                UsageTrain();
                return;
            }
            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(strFeatureConfigFile, tagSet);
            featurizer.ShowFeatureSize();

            if (File.Exists(strTrainFile) == false)
            {
                Console.WriteLine("FAILED: The training corpus {0} isn't existed.", strTrainFile);
                UsageTrain();
                return;
            }

            //LoadFeatureConfig training corpus and extract feature set
            DataSet dataSetTrain = new DataSet(tagSet.GetSize());
            LoadDataset(strTrainFile, featurizer, dataSetTrain);

            if (File.Exists(strValidFile) == false)
            {
                Console.WriteLine("FAILED: The validated corpus {0} isn't existed.", strValidFile);
                UsageTrain();
                return;
            }

            //LoadFeatureConfig validated corpus and extract feature set
            DataSet dataSetValidation = new DataSet(tagSet.GetSize());
            LoadDataset(strValidFile, featurizer, dataSetValidation);

            //Create RNN encoder and save necessary parameters
            RNNEncoder encoder = new RNNEncoder(RNNConfig);
            encoder.SetTrainingSet(dataSetTrain);
            encoder.SetValidationSet(dataSetValidation);

            if (iCRF == 1)
            {
                Console.WriteLine("Initialize output tag bigram transition probability...");
                //Build tag bigram transition matrix
                dataSetTrain.BuildLabelBigramTransition();
                encoder.SetLabelBigramTransition(dataSetTrain.GetLabelBigramTransition());
            }

            //Start to train the model
            encoder.Train();
        }
示例#12
0
 public override int[] TestSeq2Seq(Sentence srcSentence, Featurizer featurizer)
 {
     throw new NotImplementedException();
 }
示例#13
0
文件: Program.cs 项目: zxz/RNNSharp
        private static void Test()
        {
            if (File.Exists(strTagFile) == false)
            {
                Console.WriteLine("FAILED: The tag mapping file {0} isn't existed.", strTagFile);
                UsageTest();
                return;
            }

            //Load tag id and its name from file
            TagSet tagSet = new TagSet(strTagFile);

            if (File.Exists(strModelFile) == false)
            {
                Console.WriteLine("FAILED: The model file {0} isn't existed.", strModelFile);
                UsageTest();
                return;
            }

            if (File.Exists(strFeatureConfigFile) == false)
            {
                Console.WriteLine("FAILED: The feature configuration file {0} isn't existed.", strFeatureConfigFile);
                UsageTest();
                return;
            }

            if (strOutputFile.Length == 0)
            {
                Console.WriteLine("FAILED: The output file name should not be empty.");
                UsageTest();
                return;
            }

            //Create feature extractors and load word embedding data from file
            Featurizer featurizer = new Featurizer(strFeatureConfigFile, tagSet);
            featurizer.ShowFeatureSize();

            //Create an instance for the model
               // Model model = new Model(strModelFile);

            //Create instance for decoder
            RNNSharp.RNNDecoder decoder = new RNNSharp.RNNDecoder(strModelFile, featurizer);

            if (File.Exists(strTestFile) == false)
            {
                Console.WriteLine("FAILED: The test corpus {0} isn't existed.", strTestFile);
                UsageTest();
                return;
            }

            StreamReader sr = new StreamReader(strTestFile);
            StreamWriter sw = new StreamWriter(strOutputFile);

            while (true)
            {
                List<string> tokenList = ReadRecord(sr);
                if (tokenList.Count == 0)
                {
                    //No more record
                    break;
                }

                Sentence sent = new Sentence();
                sent.SetFeatures(tokenList);

                if (nBest == 1)
                {
                    int[] output = decoder.Process(sent);
                    //Output decoded result
                    //Append the decoded result into the end of feature set of each token
                    StringBuilder sb = new StringBuilder();
                    for (int i = 0; i < tokenList.Count; i++)
                    {
                        sb.Append(tokenList[i]);
                        sb.Append("\t");
                        sb.Append(tagSet.GetTagName(output[i]));
                        sb.AppendLine();
                    }

                    sw.WriteLine(sb.ToString());
                }
                else
                {
                    int[][] output = decoder.ProcessNBest(sent, nBest);
                    if (output == null)
                    {
                        Console.WriteLine("FAILED: decode failed. Dump current sentence...");
                        sent.DumpFeatures();
                        return;
                    }

                    StringBuilder sb = new StringBuilder();
                    for (int i = 0; i < nBest; i++)
                    {
                        for (int j = 0; j < tokenList.Count; j++)
                        {
                            sb.Append(tokenList[j]);
                            sb.Append("\t");
                            sb.Append(tagSet.GetTagName(output[i][j]));
                            sb.AppendLine();
                        }
                        sb.AppendLine();
                    }

                    sw.WriteLine(sb.ToString());
                }
            }

            sr.Close();
            sw.Close();
        }
示例#14
0
文件: Program.cs 项目: zxz/RNNSharp
        static void LoadDataset(string strFileName, Featurizer featurizer, DataSet dataSet)
        {
            CheckCorpus(strFileName);

            StreamReader sr = new StreamReader(strFileName);
            int RecordCount = 0;

            while (true)
            {
                List<string> tokenList = ReadRecord(sr);
                if (tokenList.Count == 0)
                {
                    //No more record
                    break;
                }

                //Extract features from it and convert it into sequence
                Sentence sent = new Sentence();
                sent.SetFeatures(tokenList);
                Sequence seq = featurizer.ExtractFeatures(sent);

                //Set label for the sequence
                if (seq.SetLabel(sent, featurizer.GetTagSet()) == false)
                {
                    Console.WriteLine("Error: Invalidated record.");
                    sent.DumpFeatures();
                    continue;
                }

                //Add the sequence into data set
                dataSet.Add(seq);

                //Show state at every 1000 record
                RecordCount++;
                if (RecordCount % 10000 == 0)
                {
                    Console.Write("{0}...", RecordCount);
                }
            }

            Console.WriteLine();

            sr.Close();
        }
示例#15
0
        public override int[] TestSeq2Seq(Sentence srcSentence, Featurizer featurizer)
        {
            State curState = featurizer.ExtractFeatures(new string[] { "<s>" });

            curState.Label = featurizer.TagSet.GetIndex("<s>");

            //Reset all layers
            foreach (SimpleLayer layer in HiddenLayerList)
            {
                layer.netReset(false);
            }

            //Extract features from source sentence
            Sequence srcSequence = featurizer.AutoEncoder.Featurizer.ExtractFeatures(srcSentence);

            double[] srcHiddenAvgOutput;
            Dictionary <int, float> srcSparseFeatures;

            ExtractSourceSentenceFeature(featurizer.AutoEncoder, srcSequence, curState.SparseFeature.Length, out srcHiddenAvgOutput, out srcSparseFeatures);

            int        numLayers = HiddenLayerList.Count;
            List <int> predicted = new List <int>();

            predicted.Add(curState.Label);
            while (true)
            {
                //Build sparse features
                SparseVector sparseVector = new SparseVector();
                sparseVector.SetLength(curState.SparseFeature.Length + srcSequence.SparseFeatureSize);
                sparseVector.AddKeyValuePairData(curState.SparseFeature);
                sparseVector.AddKeyValuePairData(srcSparseFeatures);

                //Compute first layer
                double[] denseFeatures = RNNHelper.ConcatenateVector(curState.DenseFeature, srcHiddenAvgOutput);
                HiddenLayerList[0].computeLayer(sparseVector, denseFeatures, false);

                //Compute middle layers
                for (int i = 1; i < numLayers; i++)
                {
                    //We use previous layer's output as dense feature for current layer
                    denseFeatures = RNNHelper.ConcatenateVector(HiddenLayerList[i - 1].cellOutput, srcHiddenAvgOutput);
                    HiddenLayerList[i].computeLayer(sparseVector, denseFeatures, false);
                }

                //Compute output layer
                denseFeatures = RNNHelper.ConcatenateVector(HiddenLayerList[numLayers - 1].cellOutput, srcHiddenAvgOutput);
                OutputLayer.computeLayer(sparseVector, denseFeatures, false);

                OutputLayer.Softmax(false);

                int    nextTagId = OutputLayer.GetBestOutputIndex(false);
                string nextWord  = featurizer.TagSet.GetTagName(nextTagId);

                curState       = featurizer.ExtractFeatures(new string[] { nextWord });
                curState.Label = nextTagId;

                predicted.Add(nextTagId);

                if (nextWord == "</s>" || predicted.Count >= 100)
                {
                    break;
                }
            }

            return(predicted.ToArray());
        }
示例#16
0
 public int[] DecodeSeq2Seq(Sentence srcSent, Featurizer feature)
 {
     return(TestSeq2Seq(srcSent, feature));
 }
示例#17
0
 public abstract int[] TestSeq2Seq(Sentence srcSentence, Featurizer featurizer);