예제 #1
0
        public static Datasets <DataSetMnist> read_data_sets(string train_dir,
                                                             bool one_hot        = false,
                                                             TF_DataType dtype   = TF_DataType.TF_FLOAT,
                                                             bool reshape        = true,
                                                             int validation_size = 5000,
                                                             int?train_size      = null,
                                                             int?test_size       = null,
                                                             string source_url   = DEFAULT_SOURCE_URL)
        {
            if (train_size != null && validation_size >= train_size)
            {
                throw new ArgumentException("Validation set should be smaller than training set");
            }

            Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
            Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
            var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size);

            Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
            Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
            var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size);

            Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
            Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
            var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size);

            Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
            Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
            var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit: test_size);

            int end = train_images.shape[0];
            var validation_images = train_images[np.arange(validation_size)];
            var validation_labels = train_labels[np.arange(validation_size)];

            train_images = train_images[np.arange(validation_size, end)];
            train_labels = train_labels[np.arange(validation_size, end)];

            var train      = new DataSetMnist(train_images, train_labels, dtype, reshape);
            var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape);
            var test       = new DataSetMnist(test_images, test_labels, dtype, reshape);

            return(new Datasets <DataSetMnist>(train, validation, test));
        }
예제 #2
0
        public static Datasets read_data_sets(string train_dir,
                                              bool one_hot        = false,
                                              TF_DataType dtype   = TF_DataType.TF_FLOAT,
                                              bool reshape        = true,
                                              int validation_size = 5000,
                                              string source_url   = DEFAULT_SOURCE_URL)
        {
            Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
            Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
            var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));

            Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
            Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
            var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);

            Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
            Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
            var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));

            Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
            Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
            var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);

            int end = train_images.shape[0];
            var validation_images = train_images[np.arange(validation_size)];
            var validation_labels = train_labels[np.arange(validation_size)];

            train_images = train_images[np.arange(validation_size, end)];
            train_labels = train_labels[np.arange(validation_size, end)];

            var train      = new DataSet(train_images, train_labels, dtype, reshape);
            var validation = new DataSet(validation_images, validation_labels, dtype, reshape);
            var test       = new DataSet(test_images, test_labels, dtype, reshape);

            return(new Datasets(train, validation, test));
        }