Ejemplo n.º 1
0
        public static void Run()
        {
            Console.WriteLine("Build Vocabulary.");

            Vocabulary vocabulary = new Vocabulary();
            string     trainPath  = InternetFileDownloader.Donwload(DOWNLOAD_URL + TRAIN_FILE, TRAIN_FILE, TRAIN_FILE_HASH);
            string     testPath   = InternetFileDownloader.Donwload(DOWNLOAD_URL + TEST_FILE, TEST_FILE, TEST_FILE_HASH);

            int[] trainData = vocabulary.LoadData(trainPath);
            int[] testData  = vocabulary.LoadData(testPath);

            int nVocab = vocabulary.Length;

            Console.WriteLine("Done.");

            Console.WriteLine("Network Initilizing.");
            FunctionStack <Real> model = new FunctionStack <Real>(
                new EmbedID <Real>(nVocab, N_UNITS, name: "l1 EmbedID"),
                new Linear <Real>(N_UNITS, N_UNITS, name: "l2 Linear"),
                new TanhActivation <Real>("l2 Tanh"),
                new Linear <Real>(N_UNITS, nVocab, name: "l3 Linear"),
                new Softmax <Real>("l3 Sonftmax")
                );

            Adam <Real> adam = new Adam <Real>();

            adam.SetUp(model);

            List <int> s = new List <int>();

            Console.WriteLine("Train Start.");
            SoftmaxCrossEntropy <Real> softmaxCrossEntropy = new SoftmaxCrossEntropy <Real>();

            for (int epoch = 0; epoch < TRAINING_EPOCHS; epoch++)
            {
                for (int pos = 0; pos < trainData.Length; pos++)
                {
                    NdArray <Real> h = new NdArray <Real>(new Real[N_UNITS]);

                    int id = trainData[pos];
                    s.Add(id);

                    if (id == vocabulary.EosID)
                    {
                        Real accumloss = 0;
                        Stack <NdArray <Real> > tmp = new Stack <NdArray <Real> >();

                        for (int i = 0; i < s.Count; i++)
                        {
                            NdArray <int> tx = new NdArray <int>(i == s.Count - 1 ? new[] { vocabulary.EosID } : new[] { s[i + 1] });

                            //l1 EmbedID
                            NdArray <Real> l1 = model.Functions[0].Forward(s[i])[0];

                            //l2 Linear
                            NdArray <Real> l2 = model.Functions[1].Forward(h)[0];

                            //Add
                            NdArray <Real> xK = l1 + l2;

                            //l2 Tanh
                            h = model.Functions[2].Forward(xK)[0];

                            //l3 Linear
                            NdArray <Real> h2 = model.Functions[3].Forward(h)[0];

                            Real loss = softmaxCrossEntropy.Evaluate(h2, tx);
                            tmp.Push(h2);
                            accumloss += loss;
                        }

                        Console.WriteLine(accumloss);

                        for (int i = 0; i < s.Count; i++)
                        {
                            model.Backward(tmp.Pop());
                        }

                        adam.Update();

                        s.Clear();
                    }

                    if (pos % 100 == 0)
                    {
                        Console.WriteLine(pos + "/" + trainData.Length + " finished");
                    }
                }
            }

            Console.WriteLine("Test Start.");

            Real       sum     = 0;
            int        wnum    = 0;
            List <int> ts      = new List <int>();
            bool       unkWord = false;

            for (int pos = 0; pos < 1000; pos++)
            {
                int id = testData[pos];
                ts.Add(id);

                if (id > trainData.Length)
                {
                    unkWord = true;
                }

                if (id == vocabulary.EosID)
                {
                    if (!unkWord)
                    {
                        Console.WriteLine("pos" + pos);
                        Console.WriteLine("tsLen" + ts.Count);
                        Console.WriteLine("sum" + sum);
                        Console.WriteLine("wnum" + wnum);

                        sum  += CalPs(model, ts);
                        wnum += ts.Count - 1;
                    }
                    else
                    {
                        unkWord = false;
                    }

                    ts.Clear();
                }
            }

            Console.WriteLine(Math.Pow(2.0f, sum / wnum));
        }
Ejemplo n.º 2
0
        public static void Run()
        {
            Console.WriteLine("Build Vocabulary.");

            Vocabulary vocabulary = new Vocabulary();

            string trainPath = InternetFileDownloader.Donwload(DOWNLOAD_URL + TRAIN_FILE, TRAIN_FILE, TRAIN_FILE_HASH);
            string validPath = InternetFileDownloader.Donwload(DOWNLOAD_URL + VALID_FILE, VALID_FILE, VALID_FILE_HASH);
            string testPath  = InternetFileDownloader.Donwload(DOWNLOAD_URL + TEST_FILE, TEST_FILE, TEST_FILE_HASH);

            int[] trainData = vocabulary.LoadData(trainPath);
            int[] validData = vocabulary.LoadData(validPath);
            int[] testData  = vocabulary.LoadData(testPath);

            int nVocab = vocabulary.Length;

            Console.WriteLine("Network Initilizing.");
            FunctionStack model = new FunctionStack(
                new EmbedID(nVocab, N_UNITS, name: "l1 EmbedID"),
                new Dropout(),
                new LSTM(N_UNITS, N_UNITS, name: "l2 LSTM"),
                new Dropout(),
                new LSTM(N_UNITS, N_UNITS, name: "l3 LSTM"),
                new Dropout(),
                new Linear(N_UNITS, nVocab, name: "l4 Linear")
                );

            for (int i = 0; i < model.Functions.Length; i++)
            {
                for (int j = 0; j < model.Functions[i].Parameters.Length; j++)
                {
                    for (int k = 0; k < model.Functions[i].Parameters[j].Data.Length; k++)
                    {
                        model.Functions[i].Parameters[j].Data[k] = (Mother.Dice.NextDouble() * 2.0 - 1.0) / 10.0;
                    }
                }
            }

            //与えられたthresholdで頭打ちではなく、全パラメータのL2Normからレートを取り補正を行う
            GradientClipping gradientClipping = new GradientClipping(threshold: GRAD_CLIP);
            SGD sgd = new SGD(learningRate: 0.1);

            model.SetOptimizer(gradientClipping, sgd);

            Real wholeLen = trainData.Length;
            int  jump     = (int)Math.Floor(wholeLen / BATCH_SIZE);
            int  epoch    = 0;

            Console.WriteLine("Train Start.");

            for (int i = 0; i < jump * N_EPOCH; i++)
            {
                NdArray x = new NdArray(new[] { 1 }, BATCH_SIZE);
                NdArray t = new NdArray(new[] { 1 }, BATCH_SIZE);

                for (int j = 0; j < BATCH_SIZE; j++)
                {
                    x.Data[j] = trainData[(int)((jump * j + i) % wholeLen)];
                    t.Data[j] = trainData[(int)((jump * j + i + 1) % wholeLen)];
                }

                NdArray result  = model.Forward(x)[0];
                Real    sumLoss = new SoftmaxCrossEntropy().Evaluate(result, t);
                Console.WriteLine("[{0}/{1}] Loss: {2}", i + 1, jump, sumLoss);
                model.Backward(result);

                //Run truncated BPTT
                if ((i + 1) % BPROP_LEN == 0)
                {
                    model.Update();
                    model.ResetState();
                }

                if ((i + 1) % jump == 0)
                {
                    epoch++;
                    Console.WriteLine("evaluate");
                    Console.WriteLine("validation perplexity: {0}", Evaluate(model, validData));

                    if (epoch >= 6)
                    {
                        sgd.LearningRate /= 1.2;
                        Console.WriteLine("learning rate =" + sgd.LearningRate);
                    }
                }
            }

            Console.WriteLine("test start");
            Console.WriteLine("test perplexity:" + Evaluate(model, testData));
        }