//////////////////////////////////////////////////////////////////////////////////////////////////// /// <summary> Do a learning process with a batch. </summary> /// /// <param name="functionStack"> Stack of functions. This cannot be null. </param> /// <param name="input"> The input. This may be null. </param> /// <param name="teach"> The teach. This may be null. </param> /// <param name="lossFunction"> The loss function. This cannot be null. </param> /// <param name="isUpdate"> (Optional) True if this object is update. </param> /// /// <returns> A Real. </returns> //////////////////////////////////////////////////////////////////////////////////////////////////// public static Real Train([NotNull] SortedFunctionStack functionStack, [CanBeNull] NdArray input, [CanBeNull] NdArray teach, [NotNull] LossFunction lossFunction, bool isUpdate = true, bool verbose = true) { if (verbose) { RILogManager.Default?.EnterMethod("Training " + functionStack.Name); } // for preserving the error of the result if (verbose) { RILogManager.Default?.SendDebug("Forward propagation"); } NdArray[] result = functionStack.Forward(verbose, input); if (verbose) { RILogManager.Default?.SendDebug("Evaluating loss"); } Real sumLoss = lossFunction.Evaluate(result, teach); // Run Backward batch if (verbose) { RILogManager.Default?.SendDebug("Backward propagation"); } functionStack.Backward(verbose, result); if (isUpdate) { if (verbose) { RILogManager.Default?.SendDebug("Updating stack"); } functionStack.Update(); } if (verbose) { RILogManager.Default?.ExitMethod("Training " + functionStack.Name); RILogManager.Default?.ViewerSendWatch("Local Loss", sumLoss.ToString(), sumLoss); } return(sumLoss); }
//////////////////////////////////////////////////////////////////////////////////////////////////// /// <summary> Determination of accuracy. </summary> /// /// <param name="functionStack"> Stack of functions. This cannot be null. </param> /// <param name="x"> A NdArray to process. This cannot be null. </param> /// <param name="y"> A NdArray to process. This may be null. </param> /// /// <returns> A double. </returns> //////////////////////////////////////////////////////////////////////////////////////////////////// public static double Accuracy([NotNull] SortedFunctionStack functionStack, [NotNull] NdArray x, [CanBeNull] NdArray y, bool verbose = true) { double matchCount = 0; Stopwatch sw = new Stopwatch(); sw.Start(); if (verbose) { RILogManager.Default?.SendDebug("Running Forecast for Accuracy Prediction on " + functionStack.Name); } NdArray forwardResult = functionStack.Predict(verbose, x)[0]; for (int b = 0; b < x.BatchCount; b++) { Real maxval = forwardResult.Data[b * forwardResult.Length]; int maxindex = 0; for (int i = 0; i < forwardResult.Length; i++) { if (maxval < forwardResult.Data[i + b * forwardResult.Length]) { maxval = forwardResult.Data[i + b * forwardResult.Length]; maxindex = i; } } if (maxindex == (int)y.Data[b * y.Length]) { matchCount++; } } sw.Stop(); if (verbose) { RILogManager.Default?.SendDebug("Accuracy Prediction took " + Helpers.FormatTimeSpan(sw.Elapsed) + "ms"); RILogManager.Default?.ViewerSendWatch("Accuracy", ((matchCount / x.BatchCount) * 100).ToString() + "%", matchCount / x.BatchCount); } return(matchCount / x.BatchCount); }
//////////////////////////////////////////////////////////////////////////////////////////////////// /// <summary> Saves. </summary> /// /// <param name="functionStack"> Stack of functions. This cannot be null. </param> /// <param name="fileName"> Filename of the file. This cannot be null. </param> //////////////////////////////////////////////////////////////////////////////////////////////////// public static void Save([NotNull] SortedFunctionStack functionStack, [NotNull] string fileName) { Ensure.Argument(fileName).NotNullOrWhiteSpace("fileName is null"); Ensure.Argument(functionStack).NotNull("functionStack is null"); NetDataContractSerializer bf = new NetDataContractSerializer(); RILogManager.Default?.SendDebug("Saving model " + functionStack.Name + " to " + fileName); try { using (Stream stream = File.OpenWrite(fileName)) { bf.Serialize(stream, functionStack); } FileStream fs = File.OpenRead(fileName); RILogManager.Default?.SendDebug("Model " + functionStack.Name + " saved, size is " + fs.Length.ToString("N0") + " bytes"); } catch (Exception ex) { RILogManager.Default?.SendException(ex.Message, ex); } }
public static void Run() { int neuronCount = 28; RILogManager.Default?.SendDebug("MNIST Data Loading..."); MnistData mnistData = new MnistData(neuronCount); RILogManager.Default.SendInformation("Training Start, creating function stack."); SortedFunctionStack nn = new SortedFunctionStack(); SortedList <Function> functions = new SortedList <Function>(); ParallelOptions po = new ParallelOptions(); po.MaxDegreeOfParallelism = 4; for (int x = 0; x < numLayers; x++) { Application.DoEvents(); functions.Add(new Linear(true, neuronCount * neuronCount, N, name: $"l{x} Linear")); functions.Add(new BatchNormalization(true, N, name: $"l{x} BatchNorm")); functions.Add(new ReLU(name: $"l{x} ReLU")); RILogManager.Default.ViewerSendWatch("Total Layers", (x + 1)); } ; RILogManager.Default.SendInformation("Adding Output Layer"); Application.DoEvents(); nn.Add(new Linear(true, N, 10, noBias: false, name: $"l{numLayers + 1} Linear")); RILogManager.Default.ViewerSendWatch("Total Layers", numLayers); RILogManager.Default.SendInformation("Setting Optimizer to AdaGrad"); nn.SetOptimizer(new AdaGrad()); Application.DoEvents(); RunningStatistics stats = new RunningStatistics(); Histogram lossHistogram = new Histogram(); Histogram accuracyHistogram = new Histogram(); Real totalLoss = 0; long totalLossCounter = 0; Real highestAccuracy = 0; Real bestLocalLoss = 0; Real bestTotalLoss = 0; for (int epoch = 0; epoch < 3; epoch++) { RILogManager.Default?.SendDebug("epoch " + (epoch + 1)); RILogManager.Default.SendInformation("epoch " + (epoch + 1)); RILogManager.Default.ViewerSendWatch("epoch", (epoch + 1)); Application.DoEvents(); for (int i = 1; i < TRAIN_DATA_COUNT + 1; i++) { Application.DoEvents(); TestDataSet datasetX = mnistData.GetRandomXSet(BATCH_DATA_COUNT, neuronCount, neuronCount); Real sumLoss = Trainer.Train(nn, datasetX.Data, datasetX.Label, new SoftmaxCrossEntropy()); totalLoss += sumLoss; totalLossCounter++; stats.Push(sumLoss); lossHistogram.AddBucket(new Bucket(-10, 10)); accuracyHistogram.AddBucket(new Bucket(-10.0, 10)); if (sumLoss < bestLocalLoss && !double.IsNaN(sumLoss)) { bestLocalLoss = sumLoss; } if (stats.Mean < bestTotalLoss && !double.IsNaN(sumLoss)) { bestTotalLoss = stats.Mean; } try { lossHistogram.AddData(sumLoss); } catch (Exception) { } if (i % 20 == 0) { RILogManager.Default.ViewerSendWatch("Batch Count ", i); RILogManager.Default.ViewerSendWatch("Total/Mean loss", stats.Mean); RILogManager.Default.ViewerSendWatch("Local loss", sumLoss); RILogManager.Default.SendInformation("Batch Count " + i + "/" + TRAIN_DATA_COUNT + ", epoch " + epoch + 1); RILogManager.Default.SendInformation("Total/Mean loss " + stats.Mean); RILogManager.Default.SendInformation("Local loss " + sumLoss); Application.DoEvents(); RILogManager.Default?.SendDebug("Testing..."); TestDataSet datasetY = mnistData.GetRandomYSet(TEST_DATA_COUNT, 28); Real accuracy = Trainer.Accuracy(nn, datasetY?.Data, datasetY.Label); if (accuracy > highestAccuracy) { highestAccuracy = accuracy; } RILogManager.Default?.SendDebug("Accuracy: " + accuracy); RILogManager.Default.ViewerSendWatch("Best Accuracy: ", highestAccuracy); RILogManager.Default.ViewerSendWatch("Best Total Loss ", bestTotalLoss); RILogManager.Default.ViewerSendWatch("Best Local Loss ", bestLocalLoss); Application.DoEvents(); try { accuracyHistogram.AddData(accuracy); } catch (Exception) { } } } } ModelIO.Save(nn, Application.StartupPath + "\\test20.nn"); RILogManager.Default?.SendDebug("Best Accuracy: " + highestAccuracy); RILogManager.Default?.SendDebug("Best Total Loss " + bestTotalLoss); RILogManager.Default?.SendDebug("Best Local Loss " + bestLocalLoss); RILogManager.Default.ViewerSendWatch("Best Accuracy: ", highestAccuracy); RILogManager.Default.ViewerSendWatch("Best Total Loss ", bestTotalLoss); RILogManager.Default.ViewerSendWatch("Best Local Loss ", bestLocalLoss); }