Ejemplo n.º 1
0
        public DataLoaderV1(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false,
                            Sampler sampler   = null,
                            string last_batch = null, BatchSampler batch_sampler = null,
                            Func <NDArrayList, NDArrayList> batchify_fn = null,
                            int num_workers = 0, bool pin_memory = false, int pin_device_id = 0)
        {
            _dataset       = dataset;
            _pin_memory    = pin_memory;
            _pin_device_id = pin_device_id;
            if (batch_sampler == null)
            {
                if (!batch_size.HasValue)
                {
                    throw new Exception("batch_size must be specified unless " +
                                        "batch_sampler is specified");
                }

                if (sampler == null)
                {
                    if (shuffle)
                    {
                        sampler = new RandomSampler(_dataset.Length);
                    }
                    else
                    {
                        sampler = new SequentialSampler(_dataset.Length);
                    }
                }
                else
                {
                    throw new Exception("shuffle must not be specified if sampler is specified");
                }

                batch_sampler = new BatchSampler(sampler, batch_size.Value,
                                                 !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep");
            }
            else if (batch_size.HasValue || shuffle || sampler != null || !string.IsNullOrWhiteSpace(last_batch))
            {
                throw new Exception("batch_size, shuffle, sampler and last_batch must " +
                                    "not be specified if batch_sampler is specified.");
            }

            _batch_sampler = batch_sampler;
            _num_workers   = num_workers >= 0 ? num_workers : num_workers;
            if (batchify_fn == null)
            {
                if (num_workers > 0)
                {
                    _batchify_fn = DataLoader.DefaultMPBatchifyFn;
                }
                else
                {
                    _batchify_fn = DataLoader.DefaultBatchifyFn;
                }
            }
            else
            {
                _batchify_fn = batchify_fn;
            }
        }
Ejemplo n.º 2
0
        public _MultiWorkerIterV1(int num_workers, Dataset <NDArray> dataset, Func <NDArrayList, NDArrayList> batchify_fn,
                                  BatchSampler batch_sampler, bool pin_memory = false, int pin_device_id = 0, WorkerFn worker_fn = null)
        {
            if (num_workers == 0)
            {
                throw new Exception($"_MultiWorkerIter is not for {num_workers} workers");
            }

            _num_workers      = num_workers;
            _dataset          = dataset;
            _batchify_fn      = batchify_fn;
            _batch_sampler    = batch_sampler;
            _key_queue        = new Queue <int>();
            _data_queue       = new Queue <NDArray>();
            _data_buffer      = new Dictionary <int, NDArrayList>();
            _data_buffer_lock = new object();
            _rcvd_idx         = 0;
            _sent_idx         = 0;
            _iter             = _batch_sampler.GetEnumerator();
            _shutdown         = false;

            _workers = new List <Thread>();
            for (var i = 0; i < _num_workers; i++)
            {
                var t = new Thread(obj => { worker_fn(_dataset, _key_queue, _data_queue, _batchify_fn); });

                t.IsBackground = true;
                t.Start();
                _workers.Add(t);
            }

            _fetcher = new Thread(obj =>
            {
                DataLoader.FetcherLoopV1(_data_queue, _data_buffer, pin_memory, pin_device_id, _data_buffer_lock);
            });

            _fetcher.IsBackground = true;
            _fetcher.Start();

            for (var i = 0; i < 2 * _num_workers; i++)
            {
                PushNext();
            }
        }
Ejemplo n.º 3
0
 public _MultiWorkerIter(WorkerPool worker_pool, Func <NDArrayList, NDArrayList> batchify_fn,
                         BatchSampler batch_sampler,
                         bool pin_memory = false, int pin_device_id     = 0, WorkerFn worker_fn        = null,
                         int prefetch    = 0, Dataset <NDArray> dataset = null, DataLoader data_loader = null)
 {
     _worker_pool   = worker_pool;
     _batchify_fn   = batchify_fn;
     _batch_sampler = batch_sampler;
     _data_buffer   = new Dictionary <int, NDArrayList>();
     _rcvd_idx      = 0;
     _sent_idx      = 0;
     _iter          = _batch_sampler.GetEnumerator();
     _worker_fn     = worker_fn;
     _pin_memory    = pin_memory;
     _pin_device_id = pin_device_id;
     _dataset       = dataset;
     _data_loader   = data_loader;
     foreach (var item in Enumerable.Range(0, prefetch))
     {
         PushNext();
     }
 }
Ejemplo n.º 4
0
        public DataLoader(Dataset <NDArray> dataset, int?batch_size = null, bool shuffle = false,
                          Sampler sampler   = null,
                          string last_batch = null, BatchSampler batch_sampler = null,
                          Func <NDArrayList, NDArrayList> batchify_fn = null,
                          int num_workers  = 0, bool pin_memory = false, int pin_device_id = 0, int?prefetch = null,
                          bool thread_pool = false)
        {
            _dataset       = dataset;
            _pin_memory    = pin_memory;
            _pin_device_id = pin_device_id;
            _thread_pool   = thread_pool;

            if (batch_sampler == null)
            {
                if (!batch_size.HasValue)
                {
                    throw new Exception("batch_size must be specified unless " +
                                        "batch_sampler is specified");
                }

                if (sampler == null)
                {
                    if (shuffle)
                    {
                        sampler = new RandomSampler(dataset.Length);
                    }
                    else
                    {
                        sampler = new SequentialSampler(dataset.Length);
                    }
                }
                else if (shuffle)
                {
                    throw new Exception("shuffle must not be specified if sampler is specified");
                }

                batch_sampler = new BatchSampler(sampler, batch_size.Value,
                                                 !string.IsNullOrWhiteSpace(last_batch) ? last_batch : "keep");
            }

            _batch_sampler = batch_sampler;
            _num_workers   = num_workers >= 0 ? num_workers : 0;
            _prefetch      = Math.Max(0, prefetch.HasValue ? prefetch.Value : 2 * _num_workers);
            if (_num_workers > 0)
            {
                if (thread_pool)
                {
                    _worker_pool = new WorkerPool(_num_workers);
                }
                else
                {
                    _worker_pool = new WorkerPool(_num_workers, WorkerInitializer, dataset);
                }

                ThreadPool.SetMinThreads(_worker_pool.NumThreads, _worker_pool.NumThreads);
            }

            if (batchify_fn == null)
            {
                if (_num_workers > 0)
                {
                    _batchify_fn = DefaultBatchifyFn;
                }
                else
                {
                    _batchify_fn = DefaultBatchifyFn;
                }
            }
            else
            {
                _batchify_fn = batchify_fn;
            }
        }