public DataLoaderV1(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false, Sampler sampler = null, string last_batch = null, BatchSampler batch_sampler = null, Func <NDArrayList, NDArrayList> batchify_fn = null, int num_workers = 0, bool pin_memory = false, int pin_device_id = 0) { _dataset = dataset; _pin_memory = pin_memory; _pin_device_id = pin_device_id; if (batch_sampler == null) { if (!batch_size.HasValue) { throw new Exception("batch_size must be specified unless " + "batch_sampler is specified"); } if (sampler == null) { if (shuffle) { sampler = new RandomSampler(_dataset.Length); } else { sampler = new SequentialSampler(_dataset.Length); } } else { throw new Exception("shuffle must not be specified if sampler is specified"); } batch_sampler = new BatchSampler(sampler, batch_size.Value, !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep"); } else if (batch_size.HasValue || shuffle || sampler != null || !string.IsNullOrWhiteSpace(last_batch)) { throw new Exception("batch_size, shuffle, sampler and last_batch must " + "not be specified if batch_sampler is specified."); } _batch_sampler = batch_sampler; _num_workers = num_workers >= 0 ? num_workers : num_workers; if (batchify_fn == null) { if (num_workers > 0) { _batchify_fn = DataLoader.DefaultMPBatchifyFn; } else { _batchify_fn = DataLoader.DefaultBatchifyFn; } } else { _batchify_fn = batchify_fn; } }
public DataLoader(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false, Sampler sampler = null, string last_batch = null, BatchSampler batch_sampler = null, Func <NDArrayList, NDArrayList> batchify_fn = null, int num_workers = 0, bool pin_memory = false, int pin_device_id = 0, int?prefetch = null, bool thread_pool = false) { _dataset = dataset; _pin_memory = pin_memory; _pin_device_id = pin_device_id; _thread_pool = thread_pool; if (batch_sampler == null) { if (!batch_size.HasValue) { throw new Exception("batch_size must be specified unless " + "batch_sampler is specified"); } if (sampler == null) { if (shuffle) { sampler = new RandomSampler(dataset.Length); } else { sampler = new SequentialSampler(dataset.Length); } } else if (shuffle) { throw new Exception("shuffle must not be specified if sampler is specified"); } batch_sampler = new BatchSampler(sampler, batch_size.Value, !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep"); } _batch_sampler = batch_sampler; _num_workers = num_workers >= 0 ? num_workers : 0; _prefetch = Math.Max(0, prefetch.HasValue ? prefetch.Value : 2 * _num_workers); if (_num_workers > 0) { if (thread_pool) { _worker_pool = new WorkerPool(_num_workers); } else { _worker_pool = new WorkerPool(_num_workers, WorkerInitializer, dataset); } ThreadPool.SetMinThreads(_worker_pool.NumThreads, _worker_pool.NumThreads); } if (batchify_fn == null) { if (_num_workers > 0) { _batchify_fn = DefaultBatchifyFn; } else { _batchify_fn = DefaultBatchifyFn; } } else { _batchify_fn = batchify_fn; } }