Ejemplo n.º 1
0
        public static void Train(NeuralNetwork network, DataSet trainingSet, DataSet validationSet)
        {
            // Initialize parameters or load them
            if (trainingMode == "new" || trainingMode == "New")
            {
                // Setup miniBatchSize
                network.Set("MiniBatchSize", miniBatchSize);
                network.InitializeParameters("random");
            }
            else if (trainingMode == "resume" || trainingMode == "Resume")
            {
                network.InitializeParameters("load");
            }
            else
            {
                throw new InvalidOperationException("Please set TrainingMode to either ''New'' or ''Resume''.");
            }

            // Set dropout
            network.Set("DropoutFC", dropoutFC);
            network.Set("DropoutConv", dropoutConv);
            network.Set("DropoutInput", dropoutInput);

            Sequence indicesSequence = new Sequence(trainingSet.DataContainer.Count);

            int[] miniBatch = new int[miniBatchSize];

            // Timers
            Stopwatch stopwatch     = Stopwatch.StartNew();
            Stopwatch stopwatchFwd  = Stopwatch.StartNew();
            Stopwatch stopwatchGrad = Stopwatch.StartNew();
            Stopwatch stopwatchBwd  = Stopwatch.StartNew();

            int  epoch      = 0;
            int  nBadEpochs = 0;
            int  consecutiveAnnealingCounter = 0;
            bool stopFlag = false;
            int  epochsRemainingToOutput = (evaluateBeforeTraining == true) ? 0 : consoleOutputLag;


            while (!stopFlag) // begin loop over training epochs
            {
                if (epochsRemainingToOutput == 0)
                {
                    /**************
                    * Evaluation *
                    **************/

                    // Pre inference (for batch-norm)
                    //network.Set("PreInference", true);
                    //Console.WriteLine("Re-computing batch-norm means and variances...");
                    //NetworkEvaluator.PreEvaluateNetwork(network, trainingSet);

                    // Evaluate on training set...
                    network.Set("Inference", true);
                    Console.WriteLine("Evaluating on TRAINING set...");
                    stopwatch.Restart();
                    NetworkEvaluator.EvaluateNetwork(network, trainingSet, out lossTraining, out errorTraining);
                    Console.WriteLine("\tLoss = {0}\n\tError = {1}\n\tEval runtime = {2}ms\n",
                                      lossTraining, errorTraining, stopwatch.ElapsedMilliseconds);
                    // ...and save loss and error to file
                    using (System.IO.StreamWriter trainingEpochOutputFile = new System.IO.StreamWriter(trainingEpochSavePath, true))
                    {
                        trainingEpochOutputFile.WriteLine(lossTraining.ToString() + "\t" + errorTraining.ToString());
                    }

                    // Evaluate on validation set...
                    if (validationSet != null)
                    {
                        Console.WriteLine("Evaluating on VALIDATION set...");
                        stopwatch.Restart();
                        NetworkEvaluator.EvaluateNetwork(network, validationSet, out newLossValidation, out newErrorValidation);
                        Console.WriteLine("\tLoss = {0}\n\tError = {1}\n\tEval runtime = {2}ms\n",
                                          newLossValidation, newErrorValidation, stopwatch.ElapsedMilliseconds);
                        // ...save loss and error to file
                        using (System.IO.StreamWriter validationEpochOutputFile = new System.IO.StreamWriter(validationEpochSavePath, true))
                        {
                            validationEpochOutputFile.WriteLine(newLossValidation.ToString() + "\t" + newErrorValidation.ToString());
                        }

                        if (newLossValidation < minLossValidation)
                        {
                            // nice, validation loss is decreasing!
                            minLossValidation = newLossValidation;
                            errorValidation   = newErrorValidation;

                            // Save network to file
                            Utils.SaveNetworkToFile(network, networkOutputFilePath);

                            // and keep training
                            nBadEpochs = 0;
                            consecutiveAnnealingCounter = 0;
                        }
                        else
                        {
                            nBadEpochs++;
                            Console.WriteLine("Loss on the validation set has been increasing for {0} epoch(s)...", nBadEpochs);
                            if (patience - nBadEpochs > 0)
                            {
                                Console.WriteLine("...I'll be patient for {0} more epoch(s)!", patience - nBadEpochs); // keep training
                            }
                            else
                            {
                                //Console.WriteLine("...and I've run out of patience! Training ends here.");
                                //stopFlag = true;
                                //break;

                                // Decrease learning rate
                                Console.WriteLine("...and I've run out of patience!");

                                if (consecutiveAnnealingCounter > maxConsecutiveAnnealings)
                                {
                                    Console.WriteLine("\nReached the numner of maximum consecutive annealings without progress. \nTraining ends here.");
                                    break;
                                }

                                Console.WriteLine("\nI'm annealing the learning rate:\n\tWas {0}\n\tSetting it to {1}.", learningRate, learningRate / learningRateDecayFactor);
                                learningRate /= learningRateDecayFactor;
                                consecutiveAnnealingCounter++;

                                Console.WriteLine("\nAnd I'm loading the network saved {0} epochs ago and resume the training from there.", patience);

                                string networkName = network.Name;
                                network = null; // this is BAD PRACTICE
                                GC.Collect();   // this is BAD PRACTICE
                                network = Utils.LoadNetworkFromFile("../../../../Results/Networks/", networkName);
                                network.Set("MiniBatchSize", miniBatchSize);
                                network.InitializeParameters("load");


                                nBadEpochs = 0;
                            }
                        }
                    }

                    // Restore dropout
                    network.Set("DropoutFC", dropoutFC);
                    network.Set("DropoutConv", dropoutConv);
                    network.Set("DropoutInput", dropoutInput);

                    epochsRemainingToOutput = consoleOutputLag;
                }
                epochsRemainingToOutput--;

                epoch++;

                if (epoch > maxTrainingEpochs)
                {
                    break;
                }

                /************
                * Training *
                ************/

                network.Set("Training", true);
                network.Set("EpochBeginning", true);

                Console.WriteLine("\nEpoch {0}...", epoch);


                stopwatch.Restart();
                stopwatchFwd.Reset();
                stopwatchGrad.Reset();
                stopwatchBwd.Reset();

                indicesSequence.Shuffle(); // shuffle examples order at every epoch

                int iMiniBatch = 0;
                // Run over mini-batches
                for (int iStartMiniBatch = 0; iStartMiniBatch < trainingSet.DataContainer.Count; iStartMiniBatch += miniBatchSize)
                {
                    // Feed a mini-batch to the network
                    miniBatch = indicesSequence.GetMiniBatchIndices(iStartMiniBatch, miniBatchSize);
                    network.InputLayer.FeedData(trainingSet, miniBatch);

                    // Forward pass
                    stopwatchFwd.Start();
                    network.ForwardPass("beginning", "end");
                    stopwatchFwd.Stop();

                    // Compute gradient and backpropagate
                    stopwatchGrad.Start();
                    network.CrossEntropyGradient(trainingSet, miniBatch);
                    stopwatchGrad.Stop();

                    // Backpropagate gradient and update parameters
                    stopwatchBwd.Start();
                    network.BackwardPass(learningRate, momentumCoefficient, weightDecayCoeff, weightMaxNorm);
                    stopwatchBwd.Stop();

                    iMiniBatch++;

                    CheckForKeyPress(ref network, ref stopFlag);
                    if (stopFlag)
                    {
                        break;
                    }
                } // end of training epoch

                Console.Write(" Training runtime = {0}ms\n", stopwatch.ElapsedMilliseconds);

                Console.WriteLine("Forward: {0}ms - Gradient: {1}ms - Backward: {2}ms\n",
                                  stopwatchFwd.ElapsedMilliseconds, stopwatchGrad.ElapsedMilliseconds, stopwatchBwd.ElapsedMilliseconds);

#if TIMING_LAYERS
                Console.WriteLine("\n Detailed runtimes::");

                Console.WriteLine("\nCONV: \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms \n\tPadUnpad: {4}ms",
                                  Utils.ConvForwardTimer.ElapsedMilliseconds, Utils.ConvBackpropTimer.ElapsedMilliseconds,
                                  Utils.ConvUpdateSpeedsTimer.ElapsedMilliseconds, Utils.ConvUpdateParametersTimer.ElapsedMilliseconds, Utils.ConvPadUnpadTimer.ElapsedMilliseconds);

                Console.WriteLine("\nPOOLING: \n\tForward: {0}ms \n\tBackprop: {1}ms",
                                  Utils.PoolingForwardTimer.ElapsedMilliseconds, Utils.PoolingBackpropTimer.ElapsedMilliseconds);

                Console.WriteLine("\nNONLINEARITIES: \n\tForward: {0}ms \n\tBackprop: {1}ms",
                                  Utils.NonlinearityForwardTimer.ElapsedMilliseconds, Utils.NonlinearityBackpropTimer.ElapsedMilliseconds);

                Console.WriteLine("\nFULLY CONNECTED: \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.FCForwardTimer.ElapsedMilliseconds, Utils.FCBackpropTimer.ElapsedMilliseconds,
                                  Utils.FCUpdateSpeedsTimer.ElapsedMilliseconds, Utils.FCUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nBATCHNORM FC \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.BNFCForwardTimer.ElapsedMilliseconds, Utils.BNFCBackpropTimer.ElapsedMilliseconds,
                                  Utils.BNFCUpdateSpeedsTimer.ElapsedMilliseconds, Utils.BNFCUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nBATCHNORM CONV \n\tForward: {0}ms \n\tBackprop: {1}ms \n\tUpdateSpeeds: {2}ms \n\tUpdateParameters: {3}ms",
                                  Utils.BNConvForwardTimer.ElapsedMilliseconds, Utils.BNConvBackpropTimer.ElapsedMilliseconds,
                                  Utils.BNConvUpdateSpeedsTimer.ElapsedMilliseconds, Utils.BNConvUpdateParametersTimer.ElapsedMilliseconds);

                Console.WriteLine("\nSOFTMAX \n\tForward: {0}ms", Utils.SoftmaxTimer.ElapsedMilliseconds);

                Utils.ResetTimers();
#endif
            }

            stopwatch.Stop();
        }
Ejemplo n.º 2
0
        public static void Check(NeuralNetwork network, DataSet dataSet)
        {
            // Setup network

            network.Set("MiniBatchSize", miniBatchSize);
            network.InitializeParameters("random");
            network.Set("DropoutFC", 1.0);
            network.Set("Training", true);
            network.Set("EpochBeginning", true);

            // Get a mini-batch of data

            Sequence indicesSequence = new Sequence(dataSet.DataContainer.Count);

            indicesSequence.Shuffle();
            int[] miniBatch = indicesSequence.GetMiniBatchIndices(0, miniBatchSize);

            // Run network forward and backward

            network.InputLayer.FeedData(dataSet, miniBatch);
            network.ForwardPass("beginning", "end");
            List <int> trueLabels = new List <int>();

            for (int m = 0; m < miniBatchSize; m++)
            {
                trueLabels.Add(dataSet.DataContainer[miniBatch[m]].Label);
            }
            network.CrossEntropyGradient(dataSet, miniBatch);
            network.BackwardPass(0.0, 0.0, 0.0, 1e10); // no momentum, no learning rate, no weight decay

            // Re-forward pass (in case there are batch-norm layer)
            network.Set("PreInference", true);
            network.ForwardPass("beginning", "end");
            network.Set("Inference", true);

            for (int iLayer = 1; iLayer < network.NumberOfLayers; iLayer++)
            {
                //if (network.Layers[iLayer].Type != "Input" && network.Layers[iLayer].Type != "MaxPooling" && network.Layers[iLayer].Type != "ReLU" &&
                //    network.Layers[iLayer].Type != "SoftMax" && network.Layers[iLayer].Type != "Convolutional" && network.Layers[iLayer].Type != "FullyConnected"
                //    && network.Layers[iLayer].Type != "ELU")
                if (network.Layers[iLayer].Type == typeToCheck)
                {
                    Console.WriteLine("\nChecking gradients in layer {0} ({1})...", iLayer, network.Layers[iLayer].Type);
                    int    nChecks         = 0;
                    int    nErrors         = 0;
                    double cumulativeError = 0.0;

                    double[] parametersBackup   = network.Layers[iLayer].GetParameters();
                    double[] parameterGradients = network.Layers[iLayer].GetParameterGradients();
                    int      nParameters        = parametersBackup.Length;

                    // First parameters

                    Console.WriteLine("\n...with respect to PARAMETERS");
                    for (int j = 0; j < nParameters; j++)
                    {
                        // decrease jth parameter by EPSILON
                        double[] parametersMinus = new double[nParameters];
                        Array.Copy(parametersBackup, parametersMinus, nParameters);
                        parametersMinus[j] -= EPSILON;
                        network.Layers[iLayer].SetParameters(parametersMinus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresMinus = network.OutputLayer.OutputClassScores;
                        double          lossMinus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossMinus -= Math.Log(outputClassScoresMinus[m][trueLabel]); // score of true class in example m
                        }
                        lossMinus /= miniBatchSize;

                        // increse jth parameter by EPSILON
                        double[] parametersPlus = new double[nParameters];
                        Array.Copy(parametersBackup, parametersPlus, nParameters);
                        parametersPlus[j] += EPSILON;
                        network.Layers[iLayer].SetParameters(parametersPlus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresPlus = network.OutputLayer.OutputClassScores;
                        double          lossPlus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossPlus -= Math.Log(outputClassScoresPlus[m][trueLabel]); // score of true class in example m
                        }
                        lossPlus /= miniBatchSize;

                        // compute gradient numerically, trying to limit loss of significance!
                        //double orderOfMagnitude = Math.Floor(Math.Log10(lossPlus));
                        //lossPlus *= Math.Pow(10, -orderOfMagnitude);
                        //lossMinus *= Math.Pow(10, -orderOfMagnitude);
                        double gradientNumerical = (lossPlus - lossMinus) / (2 * EPSILON);
                        //gradientNumerical *= Math.Pow(10, orderOfMagnitude);

                        // retrieve gradient computed with backprop
                        double gradientBackprop = parameterGradients[j];

                        //if (Math.Abs(gradientNumerical) > EPSILON || Math.Abs(gradientBackprop) > EPSILON) // when the gradient is very small, finite arithmetics effects are too large => don't check
                        //{
                        nChecks++;

                        // compare the gradients, again trying to limit loss of significance!
                        //orderOfMagnitude = Math.Floor(Math.Log10(Math.Abs(gradientNumerical)));
                        //double gradientNumericalRescaled = gradientNumerical * Math.Pow(10, -orderOfMagnitude);
                        //double gradientBackpropRescaled = gradientBackprop * Math.Pow(10, -orderOfMagnitude);
                        //double error = Math.Abs(gradientNumericalRescaled - gradientBackpropRescaled) * Math.Pow(10, orderOfMagnitude);
                        double error         = Math.Abs(gradientNumerical - gradientBackprop);
                        double relativeError = error / Math.Max(Math.Abs(gradientNumerical), Math.Abs(gradientBackprop));
                        if (relativeError > MAX_RELATIVE_ERROR)
                        {
                            Console.Write("\nGradient check failed for parameter {0}\n", j);
                            Console.WriteLine("\tBackpropagation gradient: {0}", gradientBackprop);
                            Console.WriteLine("\tFinite difference gradient: {0}", gradientNumerical);
                            Console.WriteLine("\tRelative error: {0}", relativeError);

                            nErrors++;
                        }
                        cumulativeError = (relativeError + (nChecks - 1) * cumulativeError) / nChecks;
                        //}

                        // restore original weights before checking next gradient
                        network.Layers[iLayer].SetParameters(parametersBackup);
                    }

                    if (nChecks == 0)
                    {
                        Console.Write("\nAll gradients are zero... Something is probably wrong!");
                    }
                    else if (nErrors == 0)
                    {
                        Console.Write("\nGradient check 100% passed!");
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    else
                    {
                        Console.Write("\n{0} errors out of {1} checks.", nErrors, nChecks);
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    Console.Write("\n\n");
                    Console.Write("Press any key to continue...");
                    Console.Write("\n\n");
                    Console.ReadKey();

                    // Now inputs

                    nChecks         = 0;
                    nErrors         = 0;
                    cumulativeError = 0.0;

                    double[] inputBackup    = network.Layers[iLayer].GetInput();
                    double[] inputGradients = network.Layers[iLayer].GetInputGradients();
                    int      inputArraySize = inputBackup.Length;

                    Console.WriteLine("\n...with respect to INPUT");
                    for (int j = 0; j < inputArraySize; j++)
                    {
                        // decrease jth parameter by EPSILON
                        double[] inputMinus = new double[inputArraySize];
                        Array.Copy(inputBackup, inputMinus, inputArraySize);
                        inputMinus[j] -= EPSILON;
                        network.Layers[iLayer].SetInput(inputMinus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresMinus = network.OutputLayer.OutputClassScores;
                        double          lossMinus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossMinus -= Math.Log(outputClassScoresMinus[m][trueLabel]); // score of true class in example m
                        }
                        lossMinus /= miniBatchSize;

                        // increse jth parameter by EPSILON
                        double[] inputPlus = new double[inputArraySize];
                        Array.Copy(inputBackup, inputPlus, inputArraySize);
                        inputPlus[j] += EPSILON;
                        network.Layers[iLayer].SetInput(inputPlus);
                        // then run network forward and compute loss
                        network.ForwardPass(iLayer, "end");
                        List <double[]> outputClassScoresPlus = network.OutputLayer.OutputClassScores;
                        double          lossPlus = 0;
                        for (int m = 0; m < miniBatchSize; m++)
                        {
                            int trueLabel = trueLabels[m];
                            lossPlus -= Math.Log(outputClassScoresPlus[m][trueLabel]); // score of true class in example m
                        }
                        lossPlus /= miniBatchSize;

                        // compute gradient numerically
                        double gradientNumerical = (lossPlus - lossMinus) / (2 * EPSILON);


                        // retrieve gradient computed with backprop
                        double gradientBackprop = inputGradients[j] / miniBatchSize;
                        // NOTE: it is divided by miniBatchSize because HERE the loss is defined as Loss / miniBatchSize

                        //if (Math.Abs(gradientNumerical) > EPSILON || Math.Abs(gradientBackprop) > EPSILON) // when the gradient is very small, finite arithmetics effects are too large => don't check
                        //{
                        nChecks++;

                        // compare the gradients
                        double relativeError = Math.Abs(gradientNumerical - gradientBackprop) / Math.Max(Math.Abs(gradientNumerical), Math.Abs(gradientBackprop));
                        if (relativeError > MAX_RELATIVE_ERROR)
                        {
                            Console.Write("\nGradient check failed for input {0}\n", j);
                            Console.WriteLine("\tBackpropagation gradient: {0}", gradientBackprop);
                            Console.WriteLine("\tFinite difference gradient: {0}", gradientNumerical);
                            Console.WriteLine("\tRelative error: {0}", relativeError);

                            nErrors++;
                        }
                        cumulativeError = (relativeError + (nChecks - 1) * cumulativeError) / nChecks;
                        //}

                        // restore original input before checking next gradient
                        network.Layers[iLayer].SetInput(inputBackup);
                    }

                    if (nChecks == 0)
                    {
                        Console.Write("\nAll gradients are zero... Something is probably wrong!");
                    }
                    else if (nErrors == 0)
                    {
                        Console.Write("\nGradient check 100% passed!");
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    else
                    {
                        Console.Write("\n{0} errors out of {1} checks.", nErrors, nChecks);
                        Console.Write("\nAverage error = {0}", cumulativeError);
                    }
                    Console.Write("\n\n");
                    Console.Write("Press any key to continue...");
                    Console.Write("\n\n");
                    Console.ReadKey();
                }
            }
        }