private void Learn() { //Copy eval net to target net after `replace_target_iteration` iterations if (learn_step % replace_target_iteration == 0) { network_target.SetWeightsData(network_eval.GetWeightsData()); log.Add("Replacing `target net` with `eval net`"); } //Get batch from memory int[] batch_index = CreateMemoryBatch(); if (batch_index.Length == 0) { return; } //Compute network error for (int i = 0; i < batch_index.Length; i++) { //Compute `q_eval` and `q_target` float[] q_target = network_target.Compute(network_memory[batch_index[i]].next_state); float[] q_eval = network_eval.Compute(network_memory[batch_index[i]].current_state); //Compute reward float reward = network_memory[batch_index[i]].reward; //Add reward and reward decay int max_q_target = 0; for (int j = 0; j < q_target.Length; j++) { if (q_target[max_q_target] < q_target[j]) { max_q_target = j; } } q_target[max_q_target] = reward + reward_decay * q_target[max_q_target]; //Compute error float[] error = new float[actions.Length]; for (int j = 0; j < actions.Length; j++) { error[j] = q_target[j] - q_eval[j]; } //Update weights using RMS-PROP network_eval.UpdateWeights(error); } //Update epsilon if (epsilon < max_epsilon) { epsilon += epsilon_increment; } else { epsilon = max_epsilon; } //Update learn step learn_step++; }
private void Start() { MFNN network = new MFNN(new int[] { 4, 7, 3 }, new ActivationType[] { ActivationType.NONE, ActivationType.LOGISTIC_SIGMOID, ActivationType.LOGISTIC_SIGMOID }); int[] shuffle = ShuffleArray(IrisData.dataset.Length); Debug.Log("Initial Error: " + MSR(IrisData.dataset, network)); int input_size = network.GetInputSize(); int output_size = network.GetOutputSize(); int r = 0; while (r < 100) { for (int i = 0; i < IrisData.dataset.Length - 20; i++) { float[] x_values = new float[input_size]; float[] t_values = new float[output_size]; Array.Copy(IrisData.dataset[shuffle[i]], 0, x_values, 0, input_size); Array.Copy(IrisData.dataset[shuffle[i]], input_size, t_values, 0, output_size); float[] y_values = network.Compute(x_values); float[] errors = new float[output_size]; for (int j = 0; j < output_size; j++) { errors[j] = t_values[j] - y_values[j]; } network.UpdateWeights(errors, 0.01f, 0.0001f, 0.5f); } Debug.Log("Itr. " + r + " MSR: " + MSR(IrisData.dataset, network)); r++; } Debug.Log("Testing"); for (int i = IrisData.dataset.Length - 21; i < IrisData.dataset.Length; i++) { float[] x_values = new float[input_size]; float[] t_values = new float[output_size]; Array.Copy(IrisData.dataset[shuffle[i]], 0, x_values, 0, input_size); Array.Copy(IrisData.dataset[shuffle[i]], input_size, t_values, 0, output_size); float[] y_values = network.Compute(x_values); int max1 = 0; int max2 = 0; for (int j = 0; j < t_values.Length; j++) { if (t_values[max1] < t_values[j]) { max1 = j; } if (y_values[max2] < y_values[j]) { max2 = j; } } if (max1 == max2) { Debug.Log("GOOD"); } else { Debug.Log("BAD"); } } }