Inheritance: IExtDataset
Exemplo n.º 1
0
        public static RowDataset8 GetRowDataset8(MnistDatasource mds, int[] classes)
        {
            if (mds.NSamples() == 0)
                throw new Exception("MNIST database is empty!");

            // определим максимальный индекс
            byte maxLabel = 0;
            foreach (byte label in mds.Labels)
            {
                if (label > maxLabel)
                    maxLabel = label;
            }

            // проверим соответствие индексов Mnist указанному списку классов
            if (maxLabel >= classes.Length)
                throw new Exception("Classes do not correspond to the MNIST!");

            // создаем тренировочную базу
            RowDataset8 ds = new RowDataset8(mds.NSamples());
            // перебираем MNIST базу
            for (int i = 0; i < mds.NSamples(); i++)
            {
                int label = mds.Labels[i];
                StdInput stdInp = new StdInput(mds.ImagesData[i], mds.ImgHeight, mds.ImgWidth);
                ds.Add(stdInp.ToFloatarray(), classes[label]);
            }
            return ds;
        }
Exemplo n.º 2
0
        public void TestRowDataset()
        {
            DRandomizer.Default.init_drand(DateTime.Now.Millisecond);

            // load Mnist datasource
            MnistDatasource mds = new MnistDatasource();
            mds.LoadFromFile(mnistFileNamePrefix);

            // convert mnist to RowDataset8
            RowDataset8 ds8 = MnistDatasetConvert.GetRowDataset8(mds, classes);

            // show random sample to console
            Floatarray fa = new Floatarray();
            int isample = (int)DRandomizer.Default.drand(ds8.nSamples(), 0);
            ds8.Input(fa, isample);
            Console.WriteLine("Char is '{0}'", (char)ds8.Cls(isample));
            NarrayShow.ShowConsole(fa);

            // compare random float sample and original mnist
            StdInput inp1 = new StdInput(mds.ImagesData[isample], mds.ImgHeight, mds.ImgWidth);
            StdInput inp2 = new StdInput(fa);
            Console.WriteLine("Arrays is identical? {0}", Equals(inp1.GetDataBuffer(), inp2.GetDataBuffer()));

            // save RowDataset8 to file
            Console.WriteLine("Saving {0} samples..", ds8.nSamples());
            ds8.Save(mnistFileNamePrefix + dsExt);

            // load RowDataset8 from file
            RowDataset8 ds = new RowDataset8();
            ds.Load(mnistFileNamePrefix + dsExt);
            Console.WriteLine("Loaded {0} samples", ds.nSamples());
        }
Exemplo n.º 3
0
        public static RowDataset8 GetRowDataset8(MnistDatasource mds, int[] classes)
        {
            if (mds.NSamples() == 0)
            {
                throw new Exception("MNIST database is empty!");
            }

            // определим максимальный индекс
            byte maxLabel = 0;

            foreach (byte label in mds.Labels)
            {
                if (label > maxLabel)
                {
                    maxLabel = label;
                }
            }

            // проверим соответствие индексов Mnist указанному списку классов
            if (maxLabel >= classes.Length)
            {
                throw new Exception("Classes do not correspond to the MNIST!");
            }

            // создаем тренировочную базу
            RowDataset8 ds = new RowDataset8(mds.NSamples());

            // перебираем MNIST базу
            for (int i = 0; i < mds.NSamples(); i++)
            {
                int      label  = mds.Labels[i];
                StdInput stdInp = new StdInput(mds.ImagesData[i], mds.ImgHeight, mds.ImgWidth);
                ds.Add(stdInp.ToFloatarray(), classes[label]);
            }
            return(ds);
        }
Exemplo n.º 4
0
        public void TestTrainSimple()
        {
            // create lenet
            LenetClassifier classifier = new LenetClassifier();
            classifier.Set("junk", 0);      // disable junk
            classifier.SetExtractor("scaledfe");
            classifier.Initialize(classesNums);

            StringBuilder sbout;
            classifier.GetStdout(out sbout);
            Console.Write(sbout);

            // load RowDataset8 from file
            RowDataset8 ds = new RowDataset8();
            ds.Load(trainDatasetFileName);

            // do train
            classifier.Set("epochs", 3);
            classifier.XTrain(ds);

            // save classifier to file
            classifier.Save(trainNetworkFileName);

            // test recognize
            DoTestRecognize(classifier);
        }