public bool Load(int validationSize = 0, Boolean dataAugmentation = false) { Directory.CreateDirectory(mnistFolder); var trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile); var trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile); var testingLabelFilePath = Path.Combine(mnistFolder, testingLabelFile); var testingImageFilePath = Path.Combine(mnistFolder, testingImageFile); // Download Mnist files if needed Console.WriteLine("Downloading Mnist training files..."); DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath); DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath); Console.WriteLine("Downloading Mnist testing files..."); DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath); DownloadFile(urlMnist + testingImageFile, testingImageFilePath); // Load data Console.WriteLine("Loading the datasets..."); var train_images = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath); var testing_images = MnistReader.Load(testingLabelFilePath, testingImageFilePath); if (dataAugmentation) { train_images = (new ImageTransformation()).DataAugmentation(train_images); } var validation_images = train_images.GetRange(train_images.Count - validationSize, validationSize); train_images = train_images.GetRange(0, train_images.Count - validationSize); if (train_images.Count == 0 || testing_images.Count == 0 || (validationSize > 0 && validation_images.Count == 0)) { Console.WriteLine("Missing Mnist training/testing files."); Console.ReadKey(); return(false); } this.Train = new DataSet(train_images); this.Test = new DataSet(testing_images); this.Validation = new DataSet(validation_images); return(true); }
private void MnistDemo() { Directory.CreateDirectory(mnistFolder); string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile); string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile); string testingLabelFilePath = Path.Combine(mnistFolder, testingLabelFile); string testingImageFilePath = Path.Combine(mnistFolder, testingImageFile); // Download Mnist files if needed Console.WriteLine("Downloading Mnist training files..."); DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath); DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath); Console.WriteLine("Downloading Mnist testing files..."); DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath); DownloadFile(urlMnist + testingImageFile, testingImageFilePath); // Load data Console.WriteLine("Loading the datasets..."); this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath); this.testing = MnistReader.Load(testingLabelFilePath, testingImageFilePath); if (this.training.Count == 0 || this.testing.Count == 0) { Console.WriteLine("Missing Mnist training/testing files."); Console.ReadKey(); return; } // Create network this.net = new Net(); this.net.AddLayer(new InputLayer(24, 24, 1)); this.net.AddLayer(new ConvLayer(5, 5, 8) { Stride = 1, Pad = 2 }); this.net.AddLayer(new ReluLayer()); this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 }); this.net.AddLayer(new ConvLayer(5, 5, 16) { Stride = 1, Pad = 2 }); this.net.AddLayer(new ReluLayer()); this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 }); this.net.AddLayer(new FullyConnLayer(10)); this.net.AddLayer(new SoftmaxLayer(10)); this.trainer = new AdadeltaTrainer(this.net) { BatchSize = 20, L2Decay = 0.001, }; Console.WriteLine("Convolutional neural network learning...[Press any key to stop]"); do { var sample = this.SampleTrainingInstance(); this.Step(sample); } while (!Console.KeyAvailable); }
private void MnistDemo() { Directory.CreateDirectory(mnistFolder); string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile); string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile); string testingLabelFilePath = Path.Combine(mnistFolder, testingLabelFile); string testingImageFilePath = Path.Combine(mnistFolder, testingImageFile); // Download Mnist files if needed Console.WriteLine("Downloading Mnist training files..."); DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath); DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath); Console.WriteLine("Downloading Mnist testing files..."); DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath); DownloadFile(urlMnist + testingImageFile, testingImageFilePath); // Load data Console.WriteLine("Loading the datasets..."); this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList(); this.testing = MnistReader.Load(testingLabelFilePath, testingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList(); if (this.training.Count == 0 || this.testing.Count == 0) { Console.WriteLine("Missing Mnist training/testing files."); Console.ReadKey(); return; } Console.WriteLine($"datasets training: {this.training.Count}, testing: {this.testing.Count}"); //ExtractImages(); var netFile = Path.Combine(mnistFolder, $"net{string.Join("", numbers)}.bin"); if (File.Exists(netFile)) { Console.WriteLine($"load {netFile}?"); if (Console.ReadKey(true).Key == ConsoleKey.Enter) { Console.WriteLine($"loading..."); this.net = Net.Load(netFile); } } if (this.net == null) { this.net = new Net(); this.net.AddLayer(new InputLayer(24, 24, 1)); this.net.AddLayer(new ConvLayer(5, 5, 8) { Stride = 1, Pad = 2 }); this.net.AddLayer(new ReluLayer()); //this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 }); this.net.AddLayer(new ConvLayer(5, 5, 16) { Stride = 1, Pad = 2 }); this.net.AddLayer(new ReluLayer()); //this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 }); this.net.AddLayer(new FullyConnLayer(numbers.Length)); //this.net.AddLayer(new DropOutLayer()); this.net.AddLayer(new SoftmaxLayer(numbers.Length)); } this.trainer = new AdadeltaTrainer(this.net) { BatchSize = 20, L2Decay = 0.001, }; Console.WriteLine($"limit cpu cores to 1?"); if (Console.ReadKey(true).Key == ConsoleKey.Enter) { using (Process Proc = Process.GetCurrentProcess()) { long AffinityMask = (long)Proc.ProcessorAffinity; AffinityMask = 1; Proc.ProcessorAffinity = (IntPtr)AffinityMask; } } Console.WriteLine("Training...[Press any key to stop]"); bool ok = false; while (!ok) { do { var sample = this.SampleTrainingInstance(); ok = this.Step(sample); }while (!ok && !Console.KeyAvailable); if (!ok) { Console.ReadKey(true); } Console.WriteLine($"stop? [ENTER: continue]"); ok = !(Console.ReadKey(true).Key == ConsoleKey.Enter); } Console.WriteLine($"save {netFile}?"); if (Console.ReadKey(true).Key == ConsoleKey.Enter) { var f = Path.Combine(mnistFolder, netFile); Console.WriteLine($"saving..."); net.Save(f); } Console.WriteLine("done."); Console.ReadKey(); }