Esempio n. 1
0
        static void RunLSTM()
        {
            int    nbWord = 3;
            string word   = "Hello World\n";
            string word0  = Enumerable.Repeat(word, nbWord).Glue("");

            var udata         = word0.Distinct().ToArray();
            var tableChar2Idx = udata.Select((c, i) => (c, i)).ToDictionary(a => a.c, b => b.i);
            var tableIdx2Char = tableChar2Idx.ToDictionary(a => a.Value, b => b.Key);

            int iteration    = 0;
            int epochs       = 4000;
            int displayEpoch = 250;
            int p            = 0;

            int X_size  = udata.Length;
            int H_size  = X_size * 2;
            int T_steps = word.Length;

            Console.WriteLine($"X_size:{X_size} H_size:{H_size} T_steps:{T_steps}");

            var lstm = new LSTM(X_size, H_size, T_steps);
            var sw   = Stopwatch.StartNew();

            while (iteration <= epochs)
            {
                if (p + T_steps >= word0.Length)
                {
                    p = 0;
                }

                var inputs  = Enumerable.Range(p, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray();
                var targets = Enumerable.Range(p + 1, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray();

                lstm.TrainOnBatch(inputs, targets, p == 0);

                if (iteration % displayEpoch == 0)
                {
                    Console.WriteLine($"Epochs:{iteration,6}/{epochs} Loss:{lstm.smooth_loss:F6} Time:{sw.ElapsedMilliseconds,6} ms");
                    var r = lstm.Predict(inputs[0], T_steps * 3 - 2);
                    Console.WriteLine(r.Select(i => tableIdx2Char[i]).Glue(""));
                }

                ++iteration;
                p += T_steps;
            }
        }
Esempio n. 2
0
        public TestLstm()
        {
            zs   = xs + hs;
            lstm = new LSTM(xs, hs, ts);

            lstm.W_f = Ops.ARange(hs, zs, 0);
            lstm.b_f = Ops.ARange(hs, 1, 0);
            lstm.W_i = Ops.ARange(hs, zs, 1);
            lstm.b_i = Ops.ARange(hs, 1, 1);
            lstm.W_C = Ops.ARange(hs, zs, 2);
            lstm.b_C = Ops.ARange(hs, 1, 2);
            lstm.W_o = Ops.ARange(hs, zs, 3);
            lstm.b_o = Ops.ARange(hs, 1, 3);
            lstm.W_y = Ops.ARange(xs, hs, 4);
            lstm.b_y = Ops.ARange(xs, 1, 4);

            lstm.dW_f = Ops.ARange(hs, zs, -1);
            lstm.db_f = Ops.ARange(hs, 1, -1);
            lstm.dW_i = Ops.ARange(hs, zs, -2);
            lstm.db_i = Ops.ARange(hs, 1, -2);
            lstm.dW_C = Ops.ARange(hs, zs, -3);
            lstm.db_C = Ops.ARange(hs, 1, -3);
            lstm.dW_o = Ops.ARange(hs, zs, -4);
            lstm.db_o = Ops.ARange(hs, 1, -4);
            lstm.dW_y = Ops.ARange(xs, hs, -5);
            lstm.db_y = Ops.ARange(xs, 1, -5);

            lstm.mW_f = Ops.ARange(hs, zs, 1);
            lstm.mb_f = Ops.ARange(hs, 1, 1);
            lstm.mW_i = Ops.ARange(hs, zs, 2);
            lstm.mb_i = Ops.ARange(hs, 1, 2);
            lstm.mW_C = Ops.ARange(hs, zs, 3);
            lstm.mb_C = Ops.ARange(hs, 1, 3);
            lstm.mW_o = Ops.ARange(hs, zs, 4);
            lstm.mb_o = Ops.ARange(hs, 1, 4);
            lstm.mW_y = Ops.ARange(xs, hs, 5);
            lstm.mb_y = Ops.ARange(xs, 1, 5);
        }