Exemplo n.º 1
0
        override public void Init(NeuralNetwork p_network = null)
        {
            _qtUpdateIndex = 0;
            _qtUpdateSize  = SolverConfig.GetInstance().qtupdate_size;
            _networkQ      = null;

            if (p_network == null)
            {
                _networkQ = new NeuralNetwork();
                _networkQ.AddLayer("input", new InputLayer(GetParam(STATE_DIM)), BaseLayer.TYPE.INPUT);
                _networkQ.AddLayer("hidden0", new CoreLayer(SolverConfig.GetInstance().hidden_layer, ACTIVATION.RELU, BaseLayer.TYPE.HIDDEN), BaseLayer.TYPE.HIDDEN);
                _networkQ.AddLayer("output", new CoreLayer(GetParam(ACTION_DIM), ACTIVATION.LINEAR, BaseLayer.TYPE.OUTPUT), BaseLayer.TYPE.OUTPUT);

                // feed-forward connections
                _networkQ.AddConnection("input", "hidden0", Connection.INIT.GLOROT_UNIFORM);
                _networkQ.AddConnection("hidden0", "output", Connection.INIT.GLOROT_UNIFORM);
            }
            else
            {
                _networkQ = p_network;
            }

            CreateNetworkQt();

            for (int i = 0; i < _learners.Capacity; i++)
            {
                AsyncDoubleQLearning worker = new AsyncDoubleQLearning(new ADAM(_networkQ), _networkQ, _networkQt, 0.99f, SolverConfig.GetInstance().async_update);
                //worker.SetAlpha(SolverConfig.GetInstance().learning_rate);
                worker.Optimizer.InitAlpha(SolverConfig.GetInstance().learning_rate, SolverConfig.GetInstance().learning_rate / 10);
                _learners.Add(worker);
            }
        }
Exemplo n.º 2
0
        public static NeuralNetwork LoadNetwork(string p_data)
        {
            NeuralNetwork res = null;

            string data = p_data;

            JSONObject main    = new JSONObject(data);
            JSONObject network = main["_network"];

            if (network["type"].str.Equals("feedforward"))
            {
                res = new NeuralNetwork();

                string inGroupId  = network["inlayer"].str;
                string outGroupId = network["outlayer"].str;

                JSONObject layers = main["layers"];

                foreach (string key in layers.keys)
                {
                    JSONObject layer = layers[key];

                    switch (layer["type"].str)
                    {
                    case BaseLayer.INPUT:
                        res.AddLayer(layer["id"].str, new InputLayer(layer), (BaseLayer.TYPE)Enum.Parse(typeof(BaseLayer.TYPE), layer["layer_type"].str));
                        break;

                    case BaseLayer.CORE:
                        res.AddLayer(layer["id"].str, new CoreLayer(layer), (BaseLayer.TYPE)Enum.Parse(typeof(BaseLayer.TYPE), layer["layer_type"].str));
                        break;

                    case BaseLayer.RECURRENT:
                        res.AddLayer(layer["id"].str, new RecurrentLayer(layer), (BaseLayer.TYPE)Enum.Parse(typeof(BaseLayer.TYPE), layer["layer_type"].str));
                        break;

                    case BaseLayer.LSTM:
                        //res.AddLayer(layer["id"].str, new LSTMLayer(layer), (BaseLayer.TYPE)Enum.Parse(typeof(BaseLayer.TYPE), layer["layer_type"].str));
                        break;
                    }
                }

                JSONObject connections = main["connections"];

                foreach (string key in connections.keys)
                {
                    JSONObject connection = connections[key];

                    BaseLayer inLayer  = connection.HasField("inlayer") ? res.Layers[connection["inlayer"].str] : null;
                    BaseLayer outLayer = connection.HasField("outlayer") ? res.Layers[connection["outlayer"].str] : null;

                    if (inLayer != null && outLayer != null)
                    {
                        Connection c = new Connection(key, inLayer.OutputGroup, outLayer.InputGroup, connection);

                        res.AddConnection(inLayer.Id, outLayer.Id, c);
                    }
                }
            }

            return(res);
        }