public void Create(DatasetConfiguration config, IXDatasetCreatorProgress progress)
        {
            string strTrainingBatchFile1 = Properties.Settings.Default.TrainingDataFile1;
            string strTrainingBatchFile2 = Properties.Settings.Default.TrainingDataFile2;
            string strTrainingBatchFile3 = Properties.Settings.Default.TrainingDataFile3;
            string strTrainingBatchFile4 = Properties.Settings.Default.TrainingDataFile4;
            string strTrainingBatchFile5 = Properties.Settings.Default.TrainingDataFile5;
            string strTestingBatchFile   = Properties.Settings.Default.TestingDataFile;
            string strDsName             = config.Name;
            string strTrainingSrc        = config.Name + ".training";
            string strTestingSrc         = config.Name + ".testing";
            int    nIdx   = 0;
            int    nTotal = 50000;

            m_bCancel   = false;
            m_iprogress = progress;
            m_factory.DeleteSources(strTrainingSrc, strTestingSrc);

            Log log = new Log("CIFAR Dataset Creator");

            log.OnWriteLine += new EventHandler <LogArg>(log_OnWriteLine);

            try
            {
                DataConfigSetting dsTrainingDataFile1 = config.Settings.Find("Training Data File 1");
                DataConfigSetting dsTrainingDataFile2 = config.Settings.Find("Training Data File 2");
                DataConfigSetting dsTrainingDataFile3 = config.Settings.Find("Training Data File 3");
                DataConfigSetting dsTrainingDataFile4 = config.Settings.Find("Training Data File 4");
                DataConfigSetting dsTrainingDataFile5 = config.Settings.Find("Training Data File 5");
                DataConfigSetting dsTestingDataFile   = config.Settings.Find("Testing Data File");

                strTrainingBatchFile1 = dsTrainingDataFile1.Value.ToString();
                if (strTrainingBatchFile1.Length == 0)
                {
                    throw new Exception("Training data file #1 name not specified!");
                }

                strTrainingBatchFile2 = dsTrainingDataFile2.Value.ToString();
                if (strTrainingBatchFile2.Length == 0)
                {
                    throw new Exception("Training data file #2 name not specified!");
                }

                strTrainingBatchFile3 = dsTrainingDataFile3.Value.ToString();
                if (strTrainingBatchFile3.Length == 0)
                {
                    throw new Exception("Training data file #3 name not specified!");
                }

                strTrainingBatchFile4 = dsTrainingDataFile4.Value.ToString();
                if (strTrainingBatchFile4.Length == 0)
                {
                    throw new Exception("Training data file #4 name not specified!");
                }

                strTrainingBatchFile5 = dsTrainingDataFile5.Value.ToString();
                if (strTrainingBatchFile5.Length == 0)
                {
                    throw new Exception("Training data file #5 name not specified!");
                }

                strTestingBatchFile = dsTestingDataFile.Value.ToString();
                if (strTestingBatchFile.Length == 0)
                {
                    throw new Exception("Testing data file name not specified!");
                }

                log.WriteLine("Loading the data files...");

                if (m_bCancel)
                {
                    return;
                }

                int nTrainSrcId = m_factory.AddSource(strTrainingSrc, 3, 32, 32, false, 0);
                m_factory.Open(nTrainSrcId, 500, Database.FORCE_LOAD.FROM_FILE); // use file based data.

                log.WriteLine("Deleting existing data from '" + m_factory.OpenSource.Name + "'.");
                m_factory.DeleteSourceData();

                if (!loadFile(log, dsTrainingDataFile1.Name, strTrainingBatchFile1, m_factory, nTotal, true, ref nIdx))
                {
                    return;
                }

                if (!loadFile(log, dsTrainingDataFile2.Name, strTrainingBatchFile2, m_factory, nTotal, true, ref nIdx))
                {
                    return;
                }

                if (!loadFile(log, dsTrainingDataFile3.Name, strTrainingBatchFile3, m_factory, nTotal, true, ref nIdx))
                {
                    return;
                }

                if (!loadFile(log, dsTrainingDataFile4.Name, strTrainingBatchFile4, m_factory, nTotal, true, ref nIdx))
                {
                    return;
                }

                if (!loadFile(log, dsTrainingDataFile5.Name, strTrainingBatchFile5, m_factory, nTotal, true, ref nIdx))
                {
                    return;
                }

                m_factory.UpdateSourceCounts();
                updateLabels(m_factory);

                log.WriteLine("Creating the image mean...");
                SimpleDatum dMean = SimpleDatum.CalculateMean(log, m_rgImages.ToArray(), new WaitHandle[] { new ManualResetEvent(false) });
                m_factory.PutRawImageMean(dMean, true);
                m_rgImages.Clear();

                m_factory.Close();

                int nTestSrcId = m_factory.AddSource(strTestingSrc, 3, 32, 32, false, 0);
                m_factory.Open(nTestSrcId, 500, Database.FORCE_LOAD.FROM_FILE); // use file based data.

                log.WriteLine("Deleting existing data from '" + m_factory.OpenSource.Name + "'.");
                m_factory.DeleteSourceData();

                nIdx   = 0;
                nTotal = 10000;

                if (!loadFile(log, dsTestingDataFile.Name, strTestingBatchFile, m_factory, nTotal, false, ref nIdx))
                {
                    return;
                }

                m_factory.CopyImageMean(strTrainingSrc, strTestingSrc);

                m_factory.UpdateSourceCounts();
                updateLabels(m_factory);
                m_factory.Close();

                log.WriteLine("Done loading training and testing data.");

                using (DNNEntities entities = EntitiesConnection.CreateEntities())
                {
                    List <Source> rgSrcTraining = entities.Sources.Where(p => p.Name == strTrainingSrc).ToList();
                    List <Source> rgSrcTesting  = entities.Sources.Where(p => p.Name == strTestingSrc).ToList();

                    if (rgSrcTraining.Count == 0)
                    {
                        throw new Exception("Could not find the training source '" + strTrainingSrc + "'.");
                    }

                    if (rgSrcTesting.Count == 0)
                    {
                        throw new Exception("Could not find the tesing source '" + strTestingSrc + "'.");
                    }

                    DataConfigSetting dsName = config.Settings.Find("Output Dataset Name");
                    int    nSrcTestingCount  = rgSrcTesting[0].ImageCount.GetValueOrDefault();
                    int    nSrcTrainingCount = rgSrcTraining[0].ImageCount.GetValueOrDefault();
                    int    nSrcTotalCount    = nSrcTestingCount + nSrcTrainingCount;
                    double dfTestingPct      = (nSrcTrainingCount == 0) ? 0.0 : nSrcTestingCount / (double)nSrcTotalCount;

                    Dataset ds = new Dataset();
                    ds.ImageHeight      = rgSrcTraining[0].ImageHeight;
                    ds.ImageWidth       = rgSrcTraining[0].ImageWidth;
                    ds.Name             = strDsName;
                    ds.ImageEncoded     = rgSrcTesting[0].ImageEncoded;
                    ds.ImageChannels    = rgSrcTesting[0].ImageChannels;
                    ds.TestingPercent   = (decimal)dfTestingPct;
                    ds.TestingSourceID  = rgSrcTesting[0].ID;
                    ds.TestingTotal     = rgSrcTesting[0].ImageCount;
                    ds.TrainingSourceID = rgSrcTraining[0].ID;
                    ds.TrainingTotal    = rgSrcTraining[0].ImageCount;
                    ds.DatasetCreatorID = config.ID;
                    ds.DatasetGroupID   = 0;
                    ds.ModelGroupID     = 0;

                    entities.Datasets.Add(ds);
                    entities.SaveChanges();
                }
            }
            catch (Exception excpt)
            {
                log.WriteLine("ERROR: " + excpt.Message);
            }
            finally
            {
                Properties.Settings.Default.TrainingDataFile1 = strTrainingBatchFile1;
                Properties.Settings.Default.TrainingDataFile2 = strTrainingBatchFile2;
                Properties.Settings.Default.TrainingDataFile3 = strTrainingBatchFile3;
                Properties.Settings.Default.TrainingDataFile4 = strTrainingBatchFile4;
                Properties.Settings.Default.TrainingDataFile5 = strTrainingBatchFile5;
                Properties.Settings.Default.TestingDataFile   = strTestingBatchFile;
                Properties.Settings.Default.Save();

                if (m_bCancel)
                {
                    log.WriteLine("ABORTED converting CIFAR data files.");
                }
                else
                {
                    log.WriteLine("Done converting CIFAR data files.");
                }

                if (m_bCancel)
                {
                    m_iprogress.OnCompleted(new CreateProgressArgs(nIdx, nTotal, "ABORTED!", null, true));
                }
                else
                {
                    m_iprogress.OnCompleted(new CreateProgressArgs(1, "COMPLETED."));
                }
            }
        }
示例#2
0
        public uint ConvertData(string strImageFile, string strLabelFile, string strDBPath, string strDBPathMean, bool bCreateImgMean, bool bGetItemCountOnly = false, int nChannels = 1)
        {
            string             strExt;
            List <SimpleDatum> rgImg = new List <SimpleDatum>();

            strExt = Path.GetExtension(strImageFile).ToLower();
            if (strExt == ".gz")
            {
                m_log.WriteLine("Unpacking '" + strImageFile + "'...");
                strImageFile = expandFile(strImageFile);
            }

            strExt = Path.GetExtension(strLabelFile).ToLower();
            if (strExt == ".gz")
            {
                m_log.WriteLine("Unpacking '" + strLabelFile + "'...");
                strLabelFile = expandFile(strLabelFile);
            }

            BinaryFile image_file = new BinaryFile(strImageFile);
            BinaryFile label_file = new BinaryFile(strLabelFile);

            try
            {
                uint magicImg = image_file.ReadUInt32();
                uint magicLbl = label_file.ReadUInt32();

                if (magicImg != 2051)
                {
                    if (m_log != null)
                    {
                        m_log.FAIL("Incorrect image file magic.");
                    }

                    if (OnLoadError != null)
                    {
                        OnLoadError(this, new LoadErrorArgs("Incorrect image file magic."));
                    }
                }

                if (magicLbl != 2049)
                {
                    if (m_log != null)
                    {
                        m_log.FAIL("Incorrect label file magic.");
                    }

                    if (OnLoadError != null)
                    {
                        OnLoadError(this, new LoadErrorArgs("Incorrect label file magic."));
                    }
                }

                uint num_items  = image_file.ReadUInt32();
                uint num_labels = label_file.ReadUInt32();

                if (num_items != num_labels)
                {
                    if (m_log != null)
                    {
                        m_log.FAIL("The number of items must equal the number of labels.");
                    }

                    throw new Exception("The number of items must equal the number of labels." + Environment.NewLine + "  Label File: '" + strLabelFile + Environment.NewLine + "  Image File: '" + strImageFile + "'.");
                }

                if (bGetItemCountOnly)
                {
                    return(num_items);
                }

                uint rows = image_file.ReadUInt32();
                uint cols = image_file.ReadUInt32();

                int nSrcId = m_factory.AddSource(strDBPath, nChannels, (int)cols, (int)rows, false, 0, true);
                m_factory.Open(nSrcId, 500, Database.FORCE_LOAD.FROM_FILE); // use file based data.
                m_factory.DeleteSourceData();

                // Storing to db
                byte[] rgLabel;
                byte[] rgPixels;

                Datum datum = new Datum(false, nChannels, (int)cols, (int)rows);

                if (m_log != null)
                {
                    m_log.WriteHeader("LOADING " + strDBPath + " items.");
                    m_log.WriteLine("A total of " + num_items.ToString() + " items.");
                    m_log.WriteLine("Rows: " + rows.ToString() + " Cols: " + cols.ToString());
                }

                if (OnLoadStart != null)
                {
                    OnLoadStart(this, new LoadStartArgs((int)num_items));
                }

                for (int item_id = 0; item_id < num_items; item_id++)
                {
                    rgPixels = image_file.ReadBytes((int)(rows * cols));
                    rgLabel  = label_file.ReadBytes(1);

                    List <byte> rgData = new List <byte>(rgPixels);

                    if (nChannels == 3)
                    {
                        rgData.AddRange(new List <byte>(rgPixels));
                        rgData.AddRange(new List <byte>(rgPixels));
                    }

                    datum.SetData(rgData, (int)rgLabel[0]);

                    if (m_bmpTargetOverlay != null)
                    {
                        datum = createTargetOverlay(datum);
                    }

                    m_factory.PutRawImageCache(item_id, datum);

                    if (bCreateImgMean)
                    {
                        rgImg.Add(new SimpleDatum(datum));
                    }

                    if ((item_id % 1000) == 0)
                    {
                        if (m_log != null)
                        {
                            m_log.WriteLine("Loaded " + item_id.ToString("N") + " items...");
                            m_log.Progress = (double)item_id / (double)num_items;
                        }

                        if (OnLoadProgress != null)
                        {
                            LoadArgs args = new LoadArgs(item_id);
                            OnLoadProgress(this, args);

                            if (args.Cancel)
                            {
                                break;
                            }
                        }
                    }
                }

                m_factory.ClearImageCache(true);
                m_factory.UpdateSourceCounts();

                if (bCreateImgMean)
                {
                    if (strDBPath != strDBPathMean)
                    {
                        m_factory.CopyImageMean(strDBPathMean, strDBPath);
                    }
                    else
                    {
                        m_log.WriteLine("Creating image mean...");
                        SimpleDatum dMean = SimpleDatum.CalculateMean(m_log, rgImg.ToArray(), new WaitHandle[] { new ManualResetEvent(false) });
                        m_factory.PutRawImageMean(dMean, true);
                    }
                }

                if (OnLoadProgress != null)
                {
                    LoadArgs args = new LoadArgs((int)num_items);
                    OnLoadProgress(this, args);
                }

                return(num_items);
            }
            finally
            {
                image_file.Dispose();
                label_file.Dispose();
            }
        }