static void RunLSTM() { int nbWord = 3; string word = "Hello World\n"; string word0 = Enumerable.Repeat(word, nbWord).Glue(""); var udata = word0.Distinct().ToArray(); var tableChar2Idx = udata.Select((c, i) => (c, i)).ToDictionary(a => a.c, b => b.i); var tableIdx2Char = tableChar2Idx.ToDictionary(a => a.Value, b => b.Key); int iteration = 0; int epochs = 4000; int displayEpoch = 250; int p = 0; int X_size = udata.Length; int H_size = X_size * 2; int T_steps = word.Length; Console.WriteLine($"X_size:{X_size} H_size:{H_size} T_steps:{T_steps}"); var lstm = new LSTM(X_size, H_size, T_steps); var sw = Stopwatch.StartNew(); while (iteration <= epochs) { if (p + T_steps >= word0.Length) { p = 0; } var inputs = Enumerable.Range(p, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray(); var targets = Enumerable.Range(p + 1, T_steps).Select(i => tableChar2Idx[word0[i]]).ToArray(); lstm.TrainOnBatch(inputs, targets, p == 0); if (iteration % displayEpoch == 0) { Console.WriteLine($"Epochs:{iteration,6}/{epochs} Loss:{lstm.smooth_loss:F6} Time:{sw.ElapsedMilliseconds,6} ms"); var r = lstm.Predict(inputs[0], T_steps * 3 - 2); Console.WriteLine(r.Select(i => tableIdx2Char[i]).Glue("")); } ++iteration; p += T_steps; } }
public TestLstm() { zs = xs + hs; lstm = new LSTM(xs, hs, ts); lstm.W_f = Ops.ARange(hs, zs, 0); lstm.b_f = Ops.ARange(hs, 1, 0); lstm.W_i = Ops.ARange(hs, zs, 1); lstm.b_i = Ops.ARange(hs, 1, 1); lstm.W_C = Ops.ARange(hs, zs, 2); lstm.b_C = Ops.ARange(hs, 1, 2); lstm.W_o = Ops.ARange(hs, zs, 3); lstm.b_o = Ops.ARange(hs, 1, 3); lstm.W_y = Ops.ARange(xs, hs, 4); lstm.b_y = Ops.ARange(xs, 1, 4); lstm.dW_f = Ops.ARange(hs, zs, -1); lstm.db_f = Ops.ARange(hs, 1, -1); lstm.dW_i = Ops.ARange(hs, zs, -2); lstm.db_i = Ops.ARange(hs, 1, -2); lstm.dW_C = Ops.ARange(hs, zs, -3); lstm.db_C = Ops.ARange(hs, 1, -3); lstm.dW_o = Ops.ARange(hs, zs, -4); lstm.db_o = Ops.ARange(hs, 1, -4); lstm.dW_y = Ops.ARange(xs, hs, -5); lstm.db_y = Ops.ARange(xs, 1, -5); lstm.mW_f = Ops.ARange(hs, zs, 1); lstm.mb_f = Ops.ARange(hs, 1, 1); lstm.mW_i = Ops.ARange(hs, zs, 2); lstm.mb_i = Ops.ARange(hs, 1, 2); lstm.mW_C = Ops.ARange(hs, zs, 3); lstm.mb_C = Ops.ARange(hs, 1, 3); lstm.mW_o = Ops.ARange(hs, zs, 4); lstm.mb_o = Ops.ARange(hs, 1, 4); lstm.mW_y = Ops.ARange(xs, hs, 5); lstm.mb_y = Ops.ARange(xs, 1, 5); }