public static void Demo() { LSTMNetwork network = LSTMNetwork.Create2(32, 5); network.DisplayByDot2(); Dictionary <string, double[]> inputs = new Dictionary <string, double[]>(); inputs.Add("inputs", new double[] { 1, 2, 3, 4, 5 }); network.Run2(inputs); }
private static void Main(string[] args) { Console.Clear(); ProcessingDevice.Device = DeviceType.CPU; OpenText(); hidden_size = 100; seq_length = 25; learning_rate = 1e-1f; net = new LSTMNetwork(vocab_size, vocab_size, hidden_size, learning_rate, 1e-1f); var hprev = new FloatArray(hidden_size); var cprev = new FloatArray(hidden_size); var smooth_loss = -Math.Log(1.0 / vocab_size) * seq_length; int n = 0; int p = 0; while (n <= 1000 * 100) { if (p + seq_length + 1 >= data_size || n == 0) { hprev = new FloatArray(hidden_size); cprev = new FloatArray(hidden_size); p = 0; } var inputs = new int[seq_length]; var targets = new int[seq_length]; for (int i = 0; i < seq_length; i++) { inputs[i] = char_to_ix[txt[p + i]]; } for (int i = 0; i < seq_length; i++) { targets[i] = char_to_ix[txt[p + 1 + i]]; } (var loss, var dWf, var dWi, var dWc, var dWo, var dWv, var dBf, var dBi, var dBc, var dBo, var dBv, var hs, var cs) = net.BPTT(inputs, targets, hprev, cprev); net.UpdateParams(dWf, dWi, dWc, dWo, dWv, dBf, dBi, dBc, dBo, dBv); if (n % 100 == 0) { Sample(hprev, cprev, inputs[0], 200); Console.WriteLine($"iter {n}, loss: {smooth_loss}"); } hprev = hs; cprev = cs; smooth_loss = smooth_loss * 0.999 + loss * 0.001; p += seq_length; n += 1; } }