示例#1
0
        public DataLoaderV1(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false,
                            Sampler sampler   = null,
                            string last_batch = null, BatchSampler batch_sampler = null,
                            Func <NDArrayList, NDArrayList> batchify_fn = null,
                            int num_workers = 0, bool pin_memory = false, int pin_device_id = 0)
        {
            _dataset       = dataset;
            _pin_memory    = pin_memory;
            _pin_device_id = pin_device_id;
            if (batch_sampler == null)
            {
                if (!batch_size.HasValue)
                {
                    throw new Exception("batch_size must be specified unless " +
                                        "batch_sampler is specified");
                }

                if (sampler == null)
                {
                    if (shuffle)
                    {
                        sampler = new RandomSampler(_dataset.Length);
                    }
                    else
                    {
                        sampler = new SequentialSampler(_dataset.Length);
                    }
                }
                else
                {
                    throw new Exception("shuffle must not be specified if sampler is specified");
                }

                batch_sampler = new BatchSampler(sampler, batch_size.Value,
                                                 !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep");
            }
            else if (batch_size.HasValue || shuffle || sampler != null || !string.IsNullOrWhiteSpace(last_batch))
            {
                throw new Exception("batch_size, shuffle, sampler and last_batch must " +
                                    "not be specified if batch_sampler is specified.");
            }

            _batch_sampler = batch_sampler;
            _num_workers   = num_workers >= 0 ? num_workers : num_workers;
            if (batchify_fn == null)
            {
                if (num_workers > 0)
                {
                    _batchify_fn = DataLoader.DefaultMPBatchifyFn;
                }
                else
                {
                    _batchify_fn = DataLoader.DefaultBatchifyFn;
                }
            }
            else
            {
                _batchify_fn = batchify_fn;
            }
        }
示例#2
0
        public DataLoader(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false,
                          Sampler sampler   = null,
                          string last_batch = null, BatchSampler batch_sampler = null,
                          Func <NDArrayList, NDArrayList> batchify_fn = null,
                          int num_workers  = 0, bool pin_memory = false, int pin_device_id = 0, int?prefetch = null,
                          bool thread_pool = false)
        {
            _dataset       = dataset;
            _pin_memory    = pin_memory;
            _pin_device_id = pin_device_id;
            _thread_pool   = thread_pool;

            if (batch_sampler == null)
            {
                if (!batch_size.HasValue)
                {
                    throw new Exception("batch_size must be specified unless " +
                                        "batch_sampler is specified");
                }

                if (sampler == null)
                {
                    if (shuffle)
                    {
                        sampler = new RandomSampler(dataset.Length);
                    }
                    else
                    {
                        sampler = new SequentialSampler(dataset.Length);
                    }
                }
                else if (shuffle)
                {
                    throw new Exception("shuffle must not be specified if sampler is specified");
                }

                batch_sampler = new BatchSampler(sampler, batch_size.Value,
                                                 !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep");
            }

            _batch_sampler = batch_sampler;
            _num_workers   = num_workers >= 0 ? num_workers : 0;
            _prefetch      = Math.Max(0, prefetch.HasValue ? prefetch.Value : 2 * _num_workers);
            if (_num_workers > 0)
            {
                if (thread_pool)
                {
                    _worker_pool = new WorkerPool(_num_workers);
                }
                else
                {
                    _worker_pool = new WorkerPool(_num_workers, WorkerInitializer, dataset);
                }

                ThreadPool.SetMinThreads(_worker_pool.NumThreads, _worker_pool.NumThreads);
            }

            if (batchify_fn == null)
            {
                if (_num_workers > 0)
                {
                    _batchify_fn = DefaultBatchifyFn;
                }
                else
                {
                    _batchify_fn = DefaultBatchifyFn;
                }
            }
            else
            {
                _batchify_fn = batchify_fn;
            }
        }