Пример #1
0
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        /// <summary>   Do a learning process with a batch. </summary>
        ///
        /// <param name="functionStack">    Stack of functions. This cannot be null. </param>
        /// <param name="input">            The input. This may be null. </param>
        /// <param name="teach">            The teach. This may be null. </param>
        /// <param name="lossFunction">     The loss function. This cannot be null. </param>
        /// <param name="isUpdate">         (Optional) True if this object is update. </param>
        ///
        /// <returns>   A Real. </returns>
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        public static Real Train([NotNull] SortedFunctionStack functionStack, [CanBeNull] NdArray input,
                                 [CanBeNull] NdArray teach, [NotNull] LossFunction lossFunction, bool isUpdate = true,
                                 bool verbose = true)
        {
            if (verbose)
            {
                RILogManager.Default?.EnterMethod("Training " + functionStack.Name);
            }
            // for preserving the error of the result
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Forward propagation");
            }
            NdArray[] result = functionStack.Forward(verbose, input);
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Evaluating loss");
            }
            Real sumLoss = lossFunction.Evaluate(result, teach);

            // Run Backward batch
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Backward propagation");
            }
            functionStack.Backward(verbose, result);

            if (isUpdate)
            {
                if (verbose)
                {
                    RILogManager.Default?.SendDebug("Updating stack");
                }
                functionStack.Update();
            }

            if (verbose)
            {
                RILogManager.Default?.ExitMethod("Training " + functionStack.Name);
                RILogManager.Default?.ViewerSendWatch("Local Loss", sumLoss.ToString(), sumLoss);
            }

            return(sumLoss);
        }
Пример #2
0
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        /// <summary>   Determination of accuracy. </summary>
        ///
        /// <param name="functionStack">    Stack of functions. This cannot be null. </param>
        /// <param name="x">                A NdArray to process. This cannot be null. </param>
        /// <param name="y">                A NdArray to process. This may be null. </param>
        ///
        /// <returns>   A double. </returns>
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        public static double Accuracy([NotNull] SortedFunctionStack functionStack, [NotNull] NdArray x, [CanBeNull] NdArray y,
                                      bool verbose = true)
        {
            double    matchCount = 0;
            Stopwatch sw         = new Stopwatch();

            sw.Start();

            if (verbose)
            {
                RILogManager.Default?.SendDebug("Running Forecast for Accuracy Prediction on " + functionStack.Name);
            }
            NdArray forwardResult = functionStack.Predict(verbose, x)[0];

            for (int b = 0; b < x.BatchCount; b++)
            {
                Real maxval   = forwardResult.Data[b * forwardResult.Length];
                int  maxindex = 0;

                for (int i = 0; i < forwardResult.Length; i++)
                {
                    if (maxval < forwardResult.Data[i + b * forwardResult.Length])
                    {
                        maxval   = forwardResult.Data[i + b * forwardResult.Length];
                        maxindex = i;
                    }
                }

                if (maxindex == (int)y.Data[b * y.Length])
                {
                    matchCount++;
                }
            }

            sw.Stop();
            if (verbose)
            {
                RILogManager.Default?.SendDebug("Accuracy Prediction took " + Helpers.FormatTimeSpan(sw.Elapsed) + "ms");
                RILogManager.Default?.ViewerSendWatch("Accuracy", ((matchCount / x.BatchCount) * 100).ToString() + "%",
                                                      matchCount / x.BatchCount);
            }

            return(matchCount / x.BatchCount);
        }
Пример #3
0
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        /// <summary>   Saves. </summary>
        ///
        /// <param name="functionStack">    Stack of functions. This cannot be null. </param>
        /// <param name="fileName">         Filename of the file. This cannot be null. </param>
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        public static void Save([NotNull] SortedFunctionStack functionStack, [NotNull] string fileName)
        {
            Ensure.Argument(fileName).NotNullOrWhiteSpace("fileName is null");
            Ensure.Argument(functionStack).NotNull("functionStack is null");

            NetDataContractSerializer bf = new NetDataContractSerializer();

            RILogManager.Default?.SendDebug("Saving model " + functionStack.Name + " to " + fileName);

            try
            {
                using (Stream stream = File.OpenWrite(fileName))
                {
                    bf.Serialize(stream, functionStack);
                }
                FileStream fs = File.OpenRead(fileName);
                RILogManager.Default?.SendDebug("Model " + functionStack.Name + " saved, size is " + fs.Length.ToString("N0") + " bytes");
            }
            catch (Exception ex)
            {
                RILogManager.Default?.SendException(ex.Message, ex);
            }
        }
Пример #4
0
        public static void Run()
        {
            int neuronCount = 28;

            RILogManager.Default?.SendDebug("MNIST Data Loading...");
            MnistData mnistData = new MnistData(neuronCount);

            RILogManager.Default.SendInformation("Training Start, creating function stack.");

            SortedFunctionStack   nn        = new SortedFunctionStack();
            SortedList <Function> functions = new SortedList <Function>();

            ParallelOptions po = new ParallelOptions();

            po.MaxDegreeOfParallelism = 4;

            for (int x = 0; x < numLayers; x++)
            {
                Application.DoEvents();

                functions.Add(new Linear(true, neuronCount * neuronCount, N, name: $"l{x} Linear"));
                functions.Add(new BatchNormalization(true, N, name: $"l{x} BatchNorm"));
                functions.Add(new ReLU(name: $"l{x} ReLU"));
                RILogManager.Default.ViewerSendWatch("Total Layers", (x + 1));
            }
            ;

            RILogManager.Default.SendInformation("Adding Output Layer");
            Application.DoEvents();
            nn.Add(new Linear(true, N, 10, noBias: false, name: $"l{numLayers + 1} Linear"));
            RILogManager.Default.ViewerSendWatch("Total Layers", numLayers);


            RILogManager.Default.SendInformation("Setting Optimizer to AdaGrad");
            nn.SetOptimizer(new AdaGrad());
            Application.DoEvents();

            RunningStatistics stats             = new RunningStatistics();
            Histogram         lossHistogram     = new Histogram();
            Histogram         accuracyHistogram = new Histogram();
            Real totalLoss        = 0;
            long totalLossCounter = 0;
            Real highestAccuracy  = 0;
            Real bestLocalLoss    = 0;
            Real bestTotalLoss    = 0;

            for (int epoch = 0; epoch < 3; epoch++)
            {
                RILogManager.Default?.SendDebug("epoch " + (epoch + 1));
                RILogManager.Default.SendInformation("epoch " + (epoch + 1));
                RILogManager.Default.ViewerSendWatch("epoch", (epoch + 1));
                Application.DoEvents();

                for (int i = 1; i < TRAIN_DATA_COUNT + 1; i++)
                {
                    Application.DoEvents();

                    TestDataSet datasetX = mnistData.GetRandomXSet(BATCH_DATA_COUNT, neuronCount, neuronCount);

                    Real sumLoss = Trainer.Train(nn, datasetX.Data, datasetX.Label, new SoftmaxCrossEntropy());
                    totalLoss += sumLoss;
                    totalLossCounter++;

                    stats.Push(sumLoss);
                    lossHistogram.AddBucket(new Bucket(-10, 10));
                    accuracyHistogram.AddBucket(new Bucket(-10.0, 10));

                    if (sumLoss < bestLocalLoss && !double.IsNaN(sumLoss))
                    {
                        bestLocalLoss = sumLoss;
                    }
                    if (stats.Mean < bestTotalLoss && !double.IsNaN(sumLoss))
                    {
                        bestTotalLoss = stats.Mean;
                    }

                    try
                    {
                        lossHistogram.AddData(sumLoss);
                    }
                    catch (Exception)
                    {
                    }

                    if (i % 20 == 0)
                    {
                        RILogManager.Default.ViewerSendWatch("Batch Count ", i);
                        RILogManager.Default.ViewerSendWatch("Total/Mean loss", stats.Mean);
                        RILogManager.Default.ViewerSendWatch("Local loss", sumLoss);
                        RILogManager.Default.SendInformation("Batch Count " + i + "/" + TRAIN_DATA_COUNT + ", epoch " + epoch + 1);
                        RILogManager.Default.SendInformation("Total/Mean loss " + stats.Mean);
                        RILogManager.Default.SendInformation("Local loss " + sumLoss);
                        Application.DoEvents();


                        RILogManager.Default?.SendDebug("Testing...");

                        TestDataSet datasetY = mnistData.GetRandomYSet(TEST_DATA_COUNT, 28);
                        Real        accuracy = Trainer.Accuracy(nn, datasetY?.Data, datasetY.Label);
                        if (accuracy > highestAccuracy)
                        {
                            highestAccuracy = accuracy;
                        }

                        RILogManager.Default?.SendDebug("Accuracy: " + accuracy);

                        RILogManager.Default.ViewerSendWatch("Best Accuracy: ", highestAccuracy);
                        RILogManager.Default.ViewerSendWatch("Best Total Loss ", bestTotalLoss);
                        RILogManager.Default.ViewerSendWatch("Best Local Loss ", bestLocalLoss);
                        Application.DoEvents();

                        try
                        {
                            accuracyHistogram.AddData(accuracy);
                        }
                        catch (Exception)
                        {
                        }
                    }
                }
            }

            ModelIO.Save(nn, Application.StartupPath + "\\test20.nn");
            RILogManager.Default?.SendDebug("Best Accuracy: " + highestAccuracy);
            RILogManager.Default?.SendDebug("Best Total Loss " + bestTotalLoss);
            RILogManager.Default?.SendDebug("Best Local Loss " + bestLocalLoss);
            RILogManager.Default.ViewerSendWatch("Best Accuracy: ", highestAccuracy);
            RILogManager.Default.ViewerSendWatch("Best Total Loss ", bestTotalLoss);
            RILogManager.Default.ViewerSendWatch("Best Local Loss ", bestLocalLoss);
        }