Exemple #1
0
        public bool Load(int validationSize = 0, Boolean dataAugmentation = false)
        {
            Directory.CreateDirectory(mnistFolder);

            var trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            var trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            var testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            var testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            var train_images   = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            var testing_images = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (dataAugmentation)
            {
                train_images = (new ImageTransformation()).DataAugmentation(train_images);
            }

            var validation_images = train_images.GetRange(train_images.Count - validationSize, validationSize);

            train_images = train_images.GetRange(0, train_images.Count - validationSize);

            if (train_images.Count == 0 || testing_images.Count == 0 || (validationSize > 0 && validation_images.Count == 0))
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return(false);
            }

            this.Train      = new DataSet(train_images);
            this.Test       = new DataSet(testing_images);
            this.Validation = new DataSet(validation_images);

            return(true);
        }
Exemple #2
0
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            // Create network
            this.net = new Net();
            this.net.AddLayer(new InputLayer(24, 24, 1));
            this.net.AddLayer(new ConvLayer(5, 5, 8)
            {
                Stride = 1, Pad = 2
            });
            this.net.AddLayer(new ReluLayer());
            this.net.AddLayer(new PoolLayer(2, 2)
            {
                Stride = 2
            });
            this.net.AddLayer(new ConvLayer(5, 5, 16)
            {
                Stride = 1, Pad = 2
            });
            this.net.AddLayer(new ReluLayer());
            this.net.AddLayer(new PoolLayer(3, 3)
            {
                Stride = 3
            });
            this.net.AddLayer(new FullyConnLayer(10));
            this.net.AddLayer(new SoftmaxLayer(10));

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var sample = this.SampleTrainingInstance();
                this.Step(sample);
            } while (!Console.KeyAvailable);
        }
Exemple #3
0
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);

            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");

            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList();
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList();

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            Console.WriteLine($"datasets training: {this.training.Count}, testing: {this.testing.Count}");

            //ExtractImages();

            var netFile = Path.Combine(mnistFolder, $"net{string.Join("", numbers)}.bin");

            if (File.Exists(netFile))
            {
                Console.WriteLine($"load {netFile}?");
                if (Console.ReadKey(true).Key == ConsoleKey.Enter)
                {
                    Console.WriteLine($"loading...");
                    this.net = Net.Load(netFile);
                }
            }

            if (this.net == null)
            {
                this.net = new Net();
                this.net.AddLayer(new InputLayer(24, 24, 1));

                this.net.AddLayer(new ConvLayer(5, 5, 8)
                {
                    Stride = 1, Pad = 2
                });
                this.net.AddLayer(new ReluLayer());
                //this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 });

                this.net.AddLayer(new ConvLayer(5, 5, 16)
                {
                    Stride = 1, Pad = 2
                });
                this.net.AddLayer(new ReluLayer());
                //this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 });

                this.net.AddLayer(new FullyConnLayer(numbers.Length));
                //this.net.AddLayer(new DropOutLayer());
                this.net.AddLayer(new SoftmaxLayer(numbers.Length));
            }

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine($"limit cpu cores to 1?");
            if (Console.ReadKey(true).Key == ConsoleKey.Enter)
            {
                using (Process Proc = Process.GetCurrentProcess())
                {
                    long AffinityMask = (long)Proc.ProcessorAffinity;
                    AffinityMask           = 1;
                    Proc.ProcessorAffinity = (IntPtr)AffinityMask;
                }
            }

            Console.WriteLine("Training...[Press any key to stop]");

            bool ok = false;

            while (!ok)
            {
                do
                {
                    var sample = this.SampleTrainingInstance();

                    ok = this.Step(sample);
                }while (!ok && !Console.KeyAvailable);

                if (!ok)
                {
                    Console.ReadKey(true);
                }

                Console.WriteLine($"stop? [ENTER: continue]");
                ok = !(Console.ReadKey(true).Key == ConsoleKey.Enter);
            }

            Console.WriteLine($"save {netFile}?");
            if (Console.ReadKey(true).Key == ConsoleKey.Enter)
            {
                var f = Path.Combine(mnistFolder, netFile);
                Console.WriteLine($"saving...");

                net.Save(f);
            }
            Console.WriteLine("done.");
            Console.ReadKey();
        }