private void FixedUpdate()
    {
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.position.y);
        states.Add(this.GetComponent <Rigidbody2D>().velocity.y);
        qs = ann.CalcOutput(states);
        this.GetComponent <Rigidbody2D>().AddForce(Vector2.up * force * (float)qs[0]);
        if (dead)
        {
            reward = -1;
        }
        else
        {
            reward = 0.1f;
        }
        replay lastmemory = new replay(this.transform.position.y, this.GetComponent <Rigidbody2D>().velocity.y, reward);

        if (replaymemory.Count > mcapacity)
        {
            replaymemory.RemoveAt(0);
        }
        replaymemory.Add(lastmemory);
        //Training And QLearning
        if (dead)
        {
            for (int i = replaymemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputs_old = new List <double>();
                List <double> toutputs_new = new List <double>();
                toutputs_old = ann.CalcOutput(replaymemory[i].states);

                double feedback;
                if (i == replaymemory.Count - 1 || replaymemory[i].reward == -1)
                {
                    feedback = replaymemory[i].reward;
                }
                else
                {
                    toutputs_new = ann.CalcOutput(replaymemory[i + 1].states);
                    double maxQ = toutputs_new[0];
                    feedback = (replaymemory[i].reward + discount * maxQ);  //BELLMAN EQUATION
                }
                toutputs_old[0] = feedback;
                ann.Train(replaymemory[i].states, toutputs_old);
            }
            dead = false;
            Reset();
            replaymemory.Clear();
        }
    }
예제 #2
0
    private List <double> Run(
        double ballXPosition,
        double ballYPosition,
        double ballXVelocity,
        double ballYVelocity,
        double paddleXPosition,
        double paddleYPosition,
        double paddleVelocity,
        bool train)
    {
        var inputs = new List <double>
        {
            ballXPosition,
            ballYPosition,
            ballXVelocity,
            ballYVelocity,
            paddleXPosition,
            paddleYPosition
        };

        var outputs = new List <double>
        {
            paddleVelocity
        };

        if (train)
        {
            return(_ann.Train(inputs, outputs));
        }
        else
        {
            return(_ann.CalcOutput(inputs, outputs));
        }
    }
    private void Update()
    {
        if (!finishedTraining)
        {
            return;
        }

        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        float fDist   = 0;
        float rDist   = 0;
        float lDist   = 0;
        float r45Dist = 0;
        float l45Dist = 0;

        Raycast(this.transform.forward, ref fDist);
        Raycast(this.transform.right, ref rDist);
        Raycast(-this.transform.right, ref lDist);
        Raycast(Quaternion.AngleAxis(-45, Vector3.up) * this.transform.right, ref r45Dist);
        Raycast(Quaternion.AngleAxis(45, Vector3.up) * -this.transform.right, ref l45Dist);

        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);
        List <double> calcOutputs = ann.CalcOutput(inputs);

        translationInput = Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        rotationInput    = Map(-1, 1, 0, 1, (float)calcOutputs[1]);
    }
    /// <summary>
    /// Calculate outputs without updating weight values as opposed to Train().
    /// </summary>
    private List <double> CalculateOutputs(List <double> inputs, List <double> desiredOutputs)
    {
        var calculatedOutputs = new List <double>();

        calculatedOutputs = ann.CalcOutput(inputs, desiredOutputs);
        return(calculatedOutputs);
    }
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        //Populate the inputs
        inputs.Add(bx);
        inputs.Add(by);
        inputs.Add(bvx);
        inputs.Add(bvy);
        inputs.Add(px);
        inputs.Add(py);

        //We put pv in the output is cuz so that we can train it
        outputs.Add(pv);

        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #6
0
    //Run the ANN, that can train or calculate output based on a boolean
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        //Add the inputs to the inputs list
        inputs.Add(bx);
        inputs.Add(by);
        inputs.Add(bvx);
        inputs.Add(bvy);
        inputs.Add(px);
        inputs.Add(py);
        //Ad the expected output to outputs list
        outputs.Add(pv);

        //Call function according to boolean flag
        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #7
0
    void Update()
    {
        if (!trainingDone)
        {
            return;
        }

        List <double> calcOutputs = new List <double>();
        List <double> inputs      = new List <double>();
        List <double> outputs     = new List <double>();

        float fDist = 0, rDist = 0, lDist = 0, r45Dist = 0, l45Dist = 0;

        kart.PerformRayCasts(out fDist, out rDist, out lDist, out r45Dist, out l45Dist, this.transform);

        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);
        outputs.Add(0);
        outputs.Add(0);
        calcOutputs = ann.CalcOutput(inputs, outputs);

        float translationInput = Utils.Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        float rotationInput    = Utils.Map(-1, 1, 0, 1, (float)calcOutputs[1]);

        kart.Move(this.transform, translationInput, rotationInput);
    }
예제 #8
0
    // Update is called once per frame
    void Update()
    {
        if (!(trainingDone))
        {
            return;
        }

        List <double> calcOutputs = new List <double>();
        List <double> inputs      = new List <double>();
        List <double> outputs     = new List <double>();

        //raycasts
        RaycastHit hit;
        float      fDist = 0, rDist = 0, lDist = 0, r45Dist = 0, l45Dist = 0;

        //forward
        if (Physics.Raycast(transform.position, this.transform.forward, out hit, visibleDistance))
        {
            fDist = 1 - Round(hit.distance / visibleDistance);
        }
        //right
        if (Physics.Raycast(transform.position, this.transform.right, out hit, visibleDistance))
        {
            rDist = 1 - Round(hit.distance / visibleDistance);
        }
        //forward
        if (Physics.Raycast(transform.position, -this.transform.right, out hit, visibleDistance))
        {
            lDist = 1 - Round(hit.distance / visibleDistance);
        }
        //45 degrees right
        if (Physics.Raycast(transform.position,
                            Quaternion.AngleAxis(45, Vector3.up) * this.transform.right, out hit, visibleDistance))
        {
            r45Dist = 1 - Round(hit.distance / visibleDistance);
        }
        //45 degrees left
        if (Physics.Raycast(transform.position,
                            Quaternion.AngleAxis(45, Vector3.up) * -this.transform.right, out hit, visibleDistance))
        {
            l45Dist = 1 - Round(hit.distance / visibleDistance);
        }

        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);
        outputs.Add(0);
        outputs.Add(0);
        calcOutputs = ann.CalcOutput(inputs, outputs);
        float translationInput = Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        float rotationInput    = Map(-1, 1, 0, 1, (float)calcOutputs[1]);

        translation = translationInput * speed * Time.deltaTime;
        rotation    = rotationInput * rotationSpeed * Time.deltaTime;
        this.transform.Translate(0, 0, translation);
        this.transform.Rotate(0, rotation, 0);
    }
예제 #9
0
    public int CalculateAction(List <double> states)
    {
        Debug.Log("states[0]: " + states[0]);
        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        /*
         * exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);
         *
         * if (Random.Range(0, 100) < exploreRate)
         * {
         *  maxQIndex = Random.Range(0, output);
         * }
         */

        return(maxQIndex);
    }
예제 #10
0
    private void FixedUpdate()
    {
        counter += Time.deltaTime;
        // Init
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        // Observe states
        states = CollectObservations();

        // Get Action
        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        //if (Random.Range(0, 100) < exploreRate)
        //  maxQIndex = Random.Range(0, 2);

        // Perform Action
        AgentAction(maxQIndex);

        // Access Replay Memory
        List <double> newStates = CollectObservations();
        // Replay lastMemory = new Replay(newStates[0], newStates[1], newStates[2], newStates[3], newStates[4], reward);
        //Replay lastMemory = new Replay(states[0], states[1], states[2], states[3], states[4], screenPressed ? 1f : -1, states[5], states[6], reward);
        Replay lastMemory = new Replay(states[0], states[1], states[2], states[3], reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        cumulative_reward += reward;

        if (dead)
        {
            TrainAfterDead();
        }
    }
예제 #11
0
    // Wrapper method for train and calculate output //
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        List <double> inputs = new List <double>()
        {
            bx, by, bvx, bvy, px, py
        };
        List <double> outputs = new List <double>()
        {
            pv
        };

        return(train ? ann.Train(inputs, outputs) : ann.CalcOutput(inputs, outputs));
    }
    List <double> Prediction(double i0, double i1, double i2, double i3, double i4, double i5, double output = 0)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(i0);
        inputs.Add(i1);
        inputs.Add(i2);
        inputs.Add(i3);
        inputs.Add(i4);

        outputs.Add(output);

        return(ann.CalcOutput(inputs, outputs));
    }
예제 #13
0
    List <double> Train(int input1, int input2, int desiredOutput, bool updateWeights = true)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(input1);
        inputs.Add(input2);
        outputs.Add(desiredOutput);
        if (updateWeights)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #14
0
    //Called every frame
    void Update()
    {
        //If the ANN has not been trained, return
        if (!trainingDone)
        {
            return;
        }

        //Create lists for calculated outputs, inputs and outputs list (placeholder requirement for our ANN implementation)
        List <double> calcOutputs = new List <double>();
        List <double> inputs      = new List <double>();
        List <double> outputs     = new List <double>();

        //raycasts
        float fDist = 0, rDist = 0, lDist = 0, r45Dist = 0, l45Dist = 0;

        Utils.PerformRayCasts(out fDist, out rDist, out lDist, out r45Dist, out l45Dist, this.transform, visibleDistance);

        //Add the raycast hit distances returned to the list as ANN inputs
        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);

        //Add zeros to output list; We are going to receive these values from the trained ANN. These are just placeholders.
        outputs.Add(0);
        outputs.Add(0);

        // Run the ANN and calculate output values for movement from the trained ANN
        calcOutputs = ann.CalcOutput(inputs, outputs);

        //Standard movement code, using ANN output values
        float translationInput = Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        float rotationInput    = Map(-1, 1, 0, 1, (float)calcOutputs[1]);

        translation = translationInput * speed * Time.deltaTime;
        rotation    = rotationInput * rotationSpeed * Time.deltaTime;
        this.transform.Translate(0, 0, translation);
        this.transform.Rotate(0, rotation, 0);
    }
예제 #15
0
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(bx);  //ball x pos
        inputs.Add(by);  //ball y pos
        inputs.Add(bvx); //ball velocity x
        inputs.Add(bvy); //ball velocity y
        inputs.Add(px);  //paddle x pos
        inputs.Add(py);  //paddle y pos
        outputs.Add(pv); //paddle y velocity
        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #16
0
파일: Brain.cs 프로젝트: dolevp/PongAI
    /**
     * Calculate and return a list of output from given parameters
     * @param bx - ball x position
     * @param by - ball y position
     * @param bvx - the ball's x velocity
     * @param bvy - the ball's y velocity
     * @param px - paddle's x position
     * @param py - paddle's y position
     * @param pv - the distance between the paddle and the expected hit (expected output)
     * @param train - whether we should train the ANN or not
     * @return a list of outputs that represent the new paddle's position in relation to the current one
     */
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(bx);
        inputs.Add(by);
        inputs.Add(bvx);
        inputs.Add(bvy);
        inputs.Add(px);
        inputs.Add(py);
        outputs.Add(pv);
        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs));
        }
    }
    List <double> Run(double ballX, double ballY, double ballVelX, double ballVelY, double paddleX, double paddleY, double paddleVel, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(ballX);
        inputs.Add(ballY);
        inputs.Add(ballVelX);
        inputs.Add(ballVelY);
        inputs.Add(paddleX);
        inputs.Add(paddleY);
        outputs.Add(paddleVel);

        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #18
0
    List <double> Run(double ballX, double ballY, double ballVelocX, double ballVelocY, double paddleX, double paddleY, double paddleVelocity, bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(ballX);
        inputs.Add(ballY);
        inputs.Add(ballVelocX);
        inputs.Add(ballVelocY);
        inputs.Add(paddleX);
        inputs.Add(paddleY);
        outputs.Add(paddleVelocity);

        if (train)
        {
            return(_artificialNeuronNetwork.Train(inputs, outputs));
        }
        else
        {
            return(_artificialNeuronNetwork.CalcOutput(inputs, outputs));
        }
    }
    // METHOD - This method either does training or does calculations without affecting the training
    List <double> Run(double bx, double by, double bvx, double bvy, double px, double py, double pv, bool train)
    {
        // Six inputs and One Output
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(bx);  // Ball x position
        inputs.Add(by);  // Ball y position
        inputs.Add(bvx); // Ball x velocity
        inputs.Add(bvy); // Ball y velocity
        inputs.Add(px);  // Paddle x position
        inputs.Add(py);  // Paddle y position
        outputs.Add(pv); // Paddle velocity, this is ignored when we are calculating and not training
        // If training is selected, then go ahead and perform the training
        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            // Otherwise, only calculate the output without affecting the training
            return(ann.CalcOutput(inputs, outputs));
        }
    }
예제 #20
0
    void FixedUpdate()
    {
        frames++;
        // seeGround = true;
        // isOnGround = Physics2D.OverlapCircle(groundCheck.position, groundCheckRadius, whatIsGround);
        Debug.DrawRay(theEyes.transform.position, theEyes.transform.right * 20, Color.green);
        RaycastHit2D hit = Physics2D.Raycast(theEyes.transform.position, theEyes.transform.right * 20);

        if (hit && hit.collider.tag == "Killbox")
        {
            seeGround = false;
            Debug.DrawRay(theEyes.transform.position, theEyes.transform.right * 20, Color.red);
        }
        // double[] distancesFromObjects = new double[platforms.Length];
        // for(int i = 0; i < platforms.Length; i++)
        // {
        //  Vector3 heading = transform.position - platforms[i].transform.position;
        //  distancesFromObjects[i] = heading.magnitude;
        // }
        // // second closest, to be honest
        // System.Array.Sort(distancesFromObjects);
        // double closestPlatform = distancesFromObjects[1];
        // int indexOfClosest = distancesFromObjects.ToList().IndexOf(closestPlatform);
        // Vector3 closestPoint = platforms[indexOfClosest].transform.position;

        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        GameObject[] platforms    = GameObject.FindGameObjectsWithTag("platform");
        Vector3      bestPoint    = GetClosestEnemy(platforms);
        Vector3      closestPoint = GetClosestGap(platforms);

        Vector3 directionToNextPlatform = bestPoint - transform.position;
        Vector3 directionToNextGap      = closestPoint - transform.position;

        // states.Add(transform.position.y);
        // states.Add(rb.velocity.y);
        states.Add(directionToNextPlatform.x);
        // states.Add(directionToNextPlatform.y);
        states.Add(directionToNextGap.x);
        // Debug.Log(directionToNextGap.x);

        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        if (Random.Range(0, 100) < exploreRate)
        {
            maxQIndex = Random.Range(0, 2);
        }

        if (maxQIndex == 1)
        {
            sumOfJumps++;
        }

        if (maxQIndex == 0)
        {
            sumOfStays++;
        }

        if (frames % 8 == 0)
        {
            if (sumOfJumps > sumOfStays)
            {
                robotAccess.RobotJump();
            }
            sumOfStays = 0;
            sumOfJumps = 0;
            frames     = 0;
        }

        if (rb.velocity.x < 0.5)
        {
            robotAccess.RobotJump();
        }

        if (hitObstacle)
        {
            reward = -5.0f;
        }
        else
        {
            reward = 0.1f;
        }



        Replay lastMemory = new Replay(
            // transform.position.y,
            // rb.velocity.y,
            directionToNextPlatform.x,
            // directionToNextPlatform.y,
            directionToNextGap.x,
            reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        if (hitObstacle)
        {
            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> tOutputsOld = new List <double>();
                List <double> tOutputsNew = new List <double>();
                tOutputsOld = SoftMax(ann.CalcOutput(replayMemory[i].states));

                double maxQOld = tOutputsOld.Max();
                int    action  = tOutputsOld.ToList().IndexOf(maxQOld);

                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    tOutputsNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));
                    maxQ        = tOutputsNew.Max();
                    feedback    = (replayMemory[i].reward + discount * maxQ);
                }

                tOutputsOld[action] = feedback;
                ann.Train(replayMemory[i].states, tOutputsOld);
            }
            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            hitObstacle = false;
            theGameManager.Reset();
            replayMemory.Clear();
            failCount++;
        }
    }
예제 #21
0
    private void QLearning()
    {
        /// Counting timer
        resetTimer--;

        qs.Clear();
        qs = SoftMax(ann.CalcOutput(states));

        float maxQ      = qs.Max();
        int   maxQIndex = qs.ToList().IndexOf(maxQ);

        /// Counting exploring output values
        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);
        if (Random.Range(0, 100) < exploreRate)
        {
            maxQIndex = Random.Range(0, 4);
        }

        /// Moving the vehicle using output values
        /// Move forward
        if (maxQIndex == 0)
        {
            rb.AddForce(this.transform.forward * Mathf.Clamp(qs[maxQIndex], 0f, 1f) * 300f);
        }

        if (maxQIndex == 1)
        {
            rb.AddForce(this.transform.forward * Mathf.Clamp(qs[maxQIndex], 0f, 1f) * -300f);
        }

        /// Turning
        if (maxQIndex == 2)
        {
            this.transform.Rotate(0, Mathf.Clamp(qs[maxQIndex], 0f, 1f) * 2f, 0, 0);
        }

        if (maxQIndex == 3)
        {
            this.transform.Rotate(0, Mathf.Clamp(qs[maxQIndex], 0f, 1f) * -2f, 0, 0);
        }

        RewardFunction();

        /// Setting replay memory
        Replay lastMemory = new Replay(states, reward);

        Rewards.Add(reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        /// Training through replay memory
        if (collisionFail || resetTimer == 0 || win || backFail)
        {
            List <float> QOld = new List <float>();
            List <float> QNew = new List <float>();

            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <float> toutputsOld = new List <float>();                                      // List of actions at time [t] (present)
                List <float> toutputsNew = new List <float>();                                      // List of actions at time [t + 1] (future)
                toutputsOld = SoftMax(ann.CalcOutput(replayMemory[i].states));                      // Action in time [t] is equal to NN output for [i] step states in replay memory

                float maxQOld = toutputsOld.Max();                                                  // maximum Q value at [i] step is equal to maximum Q value in the list of actions in time [t]
                int   action  = toutputsOld.ToList().IndexOf(maxQOld);                              // number of action (in list of actions at time [t]) with maximum Q value is setted
                QOld.Add(toutputsOld[action]);

                float feedback;
                if (i == replayMemory.Count - 1)                                                    // if it's the end of replay memory or if by any reason it's the end of the sequence (in this case
                {                                                                                   // it's collision fail, timer reset and getting into the source of light) then the
                    feedback = replayMemory[i].reward;                                              // feedback (new reward) is equal to the reward in [i] replay memory, because it's the end of the
                }                                                                                   // sequence and there's no event after to count Bellman's equation

                else
                {
                    toutputsNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));              // otherwise the action at time [t + 1] is equal to NN output for [i + 1] step states
                    maxQ        = toutputsNew.Max();                                                // in replay memory and then feedback is equal to the Bellman's Equation
                    feedback    = (replayMemory[i].reward +
                                   discount * maxQ);
                }
                QNew.Add(feedback);

                if (save == true)
                {
                    SaveToFile(QOld, QNew, Rewards, "QValues");
                }

                float thisError = 0f;
                currentWeights = ann.PrintWeights();

                toutputsOld[action] = feedback;                                                     // then the action at time [t] with max Q value (the best action) is setted as counted feedback
                List <float> calcOutputs = ann.Train(replayMemory[i].states, toutputsOld);          // value and it's used to train NN for [i] state
                for (int j = 0; j < calcOutputs.Count; j++)
                {
                    thisError += (Mathf.Pow((toutputsOld[j] - calcOutputs[j]), 2));
                }
                thisError = thisError / calcOutputs.Count;
                sse      += thisError;
            }
            sse /= replayMemory.Count;

            if (lastRewardSum < Rewards.Sum())
            {
                //ann.LoadWeights(currentWeights);
                ann.eta = Mathf.Clamp(ann.eta - 0.001f, 0.1f, 0.4f);
            }
            else
            {
                ann.eta       = Mathf.Clamp(ann.eta + 0.001f, 0.1f, 0.4f);
                lastRewardSum = Rewards.Sum();
            }

            replayMemory.Clear();
            ResetVehicle();
            Rewards.Clear();
        }
    }
예제 #22
0
    private void FixedUpdate()
    {
        // If brain wasn't inicialized return
        if (ann == null)
        {
            return;
        }
        if (deadInPopulation)
        {
            return;
        }

        timeAlive += Time.deltaTime;

        List <double> states = new List <double>();
        List <double> qs;

        GameObject currColumn = GameController.instance.GetCurrentColumn();

        if (!currColumn)
        {
            return;
        }

        // Get vertical and horizontal distance to the current column
        float yDist = currColumn.transform.position.y - transform.position.y;
        float xDist = currColumn.transform.position.x - transform.position.x;

        // Normalize and round
        float vertDist = (float)System.Math.Round((Map(-1.0f, 1.0f, -halfScreen, halfScreen, yDist)), 2);
        float horDist  = 1 - (float)System.Math.Round((Map(0.0f, 1.0f, 0.0f, maxDistanceToColumn, xDist)), 2);

        // Add values to the states
        states.Add(vertDist);
        states.Add(horDist);

        // Calc output for states
        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        // Explore
        if (PopulationManager.instance.isExploring)
        {
            PopulationManager.instance.exploreRate = Mathf.Clamp(PopulationManager.instance.exploreRate - PopulationManager.instance.exploreDecay, PopulationManager.instance.minExploreRate, PopulationManager.instance.maxExploreRate);
            if (Random.Range(0, 100) < PopulationManager.instance.exploreRate)
            {
                maxQIndex = Random.Range(0, 2);
            }
        }

        // Action
        if (maxQIndex == 0)
        {
            Flap();
        }

        // Reward
        if (isDead)
        {
            reward = -1f;
        }
        else
        {
            reward = 0.1f;
        }

        // Add a new memory
        Replay lastMemory = new Replay(vertDist, horDist, reward);

        replayMemory.Add(lastMemory);

        if (replayMemory.Count > maxMemoryCapacity)
        {
            replayMemory.RemoveAt(0);
        }


        if (isDead)
        {
            if (PopulationManager.instance.qLearning)
            {
                TrainFromMemories();
            }

            if (timeAlive > maxFlightTime)
            {
                maxFlightTime = timeAlive;
            }

            if (score > maxScore)
            {
                maxScore = score;
            }

            replayMemory.Clear();

            // After training with memories
            if (!deadInPopulation)
            {
                deadInPopulation = true;
                PopulationManager.instance.BirdDied();
            }
        }
    }
예제 #23
0
    void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.rotation.x);
        states.Add(this.transform.position.z);
        states.Add(ball.GetComponent <Rigidbody>().angularVelocity.x); // reflection not working in VCS

        qs = SoftMax(ann.CalcOutput(states));                          // why a softmax?
        double maxQ      = qs.Max();                                   // cost: O(L), where L is length of the list
        int    maxQIndex = qs.ToList().IndexOf(maxQ);                  // cost is O(L)

        // in my opinion, exploreRate should decrease after each fail and not after each fixedUpdate
        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        // Udemy: remove these lines will accelerate convergence
        // more exporation early on, and less later on
        if (Random.Range(0, 100) < exploreRate)
        {
            maxQIndex = Random.Range(0, 2);              // choose either 0 or 1
        }
        if (maxQIndex == 0)
        {
            // public void Rotate(Vector3 eulerAngles, Space relativeTo = Space.Self);
            // public void Rotate(Vector3 axis, float angle, Space relativeTo = Space.Self);
            this.transform.Rotate(Vector3.right, tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 1)
        {
            this.transform.Rotate(Vector3.right, -tiltSpeed * (float)qs[maxQIndex]);
        }

        if (ball.GetComponent <BallState>().dropped)
        {
            reward = -1.0f;
        }
        else
        {
            reward = 0.1f;               // [0.1f]
        }
        Replay lastMemory = new Replay(this.transform.rotation.x,
                                       ball.transform.position.z,
                                       ball.GetComponent <Rigidbody>().angularVelocity.x,
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        if (ball.GetComponent <BallState>().dropped)
        {
            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputsOld = new List <double>();
                List <double> toutputsNew = new List <double>();
                toutputsOld = SoftMax(ann.CalcOutput(replayMemory[i].states));                  // why a softmax?

                double maxQOld = toutputsOld.Max();
                int    action  = toutputsOld.ToList().IndexOf(maxQOld);

                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    toutputsNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));
                    maxQ        = toutputsNew.Max();
                    feedback    = (replayMemory[i].reward +
                                   discount * maxQ);
                }

                toutputsOld[action] = feedback;
                ann.Train(replayMemory[i].states, toutputsOld);
            }

            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            ball.GetComponent <BallState>().dropped = false;
            this.transform.rotation = Quaternion.identity;
            ResetBall();
            replayMemory.Clear();
            failCount++;
        }
    }
예제 #24
0
    private void ActivateDriver()
    {
        if (!trainingDone)
        {
            return;
        }

        List <double> calcOutputs = new List <double>();
        List <double> inputs      = new List <double>();
        List <double> outputs     = new List <double>();

        Debug.DrawRay(transform.position, transform.forward * visibleDistance, Color.red);
        Debug.DrawRay(transform.position, transform.right * visibleDistance, Color.blue);

        // Raucasts
        RaycastHit hit;
        float      fDist   = 0;
        float      rDist   = 0;
        float      lDist   = 0;
        float      r45Dist = 0;
        float      l45Dist = 0;

        // Forward
        if (Physics.Raycast(transform.position, transform.forward, out hit, visibleDistance))
        {
            fDist = 1 - Round(hit.distance / visibleDistance);
        }

        // Right
        if (Physics.Raycast(transform.position, transform.right, out hit, visibleDistance))
        {
            rDist = 1 - Round(hit.distance / visibleDistance);
        }

        // Left
        if (Physics.Raycast(transform.position, -transform.right, out hit, visibleDistance))
        {
            lDist = 1 - Round(hit.distance / visibleDistance);
        }

        // Right 45
        if (Physics.Raycast(transform.position, Quaternion.AngleAxis(45, Vector3.up) * -transform.right, out hit, visibleDistance))
        {
            r45Dist = 1 - Round(hit.distance / visibleDistance);
        }

        // Left 45
        if (Physics.Raycast(transform.position, Quaternion.AngleAxis(-45, Vector3.up) * transform.right, out hit, visibleDistance))
        {
            l45Dist = 1 - Round(hit.distance / visibleDistance);
        }

        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);

        outputs.Add(0);
        outputs.Add(0);

        calcOutputs = ann.CalcOutput(inputs, outputs);
        float translationInput = Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        float rotationInput    = Map(-1, 1, 0, 1, (float)calcOutputs[1]);

        translation = translationInput * speed * Time.deltaTime;
        rotation    = rotationInput * rotationSpeed * Time.deltaTime;

        transform.Translate(0, 0, translation);
        transform.Rotate(0, rotation, 0);
    }
예제 #25
0
//------------------------------------------------------------------------------------------
//------------------------------------------------------------------------------------------

    void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.rotation.z);
        states.Add(pole.transform.position.x);
        // states.Add(pole.GetComponent<Rigidbody>().angularVelocity.x);
        states.Add(pole.GetComponent <Rigidbody2D>().angularVelocity);

        qs = SoftMax(nn.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        //Exploration taking a random action
        // if(Random.Range(0, 100) < exploreRate) maxQIndex = Random.Range(0, 2);

        /* ===================================================== */
        //WARNING To modify every time
        /* ===================================================== */
        if (maxQIndex == 0)
        {
            this.transform.position = new Vector3(this.transform.position.x * maxSpeed * (float)qs[maxQIndex], this.transform.position.y, this.transform.position.z);
        }
        else if (maxQIndex == 1)
        {
            this.transform.position = new Vector3(-this.transform.position.x * maxSpeed * (float)qs[maxQIndex], this.transform.position.y, this.transform.position.z);
        }
        // if(maxQIndex == 0) this.transform.Translate(Vector3.right * maxSpeed * (float)qs[maxQIndex]);
        // else if(maxQIndex == 1) this.transform.Translate(Vector3.right * -maxSpeed * (float)qs[maxQIndex]);

        // if(maxQIndex == 0) this.transform.Rotate(Vector3.right, tiltSpeed * (float)qs[maxQIndex]);
        // else if(maxQIndex == 1) this.transform.Rotate(Vector3.right, -tiltSpeed * (float)qs[maxQIndex]);

        if (pole.GetComponent <PoleState>().dropped)
        {
            reward = -1.0f;
        }
        else
        {
            reward = 0.1f;
        }

        /* ===================================================== */
        //WARNING To modify every time
        /* ===================================================== */
        Replay lastMemory = new Replay(this.transform.position.x,
                                       pole.transform.position.x,
                                       pole.GetComponent <Rigidbody2D>().angularVelocity,
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        //Q-Learning part
        if (pole.GetComponent <PoleState>().dropped)
        {
            //Loop backwards
            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputsOld = new List <double>();
                List <double> toutputsNew = new List <double>();
                toutputsOld = SoftMax(nn.CalcOutput(replayMemory[i].states));

                double maxQOld = toutputsOld.Max();
                int    action  = toutputsOld.ToList().IndexOf(maxQOld);

                //Bellman's Equation
                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    toutputsNew = SoftMax(nn.CalcOutput(replayMemory[i + 1].states));
                    maxQ        = toutputsNew.Max();
                    feedback    = (replayMemory[i].reward + discount * maxQ);
                }

                toutputsOld[action] = feedback;
                nn.Train(replayMemory[i].states, toutputsOld);
            }
            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            /* ===================================================== */
            //WARNING To modify every time
            /* ===================================================== */
            pole.GetComponent <PoleState>().dropped = false;
            this.transform.position = Vector2.zero;
            ResetPole();
            replayMemory.Clear();
            failCount++;
        }
    }
예제 #26
0
    void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        RaycastHit hit;

        float fDist = visibleDistance, rDist = visibleDistance, lDist = visibleDistance, r45Dist = visibleDistance, l45Dist = visibleDistance;



        if (Physics.Raycast(transform.position, this.transform.forward, out hit, visibleDistance, terrainLayer))
        {
            fDist = Vector3.Distance(transform.position, hit.point);
        }

        if (Physics.Raycast(transform.position, this.transform.right, out hit, visibleDistance, terrainLayer))
        {
            rDist = Vector3.Distance(transform.position, hit.point);
        }

        if (Physics.Raycast(transform.position, -this.transform.right, out hit, visibleDistance, terrainLayer))
        {
            lDist = Vector3.Distance(transform.position, hit.point);
        }

        if (Physics.Raycast(transform.position, Quaternion.AngleAxis(-45, Vector3.up) * this.transform.right, out hit, visibleDistance, terrainLayer))
        {
            r45Dist = Vector3.Distance(transform.position, hit.point);
        }

        if (Physics.Raycast(transform.position, Quaternion.AngleAxis(45, Vector3.up) * -this.transform.right, out hit, visibleDistance, terrainLayer))
        {
            l45Dist = hit.distance;
        }

        // Debug.Log("Frontal: " + fDist + ", Derecha: " + rDist + ", Izquierda: " + lDist + ", Derecha45: " + r45Dist + ", Izquierda45: " + l45Dist);

        states.Add(fDist);
        states.Add(rDist);
        states.Add(lDist);
        states.Add(r45Dist);
        states.Add(l45Dist);

        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        //exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);


        //if(Random.Range(0,100) < exploreRate)
        //	maxQIndex = Random.Range(0,2);


        float translation = speed * Time.deltaTime;

        this.transform.Translate(0, 0, translation);



        if (maxQIndex == 0)
        {
            this.transform.Rotate(Vector3.up, tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 1)
        {
            this.transform.Rotate(Vector3.up, -tiltSpeed * (float)qs[maxQIndex]);
        }


        if (ball.GetComponent <BallState>().dropped)
        {
            reward = -1.0f;
            //reward = 0;
        }

        else if (ball.GetComponent <BallState>().point)
        {
            reward = 0.5f;
        }
        else
        {
            reward = 0.1f;// + 0.01f;
        }
        Replay lastMemory = new Replay(fDist, rDist, lDist, r45Dist, l45Dist,
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        if (ball.GetComponent <BallState>().dropped)
        {
            ResetBall();    //Para que no se quede pillado al no tener archivo.

            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputsOld = new List <double>();
                List <double> toutputsNew = new List <double>();
                toutputsOld = SoftMax(ann.CalcOutput(replayMemory[i].states));

                double maxQOld = toutputsOld.Max();
                int    action  = toutputsOld.ToList().IndexOf(maxQOld);

                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    toutputsNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));
                    maxQ        = toutputsNew.Max();
                    feedback    = (replayMemory[i].reward +
                                   discount * maxQ);
                }

                toutputsOld[action] = feedback;
                ann.Train(replayMemory[i].states, toutputsOld);
            }



            timer = 0;

            ball.GetComponent <BallState>().dropped = false;
            this.transform.rotation = Quaternion.identity;
            ResetBall();
            replayMemory.Clear();
            failCount++;
            if (_isAleatoryCircuit)
            {
                if (failCount == 1000)
                {
                    flowchart.ExecuteBlock("LOSE");
                }
                //Debug.Log( "Fails: " + failCount);
                onFail?.Invoke(failCount);
            }

            //reward = 0;/////////////////////////////////
        }

        if (ball.GetComponent <BallState>().meta)
        {
            ball.GetComponent <BallState>().meta = false;
            //string pesos = PlayerPrefs.GetString("Weights");
            string pesos = ann.PrintWeights();



            if (_isAleatoryCircuit)
            {
                if (!WIN && failCount <= 1000)
                {
                    if (flowchart != null)
                    {
                        flowchart.ExecuteBlock("WIN");
                    }
                    WIN = true;
                }

                /*List<string> saveFileContent = new List<string>();
                 * saveFileContent.Add(currentAleatoryCircuitName);
                 * saveFileContent.Add(pesos);
                 * SaveAndLoad.Save(saveFileContent, currentAleatoryCircuitName + ".txt");*/

                manager.SaveNewDataDictionary(currentAleatoryCircuitName, pesos);
                Debug.Log(currentAleatoryCircuitName);
            }
            else
            {
                managerCircuits.SaveNewDataDictionary(circuitName, pesos);
            }

            /*
             * if (!_isAleatoryCircuit && maxBalanceTime <= 0)
             * {
             *  pesos = ann.PrintWeights();
             *  SaveAndLoad.Save(pesos, CIRCUITO1);
             * }
             * else if(!_isAleatoryCircuit && maxBalanceTime > timer)
             * {
             *  pesos = ann.PrintWeights();
             *  SaveAndLoad.Save(pesos, CIRCUITO1);
             * }*/
            maxBalanceTime = timer;
            Debug.Log(maxBalanceTime);
            timer = 0;
        }
    }
예제 #27
0
    private void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.rotation.x);
        states.Add(ball.transform.position.z);
        states.Add(ball.GetComponent <Rigidbody>().angularVelocity.x);

        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQindex = qs.ToList().IndexOf(maxQ);

        explorerate = Mathf.Clamp(explorerate - exploredecay, minexplorerate, maxeplorerate);

        /*NO NEED OF EXPLORING IN THIS CASE AS ENVIRONMENT IS REALLY VERY SMALL
         * if(Random.Range(0,100)<explorerate)
         * {
         *   maxQindex = Random.Range(0, 2);
         * }*/
        if (maxQindex == 0)
        {
            this.transform.Rotate(Vector3.right, tiltspeed * (float)qs[maxQindex]);
        }
        else if (maxQindex == 1)
        {
            this.transform.Rotate(Vector3.right, -tiltspeed * (float)qs[maxQindex]);
        }
        if (ball.GetComponent <BallState>().dropped)
        {
            reward = -1;
        }
        else
        {
            reward = 0.1f;
        }
        Replay lastmemory = new Replay(this.transform.rotation.x, ball.transform.position.z, ball.GetComponent <Rigidbody>().angularVelocity.x, reward);

        if (replaymemory.Count > mcapacity)
        {
            replaymemory.RemoveAt(0);
        }
        replaymemory.Add(lastmemory);

        //Training and QLEARNING

        if (ball.GetComponent <BallState>().dropped)
        {
            for (int i = replaymemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputs_old = new List <double>();
                List <double> toutputs_new = new List <double>();
                toutputs_old = SoftMax(ann.CalcOutput(replaymemory[i].states));
                double maxQ_old = toutputs_old.Max();
                int    action   = toutputs_old.ToList().IndexOf(maxQ_old);

                double feedback;
                if (i == replaymemory.Count - 1 || replaymemory[i].reward == -1)
                {
                    feedback = replaymemory[i].reward;
                }
                else
                {
                    toutputs_new = SoftMax(ann.CalcOutput(replaymemory[i + 1].states));
                    maxQ         = toutputs_new.ToList().Max();
                    feedback     = (replaymemory[i].reward + discount * maxQ); //BELLMAN EQUATION
                }
                toutputs_old[action] = feedback;
                ann.Train(replaymemory[i].states, toutputs_old);
            }
            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }
            timer = 0;
            ball.GetComponent <BallState>().dropped = false;
            this.transform.rotation = Quaternion.identity;
            Reset();
            replaymemory.Clear();
            fallcount++;
        }
    }
예제 #28
0
    void Update()
    {
        if (!trainingDone)
        {
            return;
        }

        List <double> calcOutputs = new List <double>();
        List <double> inputs      = new List <double>();
        List <double> outputs     = new List <double>();

        //raycasts
        //RaycastHit hit;
        float fDist = 0, rDist = 0, lDist = 0, r45Dist = 0, l45Dist = 0;

        Utils.PerformRayCasts(out fDist, out rDist, out lDist, out r45Dist, out l45Dist, this.transform, visibleDistance);

        //forward

        /*if (Physics.Raycast(transform.position, this.transform.forward, out hit, visibleDistance))
         * {
         *  fDist = 1-Round(hit.distance/visibleDistance);
         * }
         *
         * //right
         * if (Physics.Raycast(transform.position, this.transform.right, out hit, visibleDistance))
         * {
         *  rDist = 1-Round(hit.distance/visibleDistance);
         * }
         *
         * //left
         * if (Physics.Raycast(transform.position, -this.transform.right, out hit, visibleDistance))
         * {
         *  lDist = 1-Round(hit.distance/visibleDistance);
         * }
         *
         * //right 45
         * if (Physics.Raycast(transform.position,
         *                  Quaternion.AngleAxis(-45, Vector3.up) * this.transform.right, out hit, visibleDistance))
         * {
         *  r45Dist = 1-Round(hit.distance/visibleDistance);
         * }
         *
         * //left 45
         * if (Physics.Raycast(transform.position,
         *                  Quaternion.AngleAxis(45, Vector3.up) * -this.transform.right, out hit, visibleDistance))
         * {
         *  l45Dist = 1-Round(hit.distance/visibleDistance);
         * }*/

        inputs.Add(fDist);
        inputs.Add(rDist);
        inputs.Add(lDist);
        inputs.Add(r45Dist);
        inputs.Add(l45Dist);
        outputs.Add(0);
        outputs.Add(0);
        calcOutputs = ann.CalcOutput(inputs, outputs);
        float translationInput = Map(-1, 1, 0, 1, (float)calcOutputs[0]);
        float rotationInput    = Map(-1, 1, 0, 1, (float)calcOutputs[1]);

        translation = translationInput * speed * Time.deltaTime;
        rotation    = rotationInput * rotationSpeed * Time.deltaTime;
        this.transform.Translate(0, 0, translation);
        this.transform.Rotate(0, rotation, 0);
    }
예제 #29
0
    private void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.rotation.x);
        states.Add(ball.transform.position.z);
        states.Add(ball.GetComponent <Rigidbody>().angularVelocity.x);

        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        if (Random.Range(0, 10000) < exploreRate)
        {
            maxQIndex = Random.Range(0, 2);
        }

        if (maxQIndex == 0)
        {
            this.transform.Rotate(Vector3.right, tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 1)
        {
            this.transform.Rotate(Vector3.right, -tiltSpeed * (float)qs[maxQIndex]);
        }

        if (ball.GetComponent <BallState>().dropped)
        {
            reward = -5.0f;
        }
        else
        {
            reward = 0.1f;
        }

        Replay lastMemory = new Replay(this.transform.rotation.x,
                                       ball.transform.position.z,
                                       ball.GetComponent <Rigidbody>().angularVelocity.x,
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        if (ball.GetComponent <BallState>().dropped)
        {
            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputOld = new List <double>();
                List <double> toutputNew = new List <double>();
                toutputOld = SoftMax(ann.CalcOutput(replayMemory[i].states));

                double maxQOld = toutputOld.Max();
                int    action  = toutputOld.ToList().IndexOf(maxQOld);

                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    toutputNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));
                    maxQ       = toutputNew.Max();
                    feedback   = (replayMemory[i].reward * discount * maxQ);
                }

                toutputOld[action] = feedback;
                ann.Train(replayMemory[i].states, toutputOld);
            }

            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            ball.GetComponent <BallState>().dropped = false;
            this.transform.rotation = Quaternion.identity;
            ResetBall();
            replayMemory.Clear();
            failCount++;
        }
    }
예제 #30
0
    void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(Vector3.Distance(this.transform.position, topBeam.transform.position));
        states.Add(Vector3.Distance(this.transform.position, bottomBeam.transform.position));

        qs = SoftMax(ann.CalcOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        //if(Random.Range(0,100) < exploreRate)
        //	maxQIndex = Random.Range(0,2);

        if (maxQIndex == 0)
        {
            rb.AddForce(Vector3.up * moveForce * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 1)
        {
            rb.AddForce(Vector3.up * -moveForce * (float)qs[maxQIndex]);
        }

        if (crashed)
        {
            reward = -1.0f;
        }
        else
        {
            reward = 0.1f;
        }

        Replay lastMemory = new Replay(Vector3.Distance(this.transform.position, topBeam.transform.position),
                                       Vector3.Distance(this.transform.position, bottomBeam.transform.position),
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        if (crashed)
        {
            for (int i = replayMemory.Count - 1; i >= 0; i--)
            {
                List <double> toutputsOld = new List <double>();
                List <double> toutputsNew = new List <double>();
                toutputsOld = SoftMax(ann.CalcOutput(replayMemory[i].states));

                double maxQOld = toutputsOld.Max();
                int    action  = toutputsOld.ToList().IndexOf(maxQOld);

                double feedback;
                if (i == replayMemory.Count - 1 || replayMemory[i].reward == -1)
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    toutputsNew = SoftMax(ann.CalcOutput(replayMemory[i + 1].states));
                    maxQ        = toutputsNew.Max();
                    feedback    = (replayMemory[i].reward +
                                   discount * maxQ);
                }

                toutputsOld[action] = feedback;
                ann.Train(replayMemory[i].states, toutputsOld);
            }

            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            crashed = false;
            ResetBird();
            replayMemory.Clear();
            failCount++;
        }
    }