Esempio n. 1
0
        public static LSTMLayer readLSTM(string path)
        {
            FileStream      fs = new FileStream(path, FileMode.Open);
            BinaryFormatter bf = new BinaryFormatter();
            LSTMLayer       ps = bf.Deserialize(fs) as LSTMLayer;

            fs.Close();
            return(ps);
        }
Esempio n. 2
0
        public static void Run()
        {
            System.DateTime currentTime = System.DateTime.Now;
            readword();


            Global.inputDim  = 100;
            Global.hiddenDim = 100;

            readfourword();

            Global.randn = new Normal();
            DataSet X = new DataSet();

            if (Global.isRead == 1)
            {
                string path = "model//deepNetwork";

                Global.upLSTMLayer      = LSTMLayer.readLSTM(path + "//lstmmodel.txt");
                Global.upLSTMLayerr     = LSTMLayer.readLSTM(path + "//lstmmodelr.txt");
                Global.GRNNLayer1       = GRNNLayer.readGRNN(path + "//grnnmodel1.txt");
                Global.GRNNLayer2       = GRNNLayer.readGRNN(path + "//grnnmodel2.txt");
                Global.GRNNLayer3       = GRNNLayer.readGRNN(path + "//grnnmodel3.txt");
                Global.GRNNLayer4       = GRNNLayer.readGRNN(path + "//grnnmodel4.txt");
                Global.feedForwardLayer = FeedForwardLayer.readFF(path + "//feedforwardmodel.txt");

                Global.wordEmbedding = LSTMLayer.getSerializeWordembedding(path + "//embedding.txt", Global.wordEmbedding);
            }

            else
            {
                Global.GRNNLayer1       = new GRNNLayer();
                Global.GRNNLayer2       = new GRNNLayer();
                Global.GRNNLayer3       = new GRNNLayer();
                Global.GRNNLayer4       = new GRNNLayer();
                Global.upLSTMLayer      = new LSTMLayer();
                Global.upLSTMLayerr     = new LSTMLayer();
                Global.feedForwardLayer = new FeedForwardLayer();
            }


            Trainer.train(X);

            Global.swLog.Close();

            System.DateTime currentTime_1 = System.DateTime.Now;

            Console.WriteLine(currentTime_1 - currentTime);

            Console.Read();
        }
Esempio n. 3
0
        public static void train(DataSet X)
        {
            if (Global.mode == "test")
            {
                Console.WriteLine("begain testing......");
                testAccuracy1 = runtestIteration(X.Testing, false);
                Console.WriteLine("Epoch test  f-score: {0}", (testAccuracy1 * 100).ToString("f3"));
                //Console.WriteLine("Epoch test  best f-score: {0}", (testAccuracy * 100).ToString("f3"));
                Global.swLog.WriteLine("Epoch test fscore: {0}", (testAccuracy1 * 100).ToString("f3"));
                //Global.swLog.WriteLine("Epoch test best fscore: {0}", (testAccuracy * 100).ToString("f3"));
                Postprocessing.transfer("data/temp/" + "test_raw.txt");

                Console.WriteLine("Finished");
            }


            //else if (Global.mode == "test")
            //{
            //    Console.WriteLine("begain testing......");
            //    testAccuracy1 = runtestIteration(X.Testing, false);
            //    Postprocessing.transfer(Global.readFile);
            //    Console.WriteLine("predict fininshed");
            //    Global.swLog.WriteLine("predict fininshed");

            //}
            else if (Global.mode == "train")
            {
                for (int iter = 0; iter < Global.trainIter; iter++)
                {
                    DateTime begin = DateTime.Now;

                    Console.WriteLine("\niter: {0}", iter + 1);
                    Global.swLog.WriteLine("\niter: {0}", iter + 1);


                    double trainAccuracy = runtrainIteration(X.Training, X.Testing, true, iter);

                    if (double.IsNaN(trainAccuracy) || double.IsInfinity(trainAccuracy))
                    {
                        Console.WriteLine("WARNING: invalid value for training loss. Try lowering learning rate.");
                    }

                    //test

                    testAccuracy1 = runtestIteration(X.Testing, false);

                    if (testAccuracy <= testAccuracy1)
                    {
                        LSTMLayer.SerializeWordembedding("model//deepNetwork//embedding");
                        //LSTMLayer.SerializeBigramWordembedding();
                        Global.upLSTMLayer.saveLSTM("model//deepNetwork//lstmmodel.txt");
                        Global.upLSTMLayerr.saveLSTM("model//deepNetwork//lstmmodelr.txt");
                        Global.GRNNLayer1.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer2.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer3.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer4.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.feedForwardLayer.saveFFmodel("model//deepNetwork//feedforwardmodel.txt");
                        testAccuracy = testAccuracy1;
                    }

                    DateTime end = DateTime.Now;

                    // Console.WriteLine("train f-score: {0}", (trainAccuracy*100).ToString("f3"));
                    Console.WriteLine("test f-score: {0}", (testAccuracy * 100).ToString("f3"));
                    Console.WriteLine("test f-score: {0}", (testAccuracy1 * 100).ToString("f3"));
                    Console.WriteLine("time used: {0}", end - begin);

                    //Global.swLog.WriteLine("train f-score: {0}", (trainAccuracy * 100).ToString("f3"));
                    Global.swLog.WriteLine("test f-score: {0}", (testAccuracy * 100).ToString("f3"));
                    Global.swLog.WriteLine("test f-score: {0}", (testAccuracy1 * 100).ToString("f3"));
                    Global.swLog.WriteLine("time used: {0}", end - begin);
                }
            }
        }
Esempio n. 4
0
        public static double runtrainIteration(List <DataSeq> X, List <DataSeq> Xtest, bool train, int iter)
        {
            List <DataSeq> x = new List <DataSeq>();

            if (train)
            {
                x = shuffle(X);//shuffle every window (point)
            }

            TrainThread runThread = new TrainThread(train);

            List <ManualResetEvent> manualEvents = new List <ManualResetEvent>();
            List <DataStep>         temp         = new List <DataStep>();

            int i = 0, j = 0;
            int length = x.Count();

            while (i < length)
            {
                if (i != 0 && i % 16 == 0)
                {
                    testAccuracy1 = runtestIteration(Xtest, false);

                    if (testAccuracy <= testAccuracy1)
                    {
                        LSTMLayer.SerializeWordembedding("model//deepNetwork//embedding");
                        //LSTMLayer.SerializeBigramWordembedding();
                        Global.upLSTMLayer.saveLSTM("model//deepNetwork//lstmmodel.txt");
                        Global.upLSTMLayerr.saveLSTM("model//deepNetwork//lstmmodelr.txt");
                        Global.GRNNLayer1.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer2.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer3.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.GRNNLayer4.saveGRNN("model//deepNetwork//grnnmodel1.txt");
                        Global.feedForwardLayer.saveFFmodel("model//deepNetwork//feedforwardmodel.txt");
                        testAccuracy = testAccuracy1;
                    }


                    Console.WriteLine("test f-score: {0}", (testAccuracy * 100).ToString("f3"));
                    Console.WriteLine("test1 f-score: {0}", (testAccuracy1 * 100).ToString("f3"));

                    Global.swLog.WriteLine("test f-score: {0}", (testAccuracy * 100).ToString("f3"));
                    Global.swLog.WriteLine("test f-score: {0}", (testAccuracy1 * 100).ToString("f3"));
                }
                for (int k = 0; k < Global.nThread && i < length; k++, i++)
                {
                    ManualResetEvent mre = new ManualResetEvent(false);
                    param            pa  = new param(x[i]);
                    pa.mre      = mre;
                    pa.datastep = x[i].datasteps;
                    manualEvents.Add(mre);
                    for (int m = 0; m < x[i].datasteps.Count; m++)
                    {
                        temp.Add(x[i].datasteps[m]);
                    }
                    ThreadPool.QueueUserWorkItem(new WaitCallback(runThread.run), pa);
                }
                WaitHandle.WaitAll(manualEvents.ToArray());
                if (train)
                {
                    UpdateWeight_rmProp(temp);
                }
                manualEvents.Clear();
                temp.Clear();
            }


            return(runThread.accword / runThread.totalword);
        }
Esempio n. 5
0
        static void Main(string[] args)
        {
            Console.WriteLine("Choose running mode: 1. training, 2. testing");
            string mode = Console.ReadLine();

            if (mode == "1")
            {
                Global.mode = "train";
            }
            else if (mode == "2")
            {
                Global.mode = "test";
            }

            //Console.WriteLine("Choose reading mode: 1. read saved model, 2. read model trained on MSR dataset.");
            //string read = Console.ReadLine();
            //string modes = Console.ReadLine();
            //if (modes == "1")
            //{
            //    Global.isRead = true;
            //}
            //else if (modes == "2")
            //{
            //    Global.isRead = false;
            //}

            Console.WriteLine("Choose the bigram feature mode: 1. read bigram features, 2. create bigram features");
            string bigramfeature = Console.ReadLine();

            if (bigramfeature == "1")
            {
                Global.isReadBigramfeature = true;
            }
            else if (bigramfeature == "2")
            {
                Global.isReadBigramfeature = false;
            }

            readword();
            readbigramword();
            readidoimword();

            Global.randn = new Normal();

            Global._UpLSTMLayer  = LSTMLayer.readLSTM("model\\lstmmodel.txt");
            Global._UpLSTMLayerr = LSTMLayer.readLSTM("model\\lstmmodelr.txt");
            // Global._LSTMLayerr = LSTMLayer.readLSTM("model\\lstmmodelr.txt");
            Global._GRNNLayer1 = GRNNLayer.readGRNN("model\\grnnmodel1.txt");
            Global._GRNNLayer2 = GRNNLayer.readGRNN("model\\grnnmodel2.txt");
            Global._GRNNLayer3 = GRNNLayer.readGRNN("model\\grnnmodel3.txt");
            Global._GRNNLayer4 = GRNNLayer.readGRNN("model\\grnnmodel4.txt");


            Global._feedForwardLayer = FeedForwardLayer.readFF("model\\feedforwardmodel.txt");



            LSTMLayer.getSerializeWordembedding();
            LSTMLayer.getSerializeBigramWordembedding();



            DataSet X = new DataSet();

            Trainer.train(X);

            Global.swLog.Close();
        }