Beispiel #1
0
        static void Main(string[] args)
        {
            Console.WriteLine("MNIST Test");

            int seed;

            using (var rng = new RNGCryptoServiceProvider())
            {
                var buffer = new byte[sizeof(int)];

                rng.GetBytes(buffer);
                seed = BitConverter.ToInt32(buffer, 0);
            }

            RandomProvider.SetSeed(seed);

            var   assembly            = Assembly.GetExecutingAssembly();
            var   filename            = "CNN.xml";
            var   serializer          = new DataContractSerializer(typeof(IEnumerable <Layer>), new Type[] { typeof(Convolution), typeof(BatchNormalization), typeof(Activation), typeof(ReLU), typeof(MaxPooling), typeof(FullyConnected), typeof(Softmax) });
            var   random              = RandomProvider.GetRandom();
            var   trainingList        = new List <Tuple <double[], double[]> >();
            var   testList            = new List <Tuple <double[], double[]> >();
            var   accuracyList        = new List <double>();
            var   lossList            = new List <double>();
            var   logPath             = "Log.csv";
            var   channels            = 1;
            var   imageWidth          = 28;
            var   imageHeight         = 28;
            var   filters             = 30;
            var   filterWidth         = 5;
            var   filterHeight        = 5;
            var   poolWidth           = 2;
            var   poolHeight          = 2;
            var   activationMapWidth  = Convolution.GetActivationMapLength(imageWidth, filterWidth);
            var   activationMapHeight = Convolution.GetActivationMapLength(imageHeight, filterHeight);
            var   outputWidth         = MaxPooling.GetOutputLength(activationMapWidth, poolWidth);
            var   outputHeight        = MaxPooling.GetOutputLength(activationMapHeight, poolHeight);
            Model model;

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.train-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.train-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    trainingList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.t10k-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.t10k-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    testList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            if (File.Exists(filename))
            {
                using (XmlReader xmlReader = XmlReader.Create(filename))
                {
                    model = new Model((IEnumerable <Layer>)serializer.ReadObject(xmlReader), new Adam(), new SoftmaxCrossEntropy());
                }
            }
            else
            {
                /*model = new Model(new Layer[] {
                 *  new Convolutional(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *  new Activation(filters * activationMapWidth * activationMapHeight, new ReLU()),
                 *  new MaxPooling(filters, activationMapWidth, activationMapHeight, poolWidth, poolHeight),
                 *  new FullyConnected(filters * outputWidth * outputHeight, 100, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *  new Activation(100, new ReLU()),
                 * new Softmax(100, 10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut))
                 * }, new Adam(), new SoftmaxCrossEntropy());*/
                /*var inputLayer = new Convolutional(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn));
                 *
                 * new Softmax(
                 *  new Activation(
                 *      new FullyConnected(
                 *          new MaxPooling(
                 *              new Activation(inputLayer, new ReLU()),
                 *              filters, inputLayer.ActivationMapWidth, inputLayer.ActivationMapHeight, poolWidth, poolHeight),
                 *          100, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *      new ReLU()),
                 *  10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut));
                 *
                 * model = new Model(inputLayer, new Adam(), new SoftmaxCrossEntropy());*/
                model = new Model(
                    new Convolution(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn),
                                    new Activation(new ReLU(),
                                                   new MaxPooling(filters, activationMapWidth, activationMapHeight, poolWidth, poolHeight,
                                                                  new FullyConnected(filters * outputWidth * outputHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn),
                                                                                     new Activation(new ReLU(),
                                                                                                    new Softmax(100, 10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut))))))),
                    new Adam(), new SoftmaxCrossEntropy());
                int epochs     = 50;
                int iterations = 1;

                model.Stepped += (sender, e) =>
                {
                    double tptn = 0.0;

                    trainingList.ForEach(x =>
                    {
                        var vector = model.Predicate(x.Item1);
                        var i      = ArgMax(vector);
                        var j      = ArgMax(x.Item2);

                        if (i == j && Math.Round(vector[i]) == x.Item2[j])
                        {
                            tptn += 1.0;
                        }
                    });

                    var accuracy = tptn / trainingList.Count;

                    accuracyList.Add(accuracy);
                    lossList.Add(model.Loss);

                    Console.WriteLine("Epoch {0}/{1}", iterations, epochs);
                    Console.WriteLine("Accuracy: {0}, Loss: {1}", accuracy, model.Loss);

                    iterations++;
                };

                Console.WriteLine("Training...");

                var stopwatch = Stopwatch.StartNew();

                model.Fit(trainingList, epochs, 100);

                stopwatch.Stop();

                Console.WriteLine("Done ({0}).", stopwatch.Elapsed.ToString());
            }

            double testTptn = 0.0;

            testList.ForEach(x =>
            {
                var vector = model.Predicate(x.Item1);
                var i      = ArgMax(vector);
                var j      = ArgMax(x.Item2);

                if (i == j && Math.Round(vector[i]) == x.Item2[j])
                {
                    testTptn += 1.0;
                }
            });

            Console.WriteLine("Accuracy: {0}", testTptn / testList.Count);

            if (accuracyList.Count > 0)
            {
                var logDictionary = new Dictionary <string, IEnumerable <double> >();

                logDictionary.Add("Accuracy", accuracyList);
                logDictionary.Add("Loss", lossList);

                ToCsv(logPath, logDictionary);

                Console.WriteLine("Saved log to {0}...", logPath);
            }

            XmlWriterSettings settings = new XmlWriterSettings();

            settings.Indent   = true;
            settings.Encoding = new System.Text.UTF8Encoding(false);

            using (XmlWriter xmlWriter = XmlWriter.Create(filename, settings))
            {
                serializer.WriteObject(xmlWriter, model.Layers);
                xmlWriter.Flush();
            }
        }
Beispiel #2
0
        static void Main(string[] args)
        {
            Console.WriteLine("MNIST Test");

            int seed;

            using (var rng = new RNGCryptoServiceProvider())
            {
                var buffer = new byte[sizeof(int)];

                rng.GetBytes(buffer);
                seed = BitConverter.ToInt32(buffer, 0);
            }

            RandomProvider.SetSeed(seed);

            var assembly      = Assembly.GetExecutingAssembly();
            var random        = RandomProvider.GetRandom();
            var trainingList  = new List <Tuple <double[], double[]> >();
            var testList      = new List <Tuple <double[], double[]> >();
            var accuracyList  = new List <double>();
            var lossList      = new List <double>();
            var logDictionary = new Dictionary <string, IEnumerable <double> >();
            var logPath       = "Log.csv";
            var channels      = 1;
            var imageWidth    = 28;
            var imageHeight   = 28;
            var filters       = 30;
            var filterWidth   = 5;
            var filterHeight  = 5;
            var poolWidth     = 2;
            var poolHeight    = 2;

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.train-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.train-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    trainingList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.t10k-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.t10k-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    testList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            var inputLayer  = new ConvolutionalPoolingLayer(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, poolWidth, poolHeight, new ReLU(), (index, fanIn, fanOut) => Initializers.HeNormal(fanIn));
            var hiddenLayer = new FullyConnectedLayer(inputLayer, 100, new ReLU(), (index, fanIn, fanOut) => Initializers.HeNormal(fanIn));
            var outputLayer = new SoftmaxLayer(hiddenLayer, 10, (index, fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut));
            var network     = new Network(inputLayer, outputLayer, new Adam(), new SoftmaxCrossEntropy());
            int epochs      = 50;
            int iterations  = 1;

            network.Stepped += (sender, e) =>
            {
                double tptn = 0;

                trainingList.ForEach(x =>
                {
                    var vector = network.Predicate(x.Item1);
                    var i      = ArgMax(vector);
                    var j      = ArgMax(x.Item2);

                    if (i == j && Math.Round(vector[i]) == x.Item2[j])
                    {
                        tptn += 1.0;
                    }
                });

                var accuracy = tptn / trainingList.Count;

                accuracyList.Add(accuracy);
                lossList.Add(network.Loss);

                Console.WriteLine("Epoch {0}/{1}", iterations, epochs);
                Console.WriteLine("Accuracy: {0}, Loss: {1}", accuracy, network.Loss);

                iterations++;
            };

            Console.WriteLine("Training...");

            var stopwatch = Stopwatch.StartNew();

            network.Train(trainingList, epochs, 100);

            stopwatch.Stop();

            Console.WriteLine("Done ({0}).", stopwatch.Elapsed.ToString());

            double testTptn = 0;

            testList.ForEach(x =>
            {
                var vector = network.Predicate(x.Item1);
                var i      = ArgMax(vector);
                var j      = ArgMax(x.Item2);

                if (i == j && Math.Round(vector[i]) == x.Item2[j])
                {
                    testTptn += 1.0;
                }
            });

            Console.WriteLine("Accuracy: {0}", testTptn / testList.Count);

            logDictionary.Add("Accuracy", accuracyList);
            logDictionary.Add("Loss", lossList);

            ToCsv(logPath, logDictionary);

            Console.Write("Saved log to {0}...", logPath);
        }