Пример #1
0
    private void Learn()
    {
        //Copy eval net to target net after `replace_target_iteration` iterations
        if (learn_step % replace_target_iteration == 0)
        {
            network_target.SetWeightsData(network_eval.GetWeightsData());
            log.Add("Replacing `target net` with `eval net`");
        }

        //Get batch from memory
        int[] batch_index = CreateMemoryBatch();
        if (batch_index.Length == 0)
        {
            return;
        }

        //Compute network error
        for (int i = 0; i < batch_index.Length; i++)
        {
            //Compute `q_eval` and `q_target`
            float[] q_target = network_target.Compute(network_memory[batch_index[i]].next_state);
            float[] q_eval   = network_eval.Compute(network_memory[batch_index[i]].current_state);

            //Compute reward
            float reward = network_memory[batch_index[i]].reward;

            //Add reward and reward decay
            int max_q_target = 0;
            for (int j = 0; j < q_target.Length; j++)
            {
                if (q_target[max_q_target] < q_target[j])
                {
                    max_q_target = j;
                }
            }
            q_target[max_q_target] = reward + reward_decay * q_target[max_q_target];
            //Compute error
            float[] error = new float[actions.Length];
            for (int j = 0; j < actions.Length; j++)
            {
                error[j] = q_target[j] - q_eval[j];
            }
            //Update weights using RMS-PROP
            network_eval.UpdateWeights(error);
        }

        //Update epsilon
        if (epsilon < max_epsilon)
        {
            epsilon += epsilon_increment;
        }
        else
        {
            epsilon = max_epsilon;
        }
        //Update learn step
        learn_step++;
    }
    private float MSR(float[][] data, MFNN network)
    {
        int input_size  = network.GetInputSize();
        int output_size = network.GetOutputSize();

        //Error Checking
        Debug.Assert(data.Length > 0);
        Debug.Assert(data[0].Length == input_size + output_size);

        float msr = 0;

        for (int i = 0; i < data.Length; i++)
        {
            float[] x_values = new float[input_size];
            float[] t_values = new float[output_size];
            Array.Copy(data[i], 0, x_values, 0, input_size);
            Array.Copy(data[i], input_size, t_values, 0, output_size);

            float[] y_values = network.Compute(x_values);

            float sum = 0;
            for (int j = 0; j < output_size; j++)
            {
                sum += (t_values[j] - y_values[j]) * (t_values[j] - y_values[j]);
            }
            msr += sum;
        }

        return(msr);
    }
    private void Start()
    {
        MFNN network = new MFNN(new int[] { 4, 7, 3 }, new ActivationType[] {
            ActivationType.NONE,
            ActivationType.LOGISTIC_SIGMOID,
            ActivationType.LOGISTIC_SIGMOID
        });

        int[] shuffle = ShuffleArray(IrisData.dataset.Length);
        Debug.Log("Initial Error: " + MSR(IrisData.dataset, network));

        int input_size  = network.GetInputSize();
        int output_size = network.GetOutputSize();

        int r = 0;

        while (r < 100)
        {
            for (int i = 0; i < IrisData.dataset.Length - 20; i++)
            {
                float[] x_values = new float[input_size];
                float[] t_values = new float[output_size];
                Array.Copy(IrisData.dataset[shuffle[i]], 0, x_values, 0, input_size);
                Array.Copy(IrisData.dataset[shuffle[i]], input_size, t_values, 0, output_size);

                float[] y_values = network.Compute(x_values);
                float[] errors   = new float[output_size];
                for (int j = 0; j < output_size; j++)
                {
                    errors[j] = t_values[j] - y_values[j];
                }
                network.UpdateWeights(errors, 0.01f, 0.0001f, 0.5f);
            }
            Debug.Log("Itr. " + r + " MSR: " + MSR(IrisData.dataset, network));
            r++;
        }

        Debug.Log("Testing");
        for (int i = IrisData.dataset.Length - 21; i < IrisData.dataset.Length; i++)
        {
            float[] x_values = new float[input_size];
            float[] t_values = new float[output_size];
            Array.Copy(IrisData.dataset[shuffle[i]], 0, x_values, 0, input_size);
            Array.Copy(IrisData.dataset[shuffle[i]], input_size, t_values, 0, output_size);
            float[] y_values = network.Compute(x_values);
            int     max1     = 0;
            int     max2     = 0;
            for (int j = 0; j < t_values.Length; j++)
            {
                if (t_values[max1] < t_values[j])
                {
                    max1 = j;
                }
                if (y_values[max2] < y_values[j])
                {
                    max2 = j;
                }
            }
            if (max1 == max2)
            {
                Debug.Log("GOOD");
            }
            else
            {
                Debug.Log("BAD");
            }
        }
    }
    private IEnumerator TakeStep()
    {
        //Get current car state
        float[] current_state = car_camera.GetRays();

        //Get action from current state
        float[] q_values = particles[working_particle].Compute(current_state);
        last_q_values = q_values;
        action_index  = SelectAction(q_values);

        //Wait for action to complete
        yield return(new WaitForSeconds(0.1f));

        //Get next state
        float[] next_state = car_camera.GetRays();
        //Get max `a` of `q_target`
        float[] q_target     = network_target.Compute(next_state);
        int     max_q_target = 0;

        for (int i = 1; i < max_q_target; i++)
        {
            if (q_target[max_q_target] < q_target[i])
            {
                max_q_target = i;
            }
        }
        //Get rward for action
        float velocity = car_body.gameObject.transform.InverseTransformDirection(car_body.velocity).z;

        current_reward = velocity + reward_decay * q_target[max_q_target];
        particles[working_particle].SetNetworkScore(current_reward);

        //Reset car if stuck after 100 steps
        if (car_body.velocity.magnitude < 0.3f && current_step - reset_step > 100)
        {
            reset_step = current_step;
            car_body.transform.position = car_spawner.transform.position;
            car_body.transform.rotation = car_spawner.transform.rotation;
            car_body.velocity           = Vector3.zero;
            car_body.angularVelocity    = Vector3.zero;
        }
        //After 300 steps go to next particle
        if (current_step - particle_step > next_particle_wait)
        {
            working_particle++;
            //reset reward
            current_reward = 0;
            particle_step  = current_step;
        }
        //Do a pso update after all particles
        if (working_particle == max_particles)
        {
            network_target.SetWeightsData(particle_swarm.GetBestWeights());
            //PSO Update Step
            particle_swarm.ComputeEpoch();
            particle_swarm.UpdateWeights();
            working_particle = 0;
        }
        current_step++;
        if (!abort_learning)
        {
            StartCoroutine(TakeStep());
        }
    }