コード例 #1
0
ファイル: Program.cs プロジェクト: radioman/ConvNetSharp
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            // Create network
            this.net = FluentNet.Create(24, 24, 1)
                       .Conv(5, 5, 8).Stride(1).Pad(2)
                       .Relu()
                       .Pool(2, 2).Stride(2)
                       .Conv(5, 5, 16).Stride(1).Pad(2)
                       .Relu()
                       .Pool(3, 3).Stride(3)
                       .FullyConn(10)
                       .Softmax(10)
                       .Build();

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var sample = this.SampleTrainingInstance();
                this.Step(sample);
            } while (!Console.KeyAvailable);
        }
コード例 #2
0
        public bool Load(int validationSize = 1000)
        {
            Directory.CreateDirectory(mnistFolder);

            var trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            var trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            var testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            var testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            var train_images   = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            var testing_images = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            var valiation_images = train_images.GetRange(train_images.Count - validationSize, validationSize);

            train_images = train_images.GetRange(0, train_images.Count - validationSize);

            if (train_images.Count == 0 || valiation_images.Count == 0 || testing_images.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return(false);
            }

            this.Train      = new DataSet(train_images);
            this.Validation = new DataSet(valiation_images);
            this.Test       = new DataSet(testing_images);

            return(true);
        }