Ejemplo n.º 1
0
    List <double> Run(double bx,
                      double by,
                      double bvx,
                      double bvy,
                      double px,
                      double py,
                      double pv,
                      bool train)
    {
        List <double> inputs  = new List <double>();
        List <double> outputs = new List <double>();

        inputs.Add(bx);
        inputs.Add(by);
        inputs.Add(bvx);
        inputs.Add(bvy);
        inputs.Add(px);
        inputs.Add(py);
        outputs.Add(pv);

        if (train)
        {
            return(ann.Train(inputs, outputs));
        }
        else
        {
            return(ann.CalculateOutput(inputs, outputs));
        }
    }
Ejemplo n.º 2
0
        public static void Main(string[] args)
        {
            var dataset = new Dataset.Dataset(DatasetFilePath);
            var ann     = new ANN(Architecture);

            var fitness = new FitnessFunction(ann, dataset);

            var mutation   = new GaussMutation(MutationThreshold, MutationProbability, Sigma1, Sigma2);
            var selection  = new KTournamentSelection <DoubleArrayChromosome>(TournamentSize);
            var crossovers = new List <ICrossover <DoubleArrayChromosome> >()
            {
                new ArithmeticCrossover(),
                new HeuristicCrossover(),
                new UniformCrossover()
            };

            var geneticAlgorithm =
                new EliminationGeneticAlgorithm(mutation, selection, crossovers, fitness, IterationLimit, ErrorLimit, PopulationSize);

            var optimum = geneticAlgorithm.FindOptimum();
            var correctClassification = 0;

            foreach (var sample in dataset)
            {
                var classification = ann.CalculateOutput(sample.Input, optimum.Values);
                var correct        = true;

                for (int i = 0; i < classification.Length; i++)
                {
                    classification[i] = classification[i] < 0.5 ? 0 : 1;
                    if (Math.Abs(classification[i] - sample.Classification[i]) > 10e-9)
                    {
                        correct = false;
                    }
                }

                Console.WriteLine(classification[0] + " " + classification[1] + " " + classification[2] +
                                  " <=> " + sample.Classification[0] + " " + sample.Classification[1] + " " + sample.Classification[2] + " ");

                if (correct)
                {
                    correctClassification++;
                }
            }

            Console.WriteLine("Correct => " + correctClassification + ", Total => " + dataset.Count());

            ann.WriteNeuronLayerParametersToFile(ParametersFilePath, 1, optimum.Values);
        }
Ejemplo n.º 3
0
    private void FixedUpdate()
    {
        timer += Time.deltaTime;
        List <double> states = new List <double>();
        List <double> qs     = new List <double>();

        states.Add(this.transform.rotation.x);
        states.Add(this.transform.rotation.z);
        states.Add(this.transform.position.z);
        states.Add(ball.GetComponent <Rigidbody>().angularVelocity.x);
        states.Add(ball.GetComponent <Rigidbody>().angularVelocity.z);

        qs = ANN.SoftMax(ann.CalculateOutput(states));
        double maxQ      = qs.Max();
        int    maxQIndex = qs.ToList().IndexOf(maxQ);

        exploreRate = Mathf.Clamp(exploreRate - exploreDecay, minExploreRate, maxExploreRate);

        //check to see if we choose a random action
        if (UnityEngine.Random.Range(1, 100) < exploreRate)
        {
            maxQIndex = UnityEngine.Random.Range(0, 4);
        }

        //action 0 tilt right
        //action 1 tilt left
        //action 2 tilt forward
        //action 3 tilt backward
        //mapQIndex == 0 means action 0
        if (maxQIndex == 0)
        {
            this.transform.Rotate(Vector3.right, tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 1)
        {
            this.transform.Rotate(Vector3.right, -tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 2)
        {
            this.transform.Rotate(Vector3.forward, tiltSpeed * (float)qs[maxQIndex]);
        }
        else if (maxQIndex == 3)
        {
            this.transform.Rotate(Vector3.forward, -tiltSpeed * (float)qs[maxQIndex]);
        }

        if (ball.GetComponent <BallState>().dropped)
        {
            reward = -1f;
        }
        else
        {
            reward = 0.1f;
        }

        Replay lastMemory = new Replay(this.transform.rotation.x,
                                       this.transform.rotation.z,
                                       ball.transform.position.z,
                                       ball.GetComponent <Rigidbody>().angularVelocity.x,
                                       ball.GetComponent <Rigidbody>().angularVelocity.z,
                                       reward);

        if (replayMemory.Count > mCapacity)
        {
            replayMemory.RemoveAt(0);
        }

        replayMemory.Add(lastMemory);

        //Q learning starts here
        //upto this point all we did is get an inputs and getting the result from ann,
        //rewarding accordingly and then storing them.
        if (ball.GetComponent <BallState>().dropped)
        {
            //looping backwards so the quality of the last memory get carried
            //backwards up through the list so we can attributed it's blame through
            //the list
            for (int i = replayMemory.Count - 1; i >= 0; --i)
            {
                //foreach memory we ran the ann
                //first we found out what are the q values of the current memory
                List <double> currentMemoryQValues = new List <double>();
                //then we take the q values of the next memory
                List <double> nextMemoryQValues = new List <double>();
                currentMemoryQValues = ANN.SoftMax(ann.CalculateOutput(replayMemory[i].states));

                //find the maximum Q value of the current memories
                double maxQOld = currentMemoryQValues.Max();
                //which action gave that q value
                int action = currentMemoryQValues.ToList().IndexOf(maxQOld);

                double feedback;
                //checking if the current memory is the last memeory
                //or if that memory reward is -1, if it is -1, it means, that ball was dropped
                //and every memory after this is meaningless, because this is the end of the
                //memories sequance
                if ((i == replayMemory.Count - 1) || (replayMemory[i].reward == -1f))
                {
                    feedback = replayMemory[i].reward;
                }
                else
                {
                    nextMemoryQValues = ANN.SoftMax(ann.CalculateOutput(replayMemory[i + 1].states));
                    maxQ     = nextMemoryQValues.Max();
                    feedback = (replayMemory[i].reward + discount * maxQ);
                }

                //adding the correct reward (Q value) to the current action
                currentMemoryQValues[action] = feedback;
                //using the feedback to train the ANN
                ann.Train(replayMemory[i].states, currentMemoryQValues);
            }

            if (timer > maxBalanceTime)
            {
                maxBalanceTime = timer;
            }

            timer = 0;

            ball.GetComponent <BallState>().dropped = false;
            this.transform.rotation = Quaternion.identity;
            ResetBall();
            replayMemory.Clear();
            failCount++;
        }
    }