public static RowDataset8 GetRowDataset8(MnistDatasource mds, int[] classes) { if (mds.NSamples() == 0) throw new Exception("MNIST database is empty!"); // определим максимальный индекс byte maxLabel = 0; foreach (byte label in mds.Labels) { if (label > maxLabel) maxLabel = label; } // проверим соответствие индексов Mnist указанному списку классов if (maxLabel >= classes.Length) throw new Exception("Classes do not correspond to the MNIST!"); // создаем тренировочную базу RowDataset8 ds = new RowDataset8(mds.NSamples()); // перебираем MNIST базу for (int i = 0; i < mds.NSamples(); i++) { int label = mds.Labels[i]; StdInput stdInp = new StdInput(mds.ImagesData[i], mds.ImgHeight, mds.ImgWidth); ds.Add(stdInp.ToFloatarray(), classes[label]); } return ds; }
public void TestRowDataset() { DRandomizer.Default.init_drand(DateTime.Now.Millisecond); // load Mnist datasource MnistDatasource mds = new MnistDatasource(); mds.LoadFromFile(mnistFileNamePrefix); // convert mnist to RowDataset8 RowDataset8 ds8 = MnistDatasetConvert.GetRowDataset8(mds, classes); // show random sample to console Floatarray fa = new Floatarray(); int isample = (int)DRandomizer.Default.drand(ds8.nSamples(), 0); ds8.Input(fa, isample); Console.WriteLine("Char is '{0}'", (char)ds8.Cls(isample)); NarrayShow.ShowConsole(fa); // compare random float sample and original mnist StdInput inp1 = new StdInput(mds.ImagesData[isample], mds.ImgHeight, mds.ImgWidth); StdInput inp2 = new StdInput(fa); Console.WriteLine("Arrays is identical? {0}", Equals(inp1.GetDataBuffer(), inp2.GetDataBuffer())); // save RowDataset8 to file Console.WriteLine("Saving {0} samples..", ds8.nSamples()); ds8.Save(mnistFileNamePrefix + dsExt); // load RowDataset8 from file RowDataset8 ds = new RowDataset8(); ds.Load(mnistFileNamePrefix + dsExt); Console.WriteLine("Loaded {0} samples", ds.nSamples()); }
public static RowDataset8 GetRowDataset8(MnistDatasource mds, int[] classes) { if (mds.NSamples() == 0) { throw new Exception("MNIST database is empty!"); } // определим максимальный индекс byte maxLabel = 0; foreach (byte label in mds.Labels) { if (label > maxLabel) { maxLabel = label; } } // проверим соответствие индексов Mnist указанному списку классов if (maxLabel >= classes.Length) { throw new Exception("Classes do not correspond to the MNIST!"); } // создаем тренировочную базу RowDataset8 ds = new RowDataset8(mds.NSamples()); // перебираем MNIST базу for (int i = 0; i < mds.NSamples(); i++) { int label = mds.Labels[i]; StdInput stdInp = new StdInput(mds.ImagesData[i], mds.ImgHeight, mds.ImgWidth); ds.Add(stdInp.ToFloatarray(), classes[label]); } return(ds); }
public void TestTrainSimple() { // create lenet LenetClassifier classifier = new LenetClassifier(); classifier.Set("junk", 0); // disable junk classifier.SetExtractor("scaledfe"); classifier.Initialize(classesNums); StringBuilder sbout; classifier.GetStdout(out sbout); Console.Write(sbout); // load RowDataset8 from file RowDataset8 ds = new RowDataset8(); ds.Load(trainDatasetFileName); // do train classifier.Set("epochs", 3); classifier.XTrain(ds); // save classifier to file classifier.Save(trainNetworkFileName); // test recognize DoTestRecognize(classifier); }