Exemplo n.º 1
0
        static void Main(string[] args)
        {
            string dirPath = "C:/Users/jacopo/Dropbox/Chalmers/MSc thesis";

            /*****************************************************
             * (0) Setup OpenCL
             ****************************************************/
            Console.WriteLine("\n=========================================");
            Console.WriteLine("    OpenCL setup");
            Console.WriteLine("=========================================\n");

            OpenCLSpace.SetupSpace(4);
            OpenCLSpace.KernelsPath = dirPath + "/ConvDotNet/Kernels";
            OpenCLSpace.LoadKernels();


            /*****************************************************
             * (1) Load data
             ******************************************************/

            string imageColor = "none";

            #region DataImport

            Console.WriteLine("\n=========================================");
            Console.WriteLine("    Importing data");
            Console.WriteLine("=========================================\n");

            DataSet trainingSet   = new DataSet(43);
            DataSet validationSet = new DataSet(43);
            DataSet testSet       = new DataSet(43);

            if (imageColor == "GS1")
            {
                // GTSRB training set (grayscale)
                string GTSRBtrainingDataGS   = dirPath + "/GTSRB/Preprocessed/14_training_images.dat";
                string GTSRBtrainingLabelsGS = dirPath + "/GTSRB/Preprocessed/14_training_classes.dat";


                // GTSRB validation set (grayscale)
                string GTSRBvalidationDataGS   = dirPath + "/GTSRB/Preprocessed/14_validation_images.dat";
                string GTSRBvalidationLabelsGS = dirPath + "/GTSRB/Preprocessed/14_validation_classes.dat";


                // GTSRB test set (grayscale)
                string GTSRBtestDataGS   = dirPath + "/GTSRB/Preprocessed/14_test_images.dat";
                string GTSRBtestLabelsGS = dirPath + "/GTSRB/Preprocessed/test_labels_full.dat";

                Console.WriteLine("Importing training set...");
                trainingSet.ReadData(GTSRBtrainingDataGS, GTSRBtrainingLabelsGS);

                Console.WriteLine("Importing validation set...");
                validationSet.ReadData(GTSRBvalidationDataGS, GTSRBvalidationLabelsGS);

                Console.WriteLine("Importing test set...");
                testSet.ReadData(GTSRBtestDataGS, GTSRBtestLabelsGS);
            }
            else if (imageColor == "GS2")
            {
                // GTSRB training set (RGB)
                string GTSRBtrainingDataRGB   = dirPath + "/GTSRB/Preprocessed/18_training_images.dat";
                string GTSRBtrainingLabelsRGB = dirPath + "/GTSRB/Preprocessed/18_training_classes.dat";


                // GTSRB validation set (RGB)
                string GTSRBvalidationDataRGB   = dirPath + "/GTSRB/Preprocessed/18_validation_images.dat";
                string GTSRBvalidationLabelsRGB = dirPath + "/GTSRB/Preprocessed/18_validation_classes.dat";


                // GTSRB test set (RGB)
                string GTSRBtestDataRGB   = dirPath + "/GTSRB/Preprocessed/18_test_images.dat";
                string GTSRBtestLabelsRGB = dirPath + "/GTSRB/Preprocessed/test_labels_full.dat";

                Console.WriteLine("Importing training set...");
                trainingSet.ReadData(GTSRBtrainingDataRGB);
                trainingSet.ReadLabels(GTSRBtrainingLabelsRGB);


                Console.WriteLine("Importing validation set...");
                validationSet.ReadData(GTSRBvalidationDataRGB);
                validationSet.ReadLabels(GTSRBvalidationLabelsRGB);


                Console.WriteLine("Importing test set...");
                testSet.ReadData(GTSRBtestDataRGB);
                testSet.ReadLabels(GTSRBtestLabelsRGB);
            }
            else if (imageColor == "RGB1")
            {
                // GTSRB training set (RGB)
                string GTSRBtrainingDataRGB   = dirPath + "/GTSRB/Preprocessed/16_training_images.dat";
                string GTSRBtrainingLabelsRGB = dirPath + "/GTSRB/Preprocessed/16_training_classes.dat";


                // GTSRB validation set (RGB)
                string GTSRBvalidationDataRGB   = dirPath + "/GTSRB/Preprocessed/16_validation_images.dat";
                string GTSRBvalidationLabelsRGB = dirPath + "/GTSRB/Preprocessed/16_validation_classes.dat";


                // GTSRB test set (RGB)
                string GTSRBtestDataRGB   = dirPath + "/GTSRB/Preprocessed/16_test_images.dat";
                string GTSRBtestLabelsRGB = dirPath + "/GTSRB/Preprocessed/test_labels_full.dat";

                Console.WriteLine("Importing training set...");
                trainingSet.ReadData(GTSRBtrainingDataRGB, GTSRBtrainingLabelsRGB);

                Console.WriteLine("Importing validation set...");
                validationSet.ReadData(GTSRBvalidationDataRGB, GTSRBvalidationLabelsRGB);

                Console.WriteLine("Importing test set...");
                testSet.ReadData(GTSRBtestDataRGB, GTSRBtestLabelsRGB);
            }
            else if (imageColor == "RGB2")
            {
                // GTSRB training set (RGB)
                string GTSRBtrainingDataRGB   = dirPath + "/GTSRB/Preprocessed/20_training_images.dat";
                string GTSRBtrainingLabelsRGB = dirPath + "/GTSRB/Preprocessed/20_training_classes.dat";


                // GTSRB validation set (RGB)
                string GTSRBvalidationDataRGB   = dirPath + "/GTSRB/Preprocessed/20_validation_images.dat";
                string GTSRBvalidationLabelsRGB = dirPath + "/GTSRB/Preprocessed/20_validation_classes.dat";


                // GTSRB test set (RGB)
                string GTSRBtestDataRGB   = dirPath + "/GTSRB/Preprocessed/20_test_images.dat";
                string GTSRBtestLabelsRGB = dirPath + "/GTSRB/Preprocessed/test_labels_full.dat";

                Console.WriteLine("Importing training set...");
                trainingSet.ReadData(GTSRBtrainingDataRGB, GTSRBtrainingLabelsRGB);

                Console.WriteLine("Importing validation set...");
                validationSet.ReadData(GTSRBvalidationDataRGB, GTSRBvalidationLabelsRGB);

                Console.WriteLine("Importing test set...");
                testSet.ReadData(GTSRBtestDataRGB, GTSRBtestLabelsRGB);
            }
            else if (imageColor == "RGB_16")
            {
                // GTSRB training set (RGB)
                string GTSRBtrainingDataRGB   = dirPath + "/GTSRB/Preprocessed/22_training_images.dat";
                string GTSRBtrainingLabelsRGB = dirPath + "/GTSRB/Preprocessed/22_training_classes.dat";


                // GTSRB validation set (RGB)
                string GTSRBvalidationDataRGB   = dirPath + "/GTSRB/Preprocessed/22_validation_images.dat";
                string GTSRBvalidationLabelsRGB = dirPath + "/GTSRB/Preprocessed/22_validation_classes.dat";


                // GTSRB test set (RGB)
                string GTSRBtestDataRGB   = dirPath + "/GTSRB/Preprocessed/22_test_images.dat";
                string GTSRBtestLabelsRGB = dirPath + "/GTSRB/Preprocessed/test_labels_full.dat";

                Console.WriteLine("Importing training set...");
                trainingSet.ReadData(GTSRBtrainingDataRGB, GTSRBtrainingLabelsRGB);

                Console.WriteLine("Importing validation set...");
                validationSet.ReadData(GTSRBvalidationDataRGB, GTSRBvalidationLabelsRGB);

                Console.WriteLine("Importing test set...");
                testSet.ReadData(GTSRBtestDataRGB, GTSRBtestLabelsRGB);
            }
            #endregion

            /*****************************************************
             * (2) Instantiate a neural network and add layers
             *
             * OPTIONS:
             * ConvolutionalLayer(filterSize, numberOfFilters, strideLength, zeroPadding)
             * ResidualModule(filterSize, numberOfFilters, strideLength, zeroPadding, nonlinearityType)
             * FullyConnectedLayer(numberOfUnits)
             * MaxPooling(2, 2)
             * AveragePooling()
             * ReLU()
             * ELU(alpha)
             * SoftMax()
             ****************************************************/

            #region NeuralNetworkCreation

            Console.WriteLine("\n=========================================");
            Console.WriteLine("    Neural network creation");
            Console.WriteLine("=========================================\n");

            // OPTION 1: Create a new network

            /*
             * NeuralNetwork network = new NeuralNetwork("SimplerLeNet_WD1e-4");
             *
             * network.AddLayer(new InputLayer(1, 32, 32));
             *
             * network.AddLayer(new ConvolutionalLayer(5, 32, 1, 0));
             * network.AddLayer(new ELU(1.0f));
             *
             * network.AddLayer(new MaxPooling(2, 2));
             *
             * network.AddLayer(new ConvolutionalLayer(5, 64, 1, 0));
             * network.AddLayer(new ELU(1.0f));
             *
             * network.AddLayer(new MaxPooling(2, 2));
             *
             * network.AddLayer(new FullyConnectedLayer(100));
             * network.AddLayer(new ELU(1.0f));
             *
             * network.AddLayer(new FullyConnectedLayer(100));
             * network.AddLayer(new ELU(1.0f));
             *
             * network.AddLayer(new FullyConnectedLayer(43));
             * network.AddLayer(new SoftMax());
             *
             * NetworkTrainer.TrainingMode = "new";
             */

            // OPTION 2: Load a network from file

            NeuralNetwork network = Utils.LoadNetworkFromFile(dirPath + "/Results/Networks/", "LeNet_RGBb_Dropout");
            //network.Set("MiniBatchSize", 64); // this SHOULDN'T matter!
            //network.InitializeParameters("load");
            //NetworkTrainer.TrainingMode = "resume";



            #endregion


            /*****************************************************
             * (3) Gradient check
             ******************************************************/
            //GradientChecker.Check(network, validationSet);


            /*****************************************************
            * (4) Train network
            ******************************************************
            *  Console.WriteLine("\n=========================================");
            *  Console.WriteLine("    Network training");
            *  Console.WriteLine("=========================================\n");
            *
            #region Training
            *
            *  // Set output files save paths
            *  string trainingSavePath = dirPath + "/Results/LossError/";
            *  NetworkTrainer.TrainingEpochSavePath = trainingSavePath + network.Name + "_trainingEpochs.txt";
            *  NetworkTrainer.ValidationEpochSavePath = trainingSavePath + network.Name + "_validationEpochs.txt";
            *  NetworkTrainer.NetworkOutputFilePath = dirPath + "/Results/Networks/";
            *
            *  NetworkTrainer.MomentumCoefficient = 0.9;
            *  NetworkTrainer.WeightDecayCoeff = 0.0001;
            *  NetworkTrainer.MaxTrainingEpochs = 200;
            *  NetworkTrainer.EpochsBeforeRegularization = 0;
            *  NetworkTrainer.MiniBatchSize = 64;
            *  NetworkTrainer.ConsoleOutputLag = 1; // 1 = print every epoch, N = print every N epochs
            *  NetworkTrainer.EvaluateBeforeTraining = true;
            *  NetworkTrainer.DropoutFullyConnected = 0.5;
            *  NetworkTrainer.DropoutConvolutional = 1.0;
            *  NetworkTrainer.DropoutInput = 1.0;
            *  NetworkTrainer.Patience = 1000;
            *  NetworkTrainer.LearningRateDecayFactor = Math.Sqrt(10.0);
            *  NetworkTrainer.MaxConsecutiveAnnealings = 3;
            *  NetworkTrainer.WeightMaxNorm = Double.PositiveInfinity;
            *
            *  NetworkTrainer.LearningRate = 0.00002;
            *  NetworkTrainer.Train(network, trainingSet, validationSet);
            *
            #endregion
            *
            *  /*****************************************************
            * (5) Test network
            *****************************************************/

            #region Testing

            Console.WriteLine("\nFINAL EVALUATION:");

            // Load best network from file
            NeuralNetwork bestNetwork = Utils.LoadNetworkFromFile("../../../../Results/Networks/", network.Name);
            bestNetwork.Set("MiniBatchSize", 64); // this SHOULDN'T matter!
            bestNetwork.InitializeParameters("load");
            bestNetwork.Set("Inference", true);

            //double loss;
            //double error;

            //NetworkEvaluator.EvaluateNetwork(bestNetwork, trainingSet, out loss, out error);
            //Console.WriteLine("\nTraining set:\n\tLoss = {0}\n\tError = {1}", loss, error);

            //NetworkEvaluator.EvaluateNetwork(bestNetwork, validationSet, out loss, out error);
            //Console.WriteLine("\nValidation set:\n\tLoss = {0}\n\tError = {1}\n\tAccuracy = {2}", loss, error, 100*(1-error));

            //NetworkEvaluator.EvaluateNetwork(bestNetwork, testSet, out loss, out error);
            //Console.WriteLine("\nTest set:\n\tLoss = {0}\n\tError = {1}\n\tAccuracy = {2}", loss, error, 100 * (1 - error));

            // Save misclassified examples
            //NetworkEvaluator.SaveMisclassifiedExamples(bestNetwork, trainingSet, "../../../../Results/MisclassifiedExamples/" + network.Name + "_training.txt");
            //NetworkEvaluator.SaveMisclassifiedExamples(bestNetwork, validationSet, "../../../../Results/MisclassifiedExamples/" + network.Name + "_validation.txt");
            //NetworkEvaluator.SaveMisclassifiedExamples(bestNetwork, testSet, "../../../../Results/MisclassifiedExamples/" + network.Name + "_test.txt");

            // Save filters to file
            bestNetwork.SaveWeights("first", "../../../../Results/Filters/");



            #endregion
            /*****************************************************/
        }
Exemplo n.º 2
0
        public static void Train(NeuralNetwork network, DataSet trainingSet, DataSet validationSet)
        {
            // Initialize parameters or load them
            if (trainingMode == "new" || trainingMode == "New")
            {
                // Setup miniBatchSize
                network.Set("MiniBatchSize", miniBatchSize);
                network.InitializeParameters("random");
            }
            else if (trainingMode == "resume" || trainingMode == "Resume")
            {
                network.InitializeParameters("load");
            }
            else
            {
                throw new InvalidOperationException("Please set TrainingMode to either ''New'' or ''Resume''.");
            }

            // Set dropout
            network.Set("DropoutFC", dropoutFC);
            network.Set("DropoutConv", dropoutConv);
            network.Set("DropoutInput", dropoutInput);

            Sequence indicesSequence = new Sequence(trainingSet.DataContainer.Count);

            int[] miniBatch = new int[miniBatchSize];

            // Timers
            Stopwatch stopwatch     = Stopwatch.StartNew();
            Stopwatch stopwatchFwd  = Stopwatch.StartNew();
            Stopwatch stopwatchGrad = Stopwatch.StartNew();
            Stopwatch stopwatchBwd  = Stopwatch.StartNew();

            int  epoch      = 0;
            int  nBadEpochs = 0;
            int  consecutiveAnnealingCounter = 0;
            bool stopFlag = false;
            int  epochsRemainingToOutput = (evaluateBeforeTraining == true) ? 0 : consoleOutputLag;


            while (!stopFlag) // begin loop over training epochs
            {
                if (epochsRemainingToOutput == 0)
                {
                    /**************
                    * Evaluation *
                    **************/

                    // Pre inference (for batch-norm)
                    //network.Set("PreInference", true);
                    //Console.WriteLine("Re-computing batch-norm means and variances...");
                    //NetworkEvaluator.PreEvaluateNetwork(network, trainingSet);

                    // Evaluate on training set...
                    network.Set("Inference", true);
                    Console.WriteLine("Evaluating on TRAINING set...");
                    stopwatch.Restart();
                    NetworkEvaluator.EvaluateNetwork(network, trainingSet, out lossTraining, out errorTraining);
                    Console.WriteLine("\tLoss = {0}\n\tError = {1}\n\tEval runtime = {2}ms\n",
                                      lossTraining, errorTraining, stopwatch.ElapsedMilliseconds);
                    // ...and save loss and error to file
                    using (System.IO.StreamWriter trainingEpochOutputFile = new System.IO.StreamWriter(trainingEpochSavePath, true))
                    {
                        trainingEpochOutputFile.WriteLine(lossTraining.ToString() + "\t" + errorTraining.ToString());
                    }

                    // Evaluate on validation set...
                    if (validationSet != null)
                    {
                        Console.WriteLine("Evaluating on VALIDATION set...");
                        stopwatch.Restart();
                        NetworkEvaluator.EvaluateNetwork(network, validationSet, out newLossValidation, out newErrorValidation);
                        Console.WriteLine("\tLoss = {0}\n\tError = {1}\n\tEval runtime = {2}ms\n",
                                          newLossValidation, newErrorValidation, stopwatch.ElapsedMilliseconds);
                        // ...save loss and error to file
                        using (System.IO.StreamWriter validationEpochOutputFile = new System.IO.StreamWriter(validationEpochSavePath, true))
                        {
                            validationEpochOutputFile.WriteLine(newLossValidation.ToString() + "\t" + newErrorValidation.ToString());
                        }

                        if (newLossValidation < minLossValidation)
                        {
                            // nice, validation loss is decreasing!
                            minLossValidation = newLossValidation;
                            errorValidation   = newErrorValidation;

                            // Save network to file
                            Utils.SaveNetworkToFile(network, networkOutputFilePath);

                            // and keep training
                            nBadEpochs = 0;
                            consecutiveAnnealingCounter = 0;
                        }
                        else
                        {
                            nBadEpochs++;
                            Console.WriteLine("Loss on the validation set has been increasing for {0} epoch(s)...", nBadEpochs);
                            if (patience - nBadEpochs > 0)
                            {
                                Console.WriteLine("...I'll be patient for {0} more epoch(s)!", patience - nBadEpochs); // keep training
                            }
                            else
                            {
                                //Console.WriteLine("...and I've run out of patience! Training ends here.");
                                //stopFlag = true;
                                //break;

                                // Decrease learning rate
                                Console.WriteLine("...and I've run out of patience!");

                                if (consecutiveAnnealingCounter > maxConsecutiveAnnealings)
                                {
                                    Console.WriteLine("\nReached the numner of maximum consecutive annealings without progress. \nTraining ends here.");
                                    break;
                                }

                                Console.WriteLine("\nI'm annealing the learning rate:\n\tWas {0}\n\tSetting it to {1}.", learningRate, learningRate / learningRateDecayFactor);
                                learningRate /= learningRateDecayFactor;
                                consecutiveAnnealingCounter++;

                                Console.WriteLine("\nAnd I'm loading the network saved {0} epochs ago and resume the training from there.", patience);

                                string networkName = network.Name;
                                network = null; // this is BAD PRACTICE
                                GC.Collect();   // this is BAD PRACTICE
                                network = Utils.LoadNetworkFromFile("../../../../Results/Networks/", networkName);
                                network.Set("MiniBatchSize", miniBatchSize);
                                network.InitializeParameters("load");


                                nBadEpochs = 0;
                            }
                        }
                    }

                    // Restore dropout
                    network.Set("DropoutFC", dropoutFC);
                    network.Set("DropoutConv", dropoutConv);
                    network.Set("DropoutInput", dropoutInput);

                    epochsRemainingToOutput = consoleOutputLag;
                }
                epochsRemainingToOutput--;

                epoch++;

                if (epoch > maxTrainingEpochs)
                {
                    break;
                }

                /************
                * Training *
                ************/

                network.Set("Training", true);
                network.Set("EpochBeginning", true);

                Console.WriteLine("\nEpoch {0}...", epoch);


                stopwatch.Restart();
                stopwatchFwd.Reset();
                stopwatchGrad.Reset();
                stopwatchBwd.Reset();

                indicesSequence.Shuffle(); // shuffle examples order at every epoch

                int iMiniBatch = 0;
                // Run over mini-batches
                for (int iStartMiniBatch = 0; iStartMiniBatch < trainingSet.DataContainer.Count; iStartMiniBatch += miniBatchSize)
                {
                    // Feed a mini-batch to the network
                    miniBatch = indicesSequence.GetMiniBatchIndices(iStartMiniBatch, miniBatchSize);
                    network.InputLayer.FeedData(trainingSet, miniBatch);

                    // Forward pass
                    stopwatchFwd.Start();
                    network.ForwardPass("beginning", "end");
                    stopwatchFwd.Stop();

                    // Compute gradient and backpropagate
                    stopwatchGrad.Start();
                    network.CrossEntropyGradient(trainingSet, miniBatch);
                    stopwatchGrad.Stop();

                    // Backpropagate gradient and update parameters
                    stopwatchBwd.Start();
                    network.BackwardPass(learningRate, momentumCoefficient, weightDecayCoeff, weightMaxNorm);
                    stopwatchBwd.Stop();

                    iMiniBatch++;

                    CheckForKeyPress(ref network, ref stopFlag);
                    if (stopFlag)
                    {
                        break;
                    }
                } // end of training epoch

                Console.Write(" Training runtime = {0}ms\n", stopwatch.ElapsedMilliseconds);

                Console.WriteLine("Forward: {0}ms - Gradient: {1}ms - Backward: {2}ms\n",
                                  stopwatchFwd.ElapsedMilliseconds, stopwatchGrad.ElapsedMilliseconds, stopwatchBwd.ElapsedMilliseconds);

#if TIMING_LAYERS
                Console.WriteLine("\n Detailed runtimes::");

                Console.WriteLine("\nCONV: \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms \n\tPadUnpad: {4}ms",
                                  Utils.ConvForwardTimer.ElapsedMilliseconds, Utils.ConvBackpropTimer.ElapsedMilliseconds,
                                  Utils.ConvUpdateSpeedsTimer.ElapsedMilliseconds, Utils.ConvUpdateParametersTimer.ElapsedMilliseconds, Utils.ConvPadUnpadTimer.ElapsedMilliseconds);

                Console.WriteLine("\nPOOLING: \n\tForward: {0}ms \n\tBackprop: {1}ms",
                                  Utils.PoolingForwardTimer.ElapsedMilliseconds, Utils.PoolingBackpropTimer.ElapsedMilliseconds);

                Console.WriteLine("\nNONLINEARITIES: \n\tForward: {0}ms \n\tBackprop: {1}ms",
                                  Utils.NonlinearityForwardTimer.ElapsedMilliseconds, Utils.NonlinearityBackpropTimer.ElapsedMilliseconds);

                Console.WriteLine("\nFULLY CONNECTED: \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.FCForwardTimer.ElapsedMilliseconds, Utils.FCBackpropTimer.ElapsedMilliseconds,
                                  Utils.FCUpdateSpeedsTimer.ElapsedMilliseconds, Utils.FCUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nBATCHNORM FC \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.BNFCForwardTimer.ElapsedMilliseconds, Utils.BNFCBackpropTimer.ElapsedMilliseconds,
                                  Utils.BNFCUpdateSpeedsTimer.ElapsedMilliseconds, Utils.BNFCUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nBATCHNORM CONV \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.BNConvForwardTimer.ElapsedMilliseconds, Utils.BNConvBackpropTimer.ElapsedMilliseconds,
                                  Utils.BNConvUpdateSpeedsTimer.ElapsedMilliseconds, Utils.BNConvUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nSOFTMAX \n\tForward: {0}ms", Utils.SoftmaxTimer.ElapsedMilliseconds);

                Utils.ResetTimers();
#endif
            }

            stopwatch.Stop();
        }
Exemplo n.º 3
0
        public static void Check(NeuralNetwork network, DataSet dataSet)
        {
            // Setup network

            network.Set("MiniBatchSize", miniBatchSize);
            network.InitializeParameters("random");
            network.Set("DropoutFC", 1.0);
            network.Set("Training", true);
            network.Set("EpochBeginning", true);

            // Get a mini-batch of data

            Sequence indicesSequence = new Sequence(dataSet.DataContainer.Count);

            indicesSequence.Shuffle();
            int[] miniBatch = indicesSequence.GetMiniBatchIndices(0, miniBatchSize);

            // Run network forward and backward

            network.InputLayer.FeedData(dataSet, miniBatch);
            network.ForwardPass("beginning", "end");
            List <int> trueLabels = new List <int>();

            for (int m = 0; m < miniBatchSize; m++)
            {
                trueLabels.Add(dataSet.DataContainer[miniBatch[m]].Label);
            }
            network.CrossEntropyGradient(dataSet, miniBatch);
            network.BackwardPass(0.0, 0.0, 0.0, 1e10); // no momentum, no learning rate, no weight decay

            // Re-forward pass (in case there are batch-norm layer)
            network.Set("PreInference", true);
            network.ForwardPass("beginning", "end");
            network.Set("Inference", true);

            for (int iLayer = 1; iLayer < network.NumberOfLayers; iLayer++)
            {
                //if (network.Layers[iLayer].Type != "Input" && network.Layers[iLayer].Type != "MaxPooling" && network.Layers[iLayer].Type != "ReLU" &&
                //    network.Layers[iLayer].Type != "SoftMax" && network.Layers[iLayer].Type != "Convolutional" && network.Layers[iLayer].Type != "FullyConnected"
                //    && network.Layers[iLayer].Type != "ELU")
                if (network.Layers[iLayer].Type == typeToCheck)
                {
                    Console.WriteLine("\nChecking gradients in layer {0} ({1})...", iLayer, network.Layers[iLayer].Type);
                    int    nChecks         = 0;
                    int    nErrors         = 0;
                    double cumulativeError = 0.0;

                    double[] parametersBackup   = network.Layers[iLayer].GetParameters();
                    double[] parameterGradients = network.Layers[iLayer].GetParameterGradients();
                    int      nParameters        = parametersBackup.Length;

                    // First parameters

                    Console.WriteLine("\n...with respect to PARAMETERS");
                    for (int j = 0; j < nParameters; j++)
                    {
                        // decrease jth parameter by EPSILON
                        double[] parametersMinus = new double[nParameters];
                        Array.Copy(parametersBackup, parametersMinus, nParameters);
                        parametersMinus[j] -= EPSILON;
                        network.Layers[iLayer].SetParameters(parametersMinus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresMinus = network.OutputLayer.OutputClassScores;
                        double          lossMinus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossMinus -= Math.Log(outputClassScoresMinus[m][trueLabel]); // score of true class in example m
                        }
                        lossMinus /= miniBatchSize;

                        // increse jth parameter by EPSILON
                        double[] parametersPlus = new double[nParameters];
                        Array.Copy(parametersBackup, parametersPlus, nParameters);
                        parametersPlus[j] += EPSILON;
                        network.Layers[iLayer].SetParameters(parametersPlus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresPlus = network.OutputLayer.OutputClassScores;
                        double          lossPlus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossPlus -= Math.Log(outputClassScoresPlus[m][trueLabel]); // score of true class in example m
                        }
                        lossPlus /= miniBatchSize;

                        // compute gradient numerically, trying to limit loss of significance!
                        //double orderOfMagnitude = Math.Floor(Math.Log10(lossPlus));
                        //lossPlus *= Math.Pow(10, -orderOfMagnitude);
                        //lossMinus *= Math.Pow(10, -orderOfMagnitude);
                        double gradientNumerical = (lossPlus - lossMinus) / (2 * EPSILON);
                        //gradientNumerical *= Math.Pow(10, orderOfMagnitude);

                        // retrieve gradient computed with backprop
                        double gradientBackprop = parameterGradients[j];

                        //if (Math.Abs(gradientNumerical) > EPSILON || Math.Abs(gradientBackprop) > EPSILON) // when the gradient is very small, finite arithmetics effects are too large => don't check
                        //{
                        nChecks++;

                        // compare the gradients, again trying to limit loss of significance!
                        //orderOfMagnitude = Math.Floor(Math.Log10(Math.Abs(gradientNumerical)));
                        //double gradientNumericalRescaled = gradientNumerical * Math.Pow(10, -orderOfMagnitude);
                        //double gradientBackpropRescaled = gradientBackprop * Math.Pow(10, -orderOfMagnitude);
                        //double error = Math.Abs(gradientNumericalRescaled - gradientBackpropRescaled) * Math.Pow(10, orderOfMagnitude);
                        double error         = Math.Abs(gradientNumerical - gradientBackprop);
                        double relativeError = error / Math.Max(Math.Abs(gradientNumerical), Math.Abs(gradientBackprop));
                        if (relativeError > MAX_RELATIVE_ERROR)
                        {
                            Console.Write("\nGradient check failed for parameter {0}\n", j);
                            Console.WriteLine("\tBackpropagation gradient: {0}", gradientBackprop);
                            Console.WriteLine("\tFinite difference gradient: {0}", gradientNumerical);
                            Console.WriteLine("\tRelative error: {0}", relativeError);

                            nErrors++;
                        }
                        cumulativeError = (relativeError + (nChecks - 1) * cumulativeError) / nChecks;
                        //}

                        // restore original weights before checking next gradient
                        network.Layers[iLayer].SetParameters(parametersBackup);
                    }

                    if (nChecks == 0)
                    {
                        Console.Write("\nAll gradients are zero... Something is probably wrong!");
                    }
                    else if (nErrors == 0)
                    {
                        Console.Write("\nGradient check 100% passed!");
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    else
                    {
                        Console.Write("\n{0} errors out of {1} checks.", nErrors, nChecks);
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    Console.Write("\n\n");
                    Console.Write("Press any key to continue...");
                    Console.Write("\n\n");
                    Console.ReadKey();

                    // Now inputs

                    nChecks         = 0;
                    nErrors         = 0;
                    cumulativeError = 0.0;

                    double[] inputBackup    = network.Layers[iLayer].GetInput();
                    double[] inputGradients = network.Layers[iLayer].GetInputGradients();
                    int      inputArraySize = inputBackup.Length;

                    Console.WriteLine("\n...with respect to INPUT");
                    for (int j = 0; j < inputArraySize; j++)
                    {
                        // decrease jth parameter by EPSILON
                        double[] inputMinus = new double[inputArraySize];
                        Array.Copy(inputBackup, inputMinus, inputArraySize);
                        inputMinus[j] -= EPSILON;
                        network.Layers[iLayer].SetInput(inputMinus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresMinus = network.OutputLayer.OutputClassScores;
                        double          lossMinus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossMinus -= Math.Log(outputClassScoresMinus[m][trueLabel]); // score of true class in example m
                        }
                        lossMinus /= miniBatchSize;

                        // increse jth parameter by EPSILON
                        double[] inputPlus = new double[inputArraySize];
                        Array.Copy(inputBackup, inputPlus, inputArraySize);
                        inputPlus[j] += EPSILON;
                        network.Layers[iLayer].SetInput(inputPlus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresPlus = network.OutputLayer.OutputClassScores;
                        double          lossPlus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossPlus -= Math.Log(outputClassScoresPlus[m][trueLabel]); // score of true class in example m
                        }
                        lossPlus /= miniBatchSize;

                        // compute gradient numerically
                        double gradientNumerical = (lossPlus - lossMinus) / (2 * EPSILON);


                        // retrieve gradient computed with backprop
                        double gradientBackprop = inputGradients[j] / miniBatchSize;
                        // NOTE: it is divided by miniBatchSize because HERE the loss is defined as Loss / miniBatchSize

                        //if (Math.Abs(gradientNumerical) > EPSILON || Math.Abs(gradientBackprop) > EPSILON) // when the gradient is very small, finite arithmetics effects are too large => don't check
                        //{
                        nChecks++;

                        // compare the gradients
                        double relativeError = Math.Abs(gradientNumerical - gradientBackprop) / Math.Max(Math.Abs(gradientNumerical), Math.Abs(gradientBackprop));
                        if (relativeError > MAX_RELATIVE_ERROR)
                        {
                            Console.Write("\nGradient check failed for input {0}\n", j);
                            Console.WriteLine("\tBackpropagation gradient: {0}", gradientBackprop);
                            Console.WriteLine("\tFinite difference gradient: {0}", gradientNumerical);
                            Console.WriteLine("\tRelative error: {0}", relativeError);

                            nErrors++;
                        }
                        cumulativeError = (relativeError + (nChecks - 1) * cumulativeError) / nChecks;
                        //}

                        // restore original input before checking next gradient
                        network.Layers[iLayer].SetInput(inputBackup);
                    }

                    if (nChecks == 0)
                    {
                        Console.Write("\nAll gradients are zero... Something is probably wrong!");
                    }
                    else if (nErrors == 0)
                    {
                        Console.Write("\nGradient check 100% passed!");
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    else
                    {
                        Console.Write("\n{0} errors out of {1} checks.", nErrors, nChecks);
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    Console.Write("\n\n");
                    Console.Write("Press any key to continue...");
                    Console.Write("\n\n");
                    Console.ReadKey();
                }
            }
        }