Пример #1
0
 public MnistImageVM(MnistImage image)
 {
     Image          = image;
     Source         = BitmapUtils.BitmapFromArray(image.Pixels, MnistImage.ImgWidth, MnistImage.ImgHeight);
     Label          = (FashionLabel)image.Label;
     PredictedLabel = FashionLabel.Unknown;
 }
Пример #2
0
    public static MnistImage[] ProcessMNISTFile(string imgPath, string labelPath, Action <string, MessageSection> updaterCallback = null)
    {
        MnistImage[] images;
        int          numImages, rows, cols;

        StreamReader imageReaderStream = new StreamReader(imgPath);
        BinaryReader imageReader       = new BinaryReader(imageReaderStream.BaseStream);
        StreamReader labelReaderStream = new StreamReader(labelPath);
        BinaryReader labelReader       = new BinaryReader(labelReaderStream.BaseStream);

        //reading hear information
        {
            //discard magic number
            imageReader.ReadBytes(4);
            labelReader.ReadBytes(4);
            //number of images
            byte[] ReadBytes = imageReader.ReadBytes(4);
            labelReader.ReadBytes(4);
            numImages = (ReadBytes[0] << 24) + (ReadBytes[1] << 16) + (ReadBytes[2] << 8) + ReadBytes[3];

            ReadBytes = imageReader.ReadBytes(4);
            rows      = (ReadBytes[0] << 24) + (ReadBytes[1] << 16) + (ReadBytes[2] << 8) + ReadBytes[3];

            ReadBytes = imageReader.ReadBytes(4);
            cols      = (ReadBytes[0] << 24) + (ReadBytes[1] << 16) + (ReadBytes[2] << 8) + ReadBytes[3];
        }

        updaterCallback?.Invoke($"Processing Training Set Images from path: {imgPath}", MessageSection.HEAD);
        images = new MnistImage[numImages];
        for (int i = 0; i < numImages; i++)
        {
            images[i]        = new MnistImage(cols, rows);
            images[i].Pixels = new float[rows * cols];
        }


        byte[] labels       = labelReader.ReadBytes(numImages);
        byte[] imagesPixels = imageReader.ReadBytes(numImages * rows * cols);

        int imagesPixelIndex = 0;
        int pixelIndex       = 0;

        for (int image = 0; image < numImages; image++)
        {
            images[image].Label = labels[image];
            updaterCallback?.Invoke($"Processing Image {image + 1} out of {numImages}", MessageSection.BODY);

            for (int row = 0; row < rows; row++)
            {
                for (int col = 0; col < cols; col++, imagesPixelIndex++, pixelIndex++)
                {
                    images[image].Pixels[pixelIndex] = imagesPixels[imagesPixelIndex] / 255.0f;
                }
            }
            pixelIndex = 0;
        }

        return(images);
    }
Пример #3
0
        public FashionLabel Predict(MnistImage image)
        {
            if (_trainedModel == null)
            {
                return(FashionLabel.Unknown);
            }

            var engine = _context.Model.CreatePredictionEngine <MnistImage, MnistPrediction>(_trainedModel);

            return(Predict(image, engine));
        }
Пример #4
0
        public MnistCNNGUI()
        {
            InitializeComponent();

            string trainImgPath = "res/train-images.idx3-ubyte"
            , trainLblPath      = "res/train-labels.idx1-ubyte"
            , testImgPath       = "res/t10k-images.idx3-ubyte"
            , testLblPath       = "res/t10k-labels.idx1-ubyte";

            trainingImages = MnistImage.ProcessMNISTFile(trainImgPath, trainLblPath);
            testImages     = MnistImage.ProcessMNISTFile(testImgPath, testLblPath);

            mnistNetwork = new ConvolutionalNeuralNetwork(new ConvolutionalNeuralNetworkProps(28, 28, 90, 10, 6, 5, 5, 2, 2));
            mnistNetwork.UploadTrainingSet(trainingImages);
            mnistNetwork.UploadTestSet(testImages);
            mnistNetwork.NumEpochs = 1;

            //mnistNetwork.TestTrainingSet();
            mnistNetwork.Test(true);
            UpdateSensitivity();

            selectedImageSet = testImages;

            LoadImage(0);
            UseTrainingSetCheckBox_CheckedChanged(null, null);

            NetworkWorker = new BackgroundWorker();
            NetworkWorker.WorkerReportsProgress = true;
            NetworkWorker.DoWork          += TrainNetwork;
            NetworkWorker.ProgressChanged += (object sender, ProgressChangedEventArgs e) =>
            {
                int percent = e.ProgressPercentage;
                TrainingProgressLabel.Text = $"Network Progress: {percent}%";
                TrainingProgressBar.Value  = percent;
            };
            NetworkWorker.RunWorkerCompleted += (object sender, RunWorkerCompletedEventArgs e) =>
            {
                UpdateSensitivity();
                SetNetworkActive(true);
                TrainingProgressLabel.Text = e.Cancelled == false ? $"Network Progress: {0}% (finished)" : "Network Progress: {0}% (cancelled)";
                TrainingProgressBar.Value  = 0;
                TotalEpochLabel.Text       = $"Total Trained Epochs: {mnistNetwork.TotalEpochs.ToString("n2")}";
                TestImage();
            };
            NetworkWorker.WorkerSupportsCancellation = true;

            NumEpochsComboBox.SelectedIndex = 0;
        }
Пример #5
0
        public void Constructor_ImageDataNotSquareMatrix()
        {
            var imageData = new Byte[, ]
            {
                { 0, 0, 0 },
                { 0, 0, 0 },
            };

            var e = Assert.Throws <ArgumentException>(delegate
            {
                var testMnistImage = new MnistImage(imageData, 0);
            });

            Assert.That(e.Message, Does.StartWith("Parameter 'imageData' must contain equal dimensions (i.e. be a square 2-dimensional matrix)."));
            Assert.AreEqual("imageData", e.ParamName);
        }
Пример #6
0
        public void Constructor_LabelOutOfRange()
        {
            var e = Assert.Throws <ArgumentOutOfRangeException>(delegate
            {
                var testMnistImage = new MnistImage(new Byte[0, 0], -1);
            });

            Assert.That(e.Message, Does.StartWith("Parameter 'label' must be between 0 and 9 inclusive."));
            Assert.AreEqual("label", e.ParamName);


            e = Assert.Throws <ArgumentOutOfRangeException>(delegate
            {
                var testMnistImage = new MnistImage(new Byte[0, 0], 10);
            });

            Assert.That(e.Message, Does.StartWith("Parameter 'label' must be between 0 and 9 inclusive."));
            Assert.AreEqual("label", e.ParamName);
        }
Пример #7
0
        public IEnumerable <MnistImage> ReadDataset(FileInfo imageFile, FileInfo labelFile, int imageCount)
        {
            var result = new List <MnistImage>();

            using FileStream ifsLabels = new FileStream(labelFile.FullName, FileMode.Open); // labels
            using FileStream ifsImages = new FileStream(imageFile.FullName, FileMode.Open); // images

            using BinaryReader brLabels = new BinaryReader(ifsLabels);
            using BinaryReader brImages = new BinaryReader(ifsImages);

            // Image file headers
            int magic1    = brImages.ReadInt32(); // discard
            int numImages = brImages.ReadInt32();
            int numRows   = brImages.ReadInt32();
            int numCols   = brImages.ReadInt32();
            // Label file header
            int magic2    = brLabels.ReadInt32();
            int numLabels = brLabels.ReadInt32();

            int bytesPerImage = MnistImage.ImgWidth * MnistImage.ImgHeight;

            for (int di = 0; di < imageCount; ++di)
            {
                byte[] pixels1d = new byte[bytesPerImage];
                for (int i = 0; i < bytesPerImage; ++i)
                {
                    byte b = brImages.ReadByte();
                    pixels1d[i] = b;
                }

                byte label = brLabels.ReadByte();

                MnistImage dImage = new MnistImage(pixels1d, label);
                result.Add(dImage);
            }

            return(result);
        }
Пример #8
0
        static void Main(string[] args)
        {
            Console.WriteLine("MNIST Test");

            int seed;

            using (var rng = new RNGCryptoServiceProvider())
            {
                var buffer = new byte[sizeof(int)];

                rng.GetBytes(buffer);
                seed = BitConverter.ToInt32(buffer, 0);
            }

            RandomProvider.SetSeed(seed);

            var   assembly            = Assembly.GetExecutingAssembly();
            var   filename            = "CNN.xml";
            var   serializer          = new DataContractSerializer(typeof(IEnumerable <Layer>), new Type[] { typeof(Convolution), typeof(BatchNormalization), typeof(Activation), typeof(ReLU), typeof(MaxPooling), typeof(FullyConnected), typeof(Softmax) });
            var   random              = RandomProvider.GetRandom();
            var   trainingList        = new List <Tuple <double[], double[]> >();
            var   testList            = new List <Tuple <double[], double[]> >();
            var   accuracyList        = new List <double>();
            var   lossList            = new List <double>();
            var   logPath             = "Log.csv";
            var   channels            = 1;
            var   imageWidth          = 28;
            var   imageHeight         = 28;
            var   filters             = 30;
            var   filterWidth         = 5;
            var   filterHeight        = 5;
            var   poolWidth           = 2;
            var   poolHeight          = 2;
            var   activationMapWidth  = Convolution.GetActivationMapLength(imageWidth, filterWidth);
            var   activationMapHeight = Convolution.GetActivationMapLength(imageHeight, filterHeight);
            var   outputWidth         = MaxPooling.GetOutputLength(activationMapWidth, poolWidth);
            var   outputHeight        = MaxPooling.GetOutputLength(activationMapHeight, poolHeight);
            Model model;

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.train-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.train-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    trainingList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.t10k-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.t10k-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    testList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            if (File.Exists(filename))
            {
                using (XmlReader xmlReader = XmlReader.Create(filename))
                {
                    model = new Model((IEnumerable <Layer>)serializer.ReadObject(xmlReader), new Adam(), new SoftmaxCrossEntropy());
                }
            }
            else
            {
                /*model = new Model(new Layer[] {
                 *  new Convolutional(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *  new Activation(filters * activationMapWidth * activationMapHeight, new ReLU()),
                 *  new MaxPooling(filters, activationMapWidth, activationMapHeight, poolWidth, poolHeight),
                 *  new FullyConnected(filters * outputWidth * outputHeight, 100, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *  new Activation(100, new ReLU()),
                 * new Softmax(100, 10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut))
                 * }, new Adam(), new SoftmaxCrossEntropy());*/
                /*var inputLayer = new Convolutional(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn));
                 *
                 * new Softmax(
                 *  new Activation(
                 *      new FullyConnected(
                 *          new MaxPooling(
                 *              new Activation(inputLayer, new ReLU()),
                 *              filters, inputLayer.ActivationMapWidth, inputLayer.ActivationMapHeight, poolWidth, poolHeight),
                 *          100, (fanIn, fanOut) => Initializers.HeNormal(fanIn)),
                 *      new ReLU()),
                 *  10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut));
                 *
                 * model = new Model(inputLayer, new Adam(), new SoftmaxCrossEntropy());*/
                model = new Model(
                    new Convolution(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn),
                                    new Activation(new ReLU(),
                                                   new MaxPooling(filters, activationMapWidth, activationMapHeight, poolWidth, poolHeight,
                                                                  new FullyConnected(filters * outputWidth * outputHeight, (fanIn, fanOut) => Initializers.HeNormal(fanIn),
                                                                                     new Activation(new ReLU(),
                                                                                                    new Softmax(100, 10, (fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut))))))),
                    new Adam(), new SoftmaxCrossEntropy());
                int epochs     = 50;
                int iterations = 1;

                model.Stepped += (sender, e) =>
                {
                    double tptn = 0.0;

                    trainingList.ForEach(x =>
                    {
                        var vector = model.Predicate(x.Item1);
                        var i      = ArgMax(vector);
                        var j      = ArgMax(x.Item2);

                        if (i == j && Math.Round(vector[i]) == x.Item2[j])
                        {
                            tptn += 1.0;
                        }
                    });

                    var accuracy = tptn / trainingList.Count;

                    accuracyList.Add(accuracy);
                    lossList.Add(model.Loss);

                    Console.WriteLine("Epoch {0}/{1}", iterations, epochs);
                    Console.WriteLine("Accuracy: {0}, Loss: {1}", accuracy, model.Loss);

                    iterations++;
                };

                Console.WriteLine("Training...");

                var stopwatch = Stopwatch.StartNew();

                model.Fit(trainingList, epochs, 100);

                stopwatch.Stop();

                Console.WriteLine("Done ({0}).", stopwatch.Elapsed.ToString());
            }

            double testTptn = 0.0;

            testList.ForEach(x =>
            {
                var vector = model.Predicate(x.Item1);
                var i      = ArgMax(vector);
                var j      = ArgMax(x.Item2);

                if (i == j && Math.Round(vector[i]) == x.Item2[j])
                {
                    testTptn += 1.0;
                }
            });

            Console.WriteLine("Accuracy: {0}", testTptn / testList.Count);

            if (accuracyList.Count > 0)
            {
                var logDictionary = new Dictionary <string, IEnumerable <double> >();

                logDictionary.Add("Accuracy", accuracyList);
                logDictionary.Add("Loss", lossList);

                ToCsv(logPath, logDictionary);

                Console.WriteLine("Saved log to {0}...", logPath);
            }

            XmlWriterSettings settings = new XmlWriterSettings();

            settings.Indent   = true;
            settings.Encoding = new System.Text.UTF8Encoding(false);

            using (XmlWriter xmlWriter = XmlWriter.Create(filename, settings))
            {
                serializer.WriteObject(xmlWriter, model.Layers);
                xmlWriter.Flush();
            }
        }
Пример #9
0
        static void Main(string[] args)
        {
            Console.WriteLine("MNIST Test");

            int seed;

            using (var rng = new RNGCryptoServiceProvider())
            {
                var buffer = new byte[sizeof(int)];

                rng.GetBytes(buffer);
                seed = BitConverter.ToInt32(buffer, 0);
            }

            RandomProvider.SetSeed(seed);

            var assembly      = Assembly.GetExecutingAssembly();
            var random        = RandomProvider.GetRandom();
            var trainingList  = new List <Tuple <double[], double[]> >();
            var testList      = new List <Tuple <double[], double[]> >();
            var accuracyList  = new List <double>();
            var lossList      = new List <double>();
            var logDictionary = new Dictionary <string, IEnumerable <double> >();
            var logPath       = "Log.csv";
            var channels      = 1;
            var imageWidth    = 28;
            var imageHeight   = 28;
            var filters       = 30;
            var filterWidth   = 5;
            var filterHeight  = 5;
            var poolWidth     = 2;
            var poolHeight    = 2;

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.train-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.train-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    trainingList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            using (Stream
                   imagesStream = assembly.GetManifestResourceStream("MNISTTest.t10k-images.idx3-ubyte"),
                   labelsStream = assembly.GetManifestResourceStream("MNISTTest.t10k-labels.idx1-ubyte"))
            {
                foreach (var image in MnistImage.Load(imagesStream, labelsStream).Take(1000))
                {
                    var t = new double[10];

                    for (int i = 0; i < 10; i++)
                    {
                        if (i == image.Label)
                        {
                            t[i] = 1.0;
                        }
                        else
                        {
                            t[i] = 0.0;
                        }
                    }

                    testList.Add(Tuple.Create <double[], double[]>(image.Normalize(), t));
                }
            }

            var inputLayer  = new ConvolutionalPoolingLayer(channels, imageWidth, imageHeight, filters, filterWidth, filterHeight, poolWidth, poolHeight, new ReLU(), (index, fanIn, fanOut) => Initializers.HeNormal(fanIn));
            var hiddenLayer = new FullyConnectedLayer(inputLayer, 100, new ReLU(), (index, fanIn, fanOut) => Initializers.HeNormal(fanIn));
            var outputLayer = new SoftmaxLayer(hiddenLayer, 10, (index, fanIn, fanOut) => Initializers.GlorotNormal(fanIn, fanOut));
            var network     = new Network(inputLayer, outputLayer, new Adam(), new SoftmaxCrossEntropy());
            int epochs      = 50;
            int iterations  = 1;

            network.Stepped += (sender, e) =>
            {
                double tptn = 0;

                trainingList.ForEach(x =>
                {
                    var vector = network.Predicate(x.Item1);
                    var i      = ArgMax(vector);
                    var j      = ArgMax(x.Item2);

                    if (i == j && Math.Round(vector[i]) == x.Item2[j])
                    {
                        tptn += 1.0;
                    }
                });

                var accuracy = tptn / trainingList.Count;

                accuracyList.Add(accuracy);
                lossList.Add(network.Loss);

                Console.WriteLine("Epoch {0}/{1}", iterations, epochs);
                Console.WriteLine("Accuracy: {0}, Loss: {1}", accuracy, network.Loss);

                iterations++;
            };

            Console.WriteLine("Training...");

            var stopwatch = Stopwatch.StartNew();

            network.Train(trainingList, epochs, 100);

            stopwatch.Stop();

            Console.WriteLine("Done ({0}).", stopwatch.Elapsed.ToString());

            double testTptn = 0;

            testList.ForEach(x =>
            {
                var vector = network.Predicate(x.Item1);
                var i      = ArgMax(vector);
                var j      = ArgMax(x.Item2);

                if (i == j && Math.Round(vector[i]) == x.Item2[j])
                {
                    testTptn += 1.0;
                }
            });

            Console.WriteLine("Accuracy: {0}", testTptn / testList.Count);

            logDictionary.Add("Accuracy", accuracyList);
            logDictionary.Add("Loss", lossList);

            ToCsv(logPath, logDictionary);

            Console.Write("Saved log to {0}...", logPath);
        }
Пример #10
0
        private FashionLabel Predict(MnistImage img, PredictionEngine <MnistImage, MnistPrediction> engine)
        {
            var prediction = engine.Predict(img);

            return((FashionLabel)(prediction.PredictedLabel - 1)); // !!
        }
        /// <include file='InterfaceDocumentationComments.xml' path='doc/members/member[@name="M:MnistImageStore.Persistence.IMnistImageReader.Read"]/*'/>
        public IEnumerable <MnistImage> Read()
        {
            using (FileStream imageDataFileStream = new FileStream(imageDataFilePath, FileMode.Open))
                using (FileStream labelFileStream = new FileStream(labelFilePath, FileMode.Open))
                    using (BinaryReader imageDataReader = new BinaryReader(imageDataFileStream))
                        using (BinaryReader labelReader = new BinaryReader(labelFileStream))
                        {
                            // Read the header information from each file
                            Int32 imageCount = 0, labelCount = 0, imageRowPixelCount = 0, imageColumnPixelCount = 0;
                            try
                            {
                                // Consume the 'magic number'
                                imageDataReader.ReadInt32();
                                labelReader.ReadInt32();
                                // Read the item counts
                                imageCount = ConvertByteOrderFromBigEndian(imageDataReader.ReadInt32());
                                labelCount = ConvertByteOrderFromBigEndian(labelReader.ReadInt32());
                                // Read the dimensions of the images
                                imageRowPixelCount    = ConvertByteOrderFromBigEndian(imageDataReader.ReadInt32());
                                imageColumnPixelCount = ConvertByteOrderFromBigEndian(imageDataReader.ReadInt32());
                            }
                            catch (Exception e)
                            {
                                throw new Exception($"Error reading header information from image data file '{imageDataFilePath}'.", e);
                            }

                            if (imageCount != labelCount)
                            {
                                throw new Exception($"Number of images ({imageCount}) differs from number of labels ({labelCount}) in image file '{imageDataFilePath}' and label file '{labelFilePath}'.");
                            }
                            if (imageCount < 0)
                            {
                                throw new Exception($"Number of images listed in header information of file '{imageDataFilePath}' is negative ({imageCount}).");
                            }
                            if (imageRowPixelCount < 1)
                            {
                                throw new Exception($"Number of row pixels listed in header information of file '{imageDataFilePath}' is less than 1 ({imageRowPixelCount}).");
                            }
                            if (imageColumnPixelCount < 1)
                            {
                                throw new Exception($"Number of column pixels listed in header information of file '{imageDataFilePath}' is less than 1 ({imageColumnPixelCount}).");
                            }

                            Int32 imageAndLabelReadCount = 0;
                            while (imageAndLabelReadCount < imageCount)
                            {
                                // Read the next image data
                                var nextImageData = new Byte[imageRowPixelCount, imageColumnPixelCount];
                                try
                                {
                                    for (Int32 currentRowIndex = 0; currentRowIndex < imageRowPixelCount; currentRowIndex++)
                                    {
                                        Byte[] currentRowData    = imageDataReader.ReadBytes(imageRowPixelCount);
                                        Int32  destinationOffset = currentRowIndex * imageRowPixelCount;
                                        Buffer.BlockCopy(currentRowData, 0, nextImageData, destinationOffset, imageRowPixelCount);
                                    }
                                }
                                catch (Exception e)
                                {
                                    throw new Exception($"Failed to read MNIST image with index {imageAndLabelReadCount} from file '{imageDataFilePath}'.", e);
                                }

                                // Read the next label
                                Int32 nextLabel;
                                try
                                {
                                    Byte labelAsByte = labelReader.ReadByte();
                                    nextLabel = labelAsByte;
                                }
                                catch (Exception e)
                                {
                                    throw new Exception($"Failed to read MNIST label with index {imageAndLabelReadCount} from file '{labelFilePath}'.", e);
                                }

                                var nextMnistImage = new MnistImage(nextImageData, nextLabel);
                                yield return(nextMnistImage);

                                imageAndLabelReadCount++;
                            }
                        }
        }
Пример #12
0
 private MnistImageVM CreateVM(MnistImage img) => new MnistImageVM(img);
Пример #13
0
 private void PredictLabel(MnistImage img)
 {
     PredictedLabelForSelectedImage = _trainer.Predict(img);
 }