コード例 #1
0
ファイル: Program.cs プロジェクト: radioman/ConvNetSharp
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            // Create network
            this.net = FluentNet.Create(24, 24, 1)
                       .Conv(5, 5, 8).Stride(1).Pad(2)
                       .Relu()
                       .Pool(2, 2).Stride(2)
                       .Conv(5, 5, 16).Stride(1).Pad(2)
                       .Relu()
                       .Pool(3, 3).Stride(3)
                       .FullyConn(10)
                       .Softmax(10)
                       .Build();

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var sample = this.SampleTrainingInstance();
                this.Step(sample);
            } while (!Console.KeyAvailable);
        }
コード例 #2
0
        private void TrainNetworkForTactile(double availableLoss)
        {
            trainer = new AdadeltaTrainer(net)
            {
                // Количество обрабатываемых образцов за заход
                BatchSize = 15,
                // Регуляризация - штраф на наибольший вес
                L2Decay = 0.001,
            };

            do
            {
                var sample = PrepareTrainingSample();
                TrainingStep(sample);
            } while (loss > availableLoss);
        }
コード例 #3
0
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            // Create network
            this.net = new Net();
            this.net.AddLayer(new InputLayer(24, 24, 1));
            this.net.AddLayer(new ConvLayer(5, 5, 8)
            {
                Stride = 1, Pad = 2
            });
            this.net.AddLayer(new ReluLayer());
            this.net.AddLayer(new PoolLayer(2, 2)
            {
                Stride = 2
            });
            this.net.AddLayer(new ConvLayer(5, 5, 16)
            {
                Stride = 1, Pad = 2
            });
            this.net.AddLayer(new ReluLayer());
            this.net.AddLayer(new PoolLayer(3, 3)
            {
                Stride = 3
            });
            this.net.AddLayer(new FullyConnLayer(10));
            this.net.AddLayer(new SoftmaxLayer(10));

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var sample = this.SampleTrainingInstance();
                this.Step(sample);
            } while (!Console.KeyAvailable);
        }
コード例 #4
0
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);

            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");

            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList();
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath).Where(p => numbers.Contains(p.Label)).ToList();

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            Console.WriteLine($"datasets training: {this.training.Count}, testing: {this.testing.Count}");

            //ExtractImages();

            var netFile = Path.Combine(mnistFolder, $"net{string.Join("", numbers)}.bin");

            if (File.Exists(netFile))
            {
                Console.WriteLine($"load {netFile}?");
                if (Console.ReadKey(true).Key == ConsoleKey.Enter)
                {
                    Console.WriteLine($"loading...");
                    this.net = Net.Load(netFile);
                }
            }

            if (this.net == null)
            {
                this.net = new Net();
                this.net.AddLayer(new InputLayer(24, 24, 1));

                this.net.AddLayer(new ConvLayer(5, 5, 8)
                {
                    Stride = 1, Pad = 2
                });
                this.net.AddLayer(new ReluLayer());
                //this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 });

                this.net.AddLayer(new ConvLayer(5, 5, 16)
                {
                    Stride = 1, Pad = 2
                });
                this.net.AddLayer(new ReluLayer());
                //this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 });

                this.net.AddLayer(new FullyConnLayer(numbers.Length));
                //this.net.AddLayer(new DropOutLayer());
                this.net.AddLayer(new SoftmaxLayer(numbers.Length));
            }

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine($"limit cpu cores to 1?");
            if (Console.ReadKey(true).Key == ConsoleKey.Enter)
            {
                using (Process Proc = Process.GetCurrentProcess())
                {
                    long AffinityMask = (long)Proc.ProcessorAffinity;
                    AffinityMask           = 1;
                    Proc.ProcessorAffinity = (IntPtr)AffinityMask;
                }
            }

            Console.WriteLine("Training...[Press any key to stop]");

            bool ok = false;

            while (!ok)
            {
                do
                {
                    var sample = this.SampleTrainingInstance();

                    ok = this.Step(sample);
                }while (!ok && !Console.KeyAvailable);

                if (!ok)
                {
                    Console.ReadKey(true);
                }

                Console.WriteLine($"stop? [ENTER: continue]");
                ok = !(Console.ReadKey(true).Key == ConsoleKey.Enter);
            }

            Console.WriteLine($"save {netFile}?");
            if (Console.ReadKey(true).Key == ConsoleKey.Enter)
            {
                var f = Path.Combine(mnistFolder, netFile);
                Console.WriteLine($"saving...");

                net.Save(f);
            }
            Console.WriteLine("done.");
            Console.ReadKey();
        }
コード例 #5
0
        private const string folder = @"..\..\..\char\"; // para que saia da pasta 'debug' encontre os dados de treino
        private void MnistTrain()
        {
            // carrega os chars
            Console.WriteLine("carregando datasets...");
            if (File.Exists("dataset_training.dat") || File.Exists("dataset_testing.dat"))
            {   // se já tiver os datasets tratados, carrega eles
                training = useful.ReadObject <List <MnistEntry> >(File.ReadAllBytes("dataset_training.dat"));
                testing  = useful.ReadObject <List <MnistEntry> >(File.ReadAllBytes("dataset_testing.dat"));
            }
            else
            {   // caso contrário, carrega a partir das imagens
                training = MnistReader.Load(folder, true);
                testing  = MnistReader.Load(folder, false);
                File.WriteAllBytes("dataset_training.dat", useful.WriteObject <List <MnistEntry> >(training));
                File.WriteAllBytes("dataset_testing.dat", useful.WriteObject <List <MnistEntry> >(testing));
            }
            Random rnd = new Random();

            training = training.OrderBy(x => rnd.Next()).ToList();
            testing  = testing.OrderBy(x => rnd.Next()).ToList();

            if (training.Count == 0 || testing.Count == 0)
            {
                Console.WriteLine("ajuste o diretório dos arquivos de treino/teste.");
                Console.ReadKey();
                return;
            }
            // cria uma CNN simples
            net = new Net();
            net.AddLayer(new InputLayer(24, 24, 1)); //tamanho que eu escalei as imagens
            net.AddLayer(new ConvLayer(5, 5, 20)
            {
                Stride = 1, Pad = 2, Activation = Activation.Relu
            });
            net.AddLayer(new PoolLayer(2, 2)
            {
                Stride = 2
            });
            net.AddLayer(new ConvLayer(5, 5, 35)
            {
                Stride = 1, Pad = 2, Activation = Activation.Relu
            });
            net.AddLayer(new PoolLayer(3, 3)
            {
                Stride = 3
            });
            net.AddLayer(new FullyConnLayer(50));
            net.AddLayer(new SoftmaxLayer(33));

            // este é o meu otimizador
            // ver depois http://int8.io/comparison-of-optimization-techniques-stochastic-gradient-descent-momentum-adagrad-and-adadelta/#AdaDelta
            trainer = new AdadeltaTrainer(net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };
            Console.WriteLine("CNN treinando... aperte uma tecla para parar");
            // o passo
            do
            {
                var sample = SampleTrainingInstance();

                // // só pra ver no matlab a imagem (no matlab eu uso o script imag.m)
                //if (File.Exists("oi.txt")) { File.Delete("oi.txt"); }
                //for (int i = 0; i < sample.Volume.Weights.Length; i++)
                //    using (StreamWriter file = new StreamWriter(@"oi.txt", true))
                //        file.WriteLine(((sample.Volume.Weights[i] * 255.0)).ToString());
                Step(sample);
            } while (!Console.KeyAvailable);
        }