Exemple #1
0
        [Test] public void MultiLayerPerceptron()
        {
            CleanMem_();
            const double clipNorm  = 3.0;
            const long   batchSize = 1000L;
            const long   epochs    = 3;

            var model = MultiLayerPerceptronModel();
            var ctx   = Context.GpuContext(0);
            var opt   = new RMSpropOptimizer(ctx, model.Loss.Loss, 0.005, 0.9, float.Epsilon,
                                             new GlobalNormGradientClipper(clipNorm));

            opt.Initalize();

            var mnist   = new MNIST();
            var batcher = new Batcher(ctx, mnist.TrainImages, mnist.TrainLabels);

            for (var e = 1; e <= epochs; ++e)
            {
                var i = 0;
                while (batcher.Next(batchSize, opt, model.Images, model.Labels))
                {
                    i++;
                    opt.Forward();
                    opt.Backward();
                    opt.Optimize();

                    if ((i % 10 == 0) || ((i == 1) && (e == 1)))
                    {
                        PrintStatus(e, i, opt, model, mnist.ValidationImages, mnist.ValidationLabels);
                    }
                }
            }
            PrintResult(opt, model, mnist.TestImages, mnist.TestLabels);
        }
Exemple #2
0
        [Test] public void ConvolutionalNeuralNetwork()
        {
            CleanMem_();
            const long batchSize = 500L;
            const long epochs    = 2;

            var model = ConvolutionalNeuralNetworkModel();
            var ctx   = Context.GpuContext(0);
            var opt   = new RMSpropOptimizer(ctx, model.Loss.Loss, 0.005, 0.9, 1e-9);

            opt.Initalize();

            var mnist   = new MNIST();
            var batcher = new Batcher(ctx, mnist.TrainImages, mnist.TrainLabels);

            for (var e = 1; e <= epochs; ++e)
            {
                var i = 0;
                while (batcher.Next(batchSize, opt, model.Images, model.Labels))
                {
                    i++;
                    opt.Forward();
                    opt.Backward();
                    opt.Optimize();

                    if ((i % 20 == 0) || ((i == 1) && (e == 1)))
                    {
                        PrintStatus(e, i, opt, model, mnist.ValidationImages, mnist.ValidationLabels);
                    }
                }
            }
            PrintResult(opt, model, mnist.TestImages, mnist.TestLabels);

            CleanMem_();
        }