public static List <(Vector <double> Image, Vector <double> Label)> CombineImagesAndLabels(List <byte[]> images, List <Vector <double> > labels) { if (images == null | labels == null) { throw new ArgumentNullException(); } if (images.Count != labels.Count) { throw new ArgumentException("List lengths don't match."); } List <(Vector <double>, Vector <double>)> ret = new List <(Vector <double>, Vector <double>)>(); for (int i = 0; i < images.Count; i++) { ret.Add((MINSTDataLoader.Normalize(images[i]), labels[i])); } return(ret); }
static void Main(string[] args) { const string dataRoot = @"C:\HaishiRooster\Data\MINST"; //Load training data var trainingImages = MINSTDataLoader.LoadImages(Path.Combine(dataRoot, "train-images-idx3-ubyte.gz")); var trainingLabels = MINSTDataLoader.LoadLabels(Path.Combine(dataRoot, "train-labels-idx1-ubyte.gz")); var trainingSet = MINSTDataLoader.CombineImagesAndLabels(trainingImages, trainingLabels); //Load test data var testingImages = MINSTDataLoader.LoadImages(Path.Combine(dataRoot, "t10k-images-idx3-ubyte.gz")); var testingLabels = MINSTDataLoader.LoadLabels(Path.Combine(dataRoot, "t10k-labels-idx1-ubyte.gz")); var testingSet = MINSTDataLoader.CombineImagesAndLabels(testingImages, testingLabels); //Print training set information and a sample image Console.WriteLine("Training set size: " + trainingSet.Count); Console.WriteLine("Testing set size: " + testingSet.Count); int index = mRand.Next(trainingSet.Count); Console.WriteLine(string.Format("Here's a random picture from the training set: #{0} ({1})", index, convertToByte(trainingSet[index].Label))); MINSTDataVisualizer.PrintImage(trainingSet[index].Image, 28); Network network; if (args.Length > 0) { using (StreamReader reader = new StreamReader(File.OpenRead(args[0]))) { network = Network.Load(reader); } Console.WriteLine("\nNetwork is loaded!"); dumpHyperParameters(network.HyperParameters); } else { ////MSR //HyperParameters hyperParameters = new HyperParameters { CostFunctionName = "QuadraticCost", Epochs = 30, MiniBatchSize = 10, LearningRate = 3, TestSize = 10, AutoSave = true, AutoSaveThreshold = 0.951 }; //network = new Network(hyperParameters, 784, 30, 10); ////Cross-entropy //HyperParameters hyperParameters = new HyperParameters { CostFunctionName = "CrossEntropyCost", Epochs = 30, MiniBatchSize = 10, LearningRate = 0.5, TestSize = testingSet.Count, AutoSave = true, AutoSaveThreshold = 0.967 }; //network = new Network(hyperParameters, 784, 100, 10); ////Cross-entropy with regulation //HyperParameters hyperParameters = new HyperParameters { CostFunctionName = "CrossEntropyCost", Epochs = 60, MiniBatchSize = 10, LearningRate = 0.1, RegulationLambda = 5.0, TestSize = testingSet.Count, AutoSave = true, AutoSaveThreshold = 0.98 }; //network = new Network(hyperParameters, 784, 100, 10); ////Cross-entropy with regulation - 120 epochs //HyperParameters hyperParameters = new HyperParameters { CostFunctionName = "CrossEntropyCost", Epochs = 120, MiniBatchSize = 10, LearningRate = 0.1, RegulationLambda = 5.0, TestSize = testingSet.Count, AutoSave = true, AutoSaveThreshold = 0.98 }; //network = new Network(hyperParameters, 784, 100, 10); ////Cross-entropy with regulation - 120 epochs - with dropouts HyperParameters hyperParameters = new HyperParameters { CostFunctionName = "CrossEntropyCost", Epochs = 240, MiniBatchSize = 10, LearningRate = 1, RegulationLambda = 5.0, TestSize = testingSet.Count, AutoSave = true, AutoSaveThreshold = 0.98, UseDropouts = true }; network = new Network(hyperParameters, 784, 100, 10); hookupEvents(network); dumpHyperParameters(hyperParameters); //Train the network network.Train(trainingSet, (actual, expected) => { return(convertToByte(actual) == convertToByte(expected)); }, testingSet); Console.WriteLine("\nNetwork is trained!"); } //Now validate while (true) { Console.Write(string.Format("\nPlease enter a test image index [0 - {0}]. Enter '-1' to evaluate all test images. Enter '-2' to exit:", testingSet.Count - 1)); if (int.TryParse(Console.ReadLine(), out index)) { if (index == -1) { int count = 0; Console.Write(string.Format("Detecting {0} pictures", testingSet.Count)); for (int i = 0; i < testingSet.Count; i++) { var detection = network.Detect(testingSet[i].Image); if (convertToByte(detection) == convertToByte(testingSet[i].Label)) { Console.Write("."); count++; } else { Console.Write("X"); } } Console.WriteLine("\nDetected {0} out of {1} pictures, correct rate is: {2:0.0%}", count, testingSet.Count, count * 1.0 / testingSet.Count); } else if (index >= 0 && index < testingSet.Count) { Console.WriteLine(string.Format("Test image: #{0} ({1})", index, convertToByte(testingSet[index].Label))); MINSTDataVisualizer.PrintImage(testingSet[index].Image, 28); var detection = network.Detect(testingSet[index].Image); Console.WriteLine(string.Format("\nDetected number: {0} - {1}", detection, convertToByte(detection) == convertToByte(testingSet[index].Label) ? "SUCCESS!" : "FAIL!")); } else if (index == -2) { break; } else if (index == -3) { Console.Write("Please enter file name: "); string fileName = Console.ReadLine(); using (StreamWriter writer = new StreamWriter(File.Create(fileName))) { network.Save(writer); } } } } }