override public void Init(NeuralNetwork p_network = null) { NeuralNetwork network = null; Optimizer optimizer = null; if (p_network == null) { network = new NeuralNetwork(); network.AddLayer("input", new InputLayer(GetParam(STATE_DIM)), BaseLayer.TYPE.INPUT); network.AddLayer("hidden0", new CoreLayer(SolverConfig.GetInstance().hidden_layer, ACTIVATION.RELU, BaseLayer.TYPE.HIDDEN), BaseLayer.TYPE.HIDDEN); network.AddLayer("output", new CoreLayer(GetParam(ACTION_DIM), ACTIVATION.TANH, BaseLayer.TYPE.OUTPUT), BaseLayer.TYPE.OUTPUT); // feed-forward connections network.AddConnection("input", "hidden0", Connection.INIT.GLOROT_UNIFORM); network.AddConnection("hidden0", "output", Connection.INIT.GLOROT_UNIFORM); } else { network = p_network; } optimizer = new ADAM(network); //optimizer = new RMSProp(network); //optimizer = new BackProp(network, 1e-5f, 0.99f, true); _critic = new DeepQLearning(optimizer, network, 0.99f, SolverConfig.GetInstance().memory_size, SolverConfig.GetInstance().batch_size, SolverConfig.GetInstance().qtupdate_size); _critic.SetAlpha(SolverConfig.GetInstance().learning_rate); }
public void Run() { Stopwatch watch = Stopwatch.StartNew(); NeuralNetwork network = new NeuralNetwork(); network.AddLayer("input", new InputLayer(2), BaseLayer.TYPE.INPUT); network.AddLayer("hidden", new CoreLayer(8, ACTIVATION.SIGMOID, BaseLayer.TYPE.HIDDEN), BaseLayer.TYPE.HIDDEN); network.AddLayer("output", new CoreLayer(1, ACTIVATION.SIGMOID, BaseLayer.TYPE.OUTPUT), BaseLayer.TYPE.OUTPUT); network.AddConnection("input", "hidden", Connection.INIT.GLOROT_UNIFORM); network.AddConnection("hidden", "output", Connection.INIT.GLOROT_UNIFORM); /* * Optimizer optimizer = new BackProp(network, 1e-5f, 0.99f, true) * { * Alpha = 0.1f * }; */ Optimizer optimizer = new RMSProp(network) { Alpha = 0.1f }; optimizer.InitBatchMode(4); Vector[] input = new Vector[4]; Vector[] target = new Vector[4]; //Vector output = null; input[0] = Vector.Build(2, new float[] { 0f, 0f }); input[1] = Vector.Build(2, new float[] { 0f, 1f }); input[2] = Vector.Build(2, new float[] { 1f, 0f }); input[3] = Vector.Build(2, new float[] { 1f, 1f }); target[0] = Vector.Build(1, new float[] { 0f }); target[1] = Vector.Build(1, new float[] { 1f }); target[2] = Vector.Build(1, new float[] { 1f }); target[3] = Vector.Build(1, new float[] { 0f }); for (int e = 0; e < 200; e++) { //Console.Write("Start "); //BasePool.Instance.Check(); float err = 0; for (int i = 0; i < 4; i++) { err += optimizer.Train(input[i], target[i]); } Console.WriteLine(err); //Console.Write("End "); //BasePool.Instance.Check(); } Console.WriteLine(); for (int i = 0; i < 4; i++) { Console.WriteLine(network.Activate(input[i])[0]); Vector.Release(input[i]); Vector.Release(target[i]); } optimizer.Dispose(); Console.Write("Finish "); BasePool.Instance.Check(); watch.Stop(); Console.WriteLine(watch.ElapsedMilliseconds); }