Ejemplo n.º 1
0
        public void Run()
        {
            Network = new NeuralNetwork(NetworkProperties.Default, 2, 3, 1);
            Network.Initialize();

            Console.WriteLine("Enter number of training iterations:");
            int trainingIterations = int.Parse(Console.ReadLine());

            TrainingPair[] trainingSet = new TrainingPair[]
            {
                new TrainingPair(new Vector(0, 0), new Vector(0f)),
                new TrainingPair(new Vector(0, 1), new Vector(1f)),
                new TrainingPair(new Vector(1, 0), new Vector(1f)),
                new TrainingPair(new Vector(1, 1), new Vector(0f))
            };

            Stopwatch time = Stopwatch.StartNew();

            for (int i = 0; i < trainingIterations; i++)
            {
                Random       r   = new Random();
                TrainingPair set = trainingSet[r.Next(0, trainingSet.Length)];
                Network.Train(set.Input, set.Output);

                Console.WriteLine();
                Network.Input.Nodes.Print("Input");
                Network.Output.Nodes.Print("Output");
            }

            long dt = time.ElapsedMilliseconds;

            Console.WriteLine("Training Time: " + dt + "ms");
            Console.WriteLine("Training Complete\n");

            while (true)
            {
                Vector input = new Vector(2);

                Console.WriteLine("Enter Input 1:");
                float input1 = float.Parse(Console.ReadLine());
                Console.WriteLine("Enter Input 2:");
                float input2 = float.Parse(Console.ReadLine());

                input[0] = input1;
                input[1] = input2;

                input.Print("Input Vector");

                Network.Input.Nodes = input;

                Network.FeedForward();
                Network.Output.Nodes.Print("Results");

                Console.WriteLine("Exit? (y/n)");
                if (Console.ReadLine() == "y")
                {
                    break;
                }
            }
        }
Ejemplo n.º 2
0
    /// <summary>
    /// Writes
    /// </summary>
    /// <param name="debug">If set to <c>true</c>, write all recorded </param>
    public void WriteToFile(bool debug = false)
    {
        TrainingPair trainingPair = new TrainingPair(this.goal);

        trainingPair.InitializeFromGame(this.self, this.self.transform.FindChild("FirstPersonCharacter").gameObject);

        // write to file with features and target movements
        if (debug)
        {
            //debug.Log (trainingPair.ToString ());
        }
        file.WriteLine(trainingPair.ToString());
        file.Flush();
    }
Ejemplo n.º 3
0
    void Start()
    {
        var patterns = new TrainingPair[] {
            new TrainingPair()
            {
                inputs = new double[] { -1, -1 }, result = new double[] { -1 }
            },
            new TrainingPair()
            {
                inputs = new double[] { -1, +1 }, result = new double[] { +1 }
            },
            new TrainingPair()
            {
                inputs = new double[] { +1, -1 }, result = new double[] { +1 }
            },
            new TrainingPair()
            {
                inputs = new double[] { +1, +1 }, result = new double[] { -1 }
            }
        };

        Train(patterns, 10000);
        Test(patterns);
    }
Ejemplo n.º 4
0
 void Train(TrainingPair[] patterns, int iter=100, double N=0.5, double M=0.1)
 {
     for (var i=0; i<iter; i++) {
         var error = 0.0;
         foreach (var p in patterns) {
             var inputs = p.inputs;
             var targets = p.result;
             Calculate (inputs);
             error = error + BackPropagate (targets, N, M);
         }
     }
 }
Ejemplo n.º 5
0
 void Test(TrainingPair[] patterns)
 {
     foreach (var p in patterns) {
         Debug.Log (p);
         var log = "Actual:" + string.Join(" ", (from i in Calculate (p.inputs) select i.ToString()).ToArray());
         Debug.Log (log);
     }
 }
Ejemplo n.º 6
0
 void Start()
 {
     var patterns = new TrainingPair[] {
         new TrainingPair () { inputs=new double[] {-1,-1}, result=new double[] { -1 }},
         new TrainingPair () { inputs=new double[] {-1,+1}, result=new double[] { +1 }},
         new TrainingPair () { inputs=new double[] {+1,-1}, result=new double[] { +1 }},
         new TrainingPair () { inputs=new double[] {+1,+1}, result=new double[] { -1 }}
     };
     Train (patterns, 10000);
     Test (patterns);
 }
    public void TrainAINetwork(string inputPath, string serializePath = null)
    {
        if (serializePath != null && File.Exists(serializePath))
        {
            using (FileStream fs = new FileStream(serializePath, FileMode.Open)) {
                this.network = (ActivationNetwork) new BinaryFormatter().Deserialize(fs);
                return;
            }
        }

        List <TrainingPair> trainingList = new List <TrainingPair>();

        System.IO.StreamReader file =
            new System.IO.StreamReader(inputPath);
        string line;
        List <List <double> > input  = new List <List <double> >();
        List <List <double> > output = new List <List <double> > ();

        // skip the header (first line);
        file.ReadLine();
        while ((line = file.ReadLine()) != null)
        {
            TrainingPair tp = new TrainingPair(this.goal);
            line = line.Trim();
            tp.InitializeFromSaved(line);
            trainingList.Add(tp);
            if (tp.observedAction.xRotInput != 0.0f || tp.observedAction.yRotInput != 0.0f || tp.observedAction.forwardPan != 0.0f ||
                tp.observedAction.horizontalPan != 0.0f || tp.observedAction.fireButtonDown != 0.0f)
            {
                // only add non-zero examples for now
                input.Add(new List <double> ()
                {
                    tp.gameStateSummary.XZAngleToObj,
                    tp.gameStateSummary.YZAngletoObj,
                    tp.gameStateSummary.distToObj
                });
                output.Add(new List <double> ()
                {
                    tp.observedAction.yRotInput,
                    tp.observedAction.xRotInput,
                    tp.observedAction.horizontalPan,
                    tp.observedAction.forwardPan,
                    tp.observedAction.fireButtonDown,
                    tp.observedAction.sprintButtonDown,
                    tp.observedAction.jumpButtonDown
                });
            }
        }
        file.Close();
        Debug.Log(string.Format("Training List Length: {0}", input.Count));

        for (int i = 0; i < Config.Instance.node["training_epochs"].AsInt; i++)
        {
            double error = this.teacher.RunEpoch(NestedListToArray(input), NestedListToArray(output));
            if (i % 50 == 0)
            {
                Debug.Log(string.Format("iteration: {0}, Error: {1}", i, error));
            }
        }

        if (serializePath != null)
        {
            using (FileStream fs = new FileStream(serializePath, FileMode.Create))
            {
                new BinaryFormatter().Serialize(fs, this.network);
            }
        }
    }