static void RunLSTM() { int nbWord = 3; string word = "Hello World\n"; string word0 = Enumerable.Repeat(word, nbWord).Glue(""); var udata = word0.Distinct().ToArray(); var tableChar2Idx = udata.Select((c, i) => (c, i)).ToDictionary(a => a.c, b => b.i); var tableIdx2Char = tableChar2Idx.ToDictionary(a => a.Value, b => b.Key); int iteration = 0; int epochs = 4000; int displayEpoch = 250; int p = 0; int X_size = udata.Length; int H_size = X_size * 2; int T_steps = word.Length; Console.WriteLine($"X_size:{X_size} H_size:{H_size} T_steps:{T_steps}"); var lstm = new LSTM(X_size, H_size, T_steps); var sw = Stopwatch.StartNew(); while (iteration <= epochs) { if (p + T_steps >= word0.Length) { p = 0; } var inputs = Enumerable.Range(p, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray(); var targets = Enumerable.Range(p + 1, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray(); lstm.TrainOnBatch(inputs, targets, p == 0); if (iteration % displayEpoch == 0) { Console.WriteLine($"Epochs:{iteration,6}/{epochs} Loss:{lstm.smooth_loss:F6} Time:{sw.ElapsedMilliseconds,6} ms"); var r = lstm.Predict(inputs[0], T_steps * 3 - 2); Console.WriteLine(r.Select(i => tableIdx2Char[i]).Glue("")); } ++iteration; p += T_steps; } }
public void Test3() { lstm.g_h_prev = Ops.ARange(hs, 1, -6); lstm.g_C_prev = Ops.ARange(hs, 1, -7); int[] inputs = { 0, 1, 0 }; int[] targets = { 1, 0, 0 }; lstm.TrainOnBatch(inputs, targets); Console.WriteLine("loss"); Console.WriteLine(lstm.smooth_loss); Console.WriteLine(); Console.WriteLine("h_prev"); Console.WriteLine(Ops.Print(lstm.g_h_prev)); Console.WriteLine(); Console.WriteLine("C_prev"); Console.WriteLine(Ops.Print(lstm.g_C_prev)); Console.WriteLine(); Console.WriteLine("W_f"); Console.WriteLine(Ops.Print(lstm.W_f)); Console.WriteLine("W_i"); Console.WriteLine(Ops.Print(lstm.W_i)); Console.WriteLine("W_C"); Console.WriteLine(Ops.Print(lstm.W_C)); Console.WriteLine("W_o"); Console.WriteLine(Ops.Print(lstm.W_o)); Console.WriteLine("W_y"); Console.WriteLine(Ops.Print(lstm.W_y)); Console.WriteLine("b_f"); Console.WriteLine(Ops.Print(lstm.b_f)); Console.WriteLine("b_i"); Console.WriteLine(Ops.Print(lstm.b_i)); Console.WriteLine("b_C"); Console.WriteLine(Ops.Print(lstm.b_C)); Console.WriteLine("b_o"); Console.WriteLine(Ops.Print(lstm.b_o)); Console.WriteLine("b_y"); Console.WriteLine(Ops.Print(lstm.b_y)); Console.WriteLine("dW_f"); Console.WriteLine(Ops.Print(lstm.dW_f)); Console.WriteLine("dW_i"); Console.WriteLine(Ops.Print(lstm.dW_i)); Console.WriteLine("dW_C"); Console.WriteLine(Ops.Print(lstm.dW_C)); Console.WriteLine("dW_o"); Console.WriteLine(Ops.Print(lstm.dW_o)); Console.WriteLine("dW_y"); Console.WriteLine(Ops.Print(lstm.dW_y)); Console.WriteLine("db_f"); Console.WriteLine(Ops.Print(lstm.db_f)); Console.WriteLine("db_i"); Console.WriteLine(Ops.Print(lstm.db_i)); Console.WriteLine("db_C"); Console.WriteLine(Ops.Print(lstm.db_C)); Console.WriteLine("db_o"); Console.WriteLine(Ops.Print(lstm.db_o)); Console.WriteLine("db_y"); Console.WriteLine(Ops.Print(lstm.db_y)); Console.WriteLine("mW_f"); Console.WriteLine(Ops.Print(lstm.mW_f)); Console.WriteLine("mW_i"); Console.WriteLine(Ops.Print(lstm.mW_i)); Console.WriteLine("mW_C"); Console.WriteLine(Ops.Print(lstm.mW_C)); Console.WriteLine("mW_o"); Console.WriteLine(Ops.Print(lstm.mW_o)); Console.WriteLine("mW_y"); Console.WriteLine(Ops.Print(lstm.mW_y)); Console.WriteLine("mb_f"); Console.WriteLine(Ops.Print(lstm.mb_f)); Console.WriteLine("mb_i"); Console.WriteLine(Ops.Print(lstm.mb_i)); Console.WriteLine("mb_C"); Console.WriteLine(Ops.Print(lstm.mb_C)); Console.WriteLine("mb_o"); Console.WriteLine(Ops.Print(lstm.mb_o)); Console.WriteLine("mb_y"); Console.WriteLine(Ops.Print(lstm.mb_y)); }