コード例 #1
0
        public void setInput(EMNistDecoder _emnistDecoder, int imageNumber)
        {
            int k = 0;

            for (int i = 0; i < _emnistDecoder.numRows; i++)
            {
                for (int j = 0; j < _emnistDecoder.numCols; j++)
                {
                    inputArray[k] = _emnistDecoder.images[imageNumber, i, j] == 0 ? 0 : 1;
                    k++;
                }
            }
        }
コード例 #2
0
 public Input(EMNistDecoder _emnistDecoder)
 {
     inputArray = new double[_emnistDecoder.dimensions];
 }
コード例 #3
0
        static void Main()
        {
            int   layerSize;
            int   inputLayerSize;
            int   outputLayerSize;
            int   networkSizeCounter    = 0; //Counts Network Size for display
            int   connectionSizeCounter = 0; //Calculates Connection Size for Display
            int   imageAndLabelCounter  = 0;
            float imageCount;
            float succesfulPredictions   = 0;
            float unsuccesfulPredictions = 0;

            int[] layerAndNodeAmountHolder;

            string labelFilePath;
            string imagesFilePath;

            labelFilePath  = @"C:\EMNist Dataset\gzip\gzip\emnist-mnist-train-labels-idx1-ubyte\emnist-mnist-train-labels-idx1-ubyte";
            imagesFilePath = @"C:\EMNist Dataset\gzip\gzip\emnist-mnist-train-images-idx3-ubyte\emnist-mnist-train-images-idx3-ubyte";

            /*
             *  Console.WriteLine("Please enter labelFilePath");
             *  while(true)
             *  {
             *      labelFilePath = Console.ReadLine();
             *      if(File.Exists(labelFilePath))
             *      {
             *          break;
             *      }
             *      else
             *      {
             *          Console.WriteLine("Entered wrong file location. Please enter labelFilePath");
             *      }
             *  }
             *
             *  Console.WriteLine("Please enter imagesFilePath");
             *
             *  while (true)
             *  {
             *      imagesFilePath = Console.ReadLine();
             *      if (File.Exists(imagesFilePath))
             *      {
             *          break;
             *      }
             *      else
             *      {
             *          Console.WriteLine("Entered wrong file location. Please enter imagesFilePath");
             *      }
             *  }
             */

            EMNistDecoder emnistDecoder = new EMNistDecoder();

            emnistDecoder.EMNistDecoderInit(labelFilePath, imagesFilePath);

            imageCount = emnistDecoder.numImages;

            inputLayerSize = emnistDecoder.dimensions;
            Console.WriteLine("Please enter classification output parameter");
            outputLayerSize = Convert.ToInt32(Console.ReadLine());

            Variables.OutputSize = outputLayerSize;

            Variables.LearningRate = 0.05f;

            Output outputArray = new Output(Variables.OutputSize);


            Stopwatch stopwatch = new Stopwatch();

            Network network = new Network();

            //Variables variables = new Variables();

            Random rand = new Random();

            Console.WriteLine("Please Enter Hidden Layer Amount");
            Variables.LayerAmount    = Convert.ToInt32(Console.ReadLine());
            layerAndNodeAmountHolder = new int[Variables.LayerAmount];


            //Console.WriteLine("Please enter Input Layer Size");
            //layerSize = Convert.ToInt32(Console.ReadLine());

            //layerAndNodeAmountHolder[0] = layerSize;

            Variables.InputSize = inputLayerSize;
            Input inputArray = new Input(emnistDecoder);


            for (int i = 0; i < layerAndNodeAmountHolder.Length; i++)
            {
                Console.WriteLine("Please enter " + (i + 1) + "th Hidden Layer Size");
                layerSize = Convert.ToInt32(Console.ReadLine());
                layerAndNodeAmountHolder[i] = layerSize;
            }

            //Console.WriteLine("Please enter Output Layer Size");
            //layerSize = Convert.ToInt32(Console.ReadLine());
            // layerSize = emnistDecoder.labels.Length;
            //layerAndNodeAmountHolder[Variables.LayerAmount - 1] = layerSize;


            stopwatch.Start();

            network.addLayerToNetwork(inputLayerSize, 0);

            for (int i = 0; i < layerAndNodeAmountHolder.Length; i++)
            {
                network.addLayerToNetwork(layerAndNodeAmountHolder[i], 1);
            }

            network.addLayerToNetwork(outputLayerSize, 2);



            networkSizeCounter = network.networkSize;

            for (int i = 0; i < network.networkSize - 1; i++)
            {
                connectionSizeCounter += network.Layers[i].layerSize * network.Layers[i + 1].layerSize;
            }

            stopwatch.Stop();

            Console.WriteLine("Network has " + networkSizeCounter + " Layers");
            Console.WriteLine("Network has " + NodeCounter + " Neurons");
            Console.WriteLine("Network has " + connectionSizeCounter + " Connections");
            Console.WriteLine("Creation of the network took " + stopwatch.Elapsed);

            //Console.WriteLine("Please enter Training Data Set Locaiton");
            //variables.TrainingFilePath = Console.ReadLine();

            Console.WriteLine("Magic number of training set is =" + emnistDecoder.magic1);
            Console.WriteLine("Image Count in database is = " + emnistDecoder.numImages);



            Console.Read();
            Console.Clear();
            imageAndLabelCounter = 0;

            while (imageAndLabelCounter < imageCount)
            {
                //emnistDecoder.emnistDecoderPrint(imageAndLabelCounter);
                inputArray.setInput(emnistDecoder, imageAndLabelCounter);
                //inputArray.debugPrintInput();
                //Thread.Sleep(1000);
                outputArray.setExpectedOutputArray(emnistDecoder.getCurrentImageLabel(imageAndLabelCounter));
                network.setInputLayerInputs(inputArray);
                network.feedForward(network);
                //network.printNetworkLayerOutputDelta(2);
                //network.printNetworkLayerBias(2);
                //network.printNetworkLayerOutput(2);
                Console.WriteLine("Network prediction was : " + network.getPrediction());

                //Console.ReadLine();

                if (network.getPrediction() == emnistDecoder.getCurrentImageLabel(imageAndLabelCounter))
                {
                    Console.WriteLine("Succesful Prediction!");
                    //Console.WriteLine("\n##Succesful Values##");
                    //network.printNetworkLayerOutputDelta(2);
                    //network.printNetworkLayerBias(2);
                    //network.printNetworkLayerOutput(2);
                    succesfulPredictions++;
                }
                else
                {
                    Console.WriteLine("Bad Prediction :(...");

                    //Console.WriteLine("##Before Values##");
                    //network.printNetworkLayerOutputDelta(2);
                    //network.printNetworkLayerBias(2);
                    //network.printNetworkLayerOutput(2);

                    unsuccesfulPredictions++;
                    network.calculateOutputError(outputArray);
                    network.startBackPropogation();
                    network.updateWeightsAndBiases();


                    //Console.WriteLine("\n##After Values##");
                    //network.printNetworkLayerOutputDelta(2);
                    //network.printNetworkLayerBias(2);
                    //network.printNetworkLayerOutput(2);
                }
                Console.WriteLine("Success Rate is : %" + ((succesfulPredictions / (imageAndLabelCounter + 1)) * 100));
                Console.WriteLine("Training Progress : " + ((imageAndLabelCounter + 1) / imageCount) * 100);
                imageAndLabelCounter++;
                network.printNetworkLayerOutputDelta(2);
                network.printNetworkLayerBias(2);
                network.printNetworkLayerOutput(2);
                //network.printInputDebug();
                //Console.ReadLine();
                outputArray.resetExpectedOutputArray();
                network.resetAllDeltaValues();
                Console.SetCursorPosition(0, 0);

                // Console.ReadLine();
                // Console.Clear();
                // Thread.Sleep(1000);
            }


            /* for (int i = 0; i < 1000; i++)
             * {
             *   mnistDecoder.mNistSetImage();
             *   mnistDecoder.mNistSetLabel();
             *
             *   Console.WriteLine();
             * }
             */
            Console.ReadLine();
        }