public static Datasets <DataSetMnist> read_data_sets(string train_dir, bool one_hot = false, TF_DataType dtype = TF_DataType.TF_FLOAT, bool reshape = true, int validation_size = 5000, int?train_size = null, int?test_size = null, string source_url = DEFAULT_SOURCE_URL) { if (train_size != null && validation_size >= train_size) { throw new ArgumentException("Validation set should be smaller than training set"); } Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size); Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size); Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit: test_size); int end = train_images.shape[0]; var validation_images = train_images[np.arange(validation_size)]; var validation_labels = train_labels[np.arange(validation_size)]; train_images = train_images[np.arange(validation_size, end)]; train_labels = train_labels[np.arange(validation_size, end)]; var train = new DataSetMnist(train_images, train_labels, dtype, reshape); var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape); var test = new DataSetMnist(test_images, test_labels, dtype, reshape); return(new Datasets <DataSetMnist>(train, validation, test)); }
public static Datasets read_data_sets(string train_dir, bool one_hot = false, TF_DataType dtype = TF_DataType.TF_FLOAT, bool reshape = true, int validation_size = 5000, string source_url = DEFAULT_SOURCE_URL) { Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); int end = train_images.shape[0]; var validation_images = train_images[np.arange(validation_size)]; var validation_labels = train_labels[np.arange(validation_size)]; train_images = train_images[np.arange(validation_size, end)]; train_labels = train_labels[np.arange(validation_size, end)]; var train = new DataSet(train_images, train_labels, dtype, reshape); var validation = new DataSet(validation_images, validation_labels, dtype, reshape); var test = new DataSet(test_images, test_labels, dtype, reshape); return(new Datasets(train, validation, test)); }