Exemplo n.º 1
0
        private static void Main(string[] args)
        {
            if (NETWORK_SIZE % 4 > 0)
            {
                throw new Exception("Network size must be a multiple of 4!");
            }

            System.Threading.Thread.CurrentThread.CurrentCulture = System.Globalization.CultureInfo.GetCultureInfo("en-US");
            rand = new Random();

            DeepNeuralNetwork[] nn = new DeepNeuralNetwork[NETWORK_SIZE];

            // Construct initial neural networks
            for (int i = 0; i < nn.Length; i++)
            {
                nn[i] = CreateNewNN();
            }

            dtLastIteration = DateTime.Now;
            // Perform training iterations
            for (int i = 0; i < ITERATION_LIMIT; i++)
            {
                PerformEvolution(ref nn, i);
            }

            // Print training summary
            Console.WriteLine(Environment.NewLine + "The top quarter neural networks have " + GetAverageDeviation(nn).ToString() + " average deviation.");
            Console.WriteLine("The best neural network has " + nn[0].AverageDeviation.ToString() + " deviation over " + (nn[0].IterationsAlive / SUB_ITERATION_COUNT).ToString() + " iterations.");

            // Let the user try the NN with his own input values
            while (true)
            {
                Console.Write(Environment.NewLine + "Enter input values: ");
                string strInput = Console.ReadLine();

                // Empty input is an exit request
                if (strInput == string.Empty)
                {
                    break;
                }

                // Separate individual values on the line
                double[] userInput = CSVParser.ParseLine(strInput);

                // Get output values from the NN
                double[] predictedOutput = nn[0].ProcessData(userInput);

                Console.WriteLine("Output: " + string.Join(" ", predictedOutput.AsEnumerable().Select(x => x.ToString())));
            }
        }
Exemplo n.º 2
0
        // Keep 1/4 of the NNs as survivors
        // Replace 2/4 of the NNs with modifications of the survivors
        // Fill the remaining 1/4 with completely new NNs
        private static void PerformEvolution(ref DeepNeuralNetwork[] networks, int iterationIdx)
        {
            for (int i = 0; i < SUB_ITERATION_COUNT; i++)
            {
                FeedInputToNetworks(networks, rand.Next(sourceData.Count));
            }

            // Order NNs by their average deviation
            networks = networks.OrderBy(x => x.AverageDeviation).ToArray();

            if ((DateTime.Now - dtLastIteration) >= new TimeSpan(0, 0, 1))
            {
                dtLastIteration = DateTime.Now;
                // Print average deviation for this iteration
                Console.WriteLine("Iteration: " + (iterationIdx + 1).ToString().PadLeft(ITERATION_LIMIT.ToString().Length) + " - Deviation: " + GetAverageDeviation(networks).ToString());
            }

            // Don't modify the NNs on the last iteration
            if (iterationIdx == ITERATION_LIMIT - 1)
            {
                return;
            }

            // Replace 2/4 of NNs with children of survivors
            for (int i = 0; i < 2; i++)
            {
                for (int j = 0; j < NETWORK_QUARTER; j++)
                {
                    // NN to be replaced with one of the children
                    ref DeepNeuralNetwork nn = ref networks[((i + 1) * NETWORK_QUARTER) + j];

                    // Don't use networks with no deviation as parents
                    // They would just produce their exact clones and that's pointless
                    if (networks[j].AverageDeviation == 0.0)
                    {
                        // Generate a completely new NN with random weights instead
                        nn = CreateNewNN();
                    }
                    else
                    {
                        // Replace the NN with a slightly modifier version of one of the survivors
                        nn = networks[j].ShakeWeights();
                    }
                }
            }