예제 #1
0
        private static DatasetByte LoadAsBytes(string imgPath, string lblPath)
        {
            var lblReader = new BigEndianBinaryReader(new FileStream(Path.Combine(Folder, lblPath), FileMode.Open));
            var imgReader = new BigEndianBinaryReader(new FileStream(Path.Combine(Folder, imgPath), FileMode.Open));

            lblReader.ReadInt32();
            lblReader.ReadInt32();

            int magicNum = imgReader.ReadInt32();
            int NumImgs  = imgReader.ReadInt32();
            int Rows     = imgReader.ReadInt32();
            int Cols     = imgReader.ReadInt32();
            int ImgDims  = Rows * Cols;

            Debug.Log("MNIST: Loading " + imgPath + ", Imgs: " + NumImgs + " Rows: " + Rows + " Cols: " + Cols);

            var set = new DatasetByte(NumImgs);

            for (int i = 0; i < NumImgs; i++)
            {
                byte lbl = lblReader.ReadByte();
                set.Labels[i] = (int)lbl;
            }

            for (int i = 0; i < NumImgs; i++)
            {
                // Read order flips images axes to Unity style
                for (int y = 0; y < Cols; y++)
                {
                    for (int x = 0; x < Rows; x++)
                    {
                        byte pix = imgReader.ReadByte();
                        set.Images[i * ImgDims + (Cols - 1 - y) * Rows + x] = pix;
                    }
                }
            }

            lblReader.Close();
            imgReader.Close();

            return(set);
        }
예제 #2
0
        public static void GetBatch(NativeArray <int> batch, DatasetByte set, ref Rng rng)
        {
            // Todo: can transform dataset to create additional variation

            UnityEngine.Profiling.Profiler.BeginSample("GetBatch");

            if (set.Indices.Count < batch.Length)
            {
                set.Indices.Clear();
                for (int i = 0; i < set.NumImgs; i++)
                {
                    set.Indices.Add(i);
                }
                Ramjet.Utils.Shuffle(set.Indices, ref rng);
            }

            for (int i = 0; i < batch.Length; i++)
            {
                batch[i] = set.Indices[set.Indices.Count - 1];
                set.Indices.RemoveAt(set.Indices.Count - 1);
            }

            UnityEngine.Profiling.Profiler.EndSample();
        }
예제 #3
0
 public static void LoadByteData()
 {
     TrainBytes = LoadAsBytes(TrainImagePath, TrainLabelPath);
     TestBytes  = LoadAsBytes(TestImagePath, TestLabelPath);
 }