예제 #1
0
        public DataParallelExecutorGroup(Symbol sym,
                                         IList <string> argNames,
                                         IList <string> paramNames,
                                         IList <Context> ctx,
                                         IList <Tuple <int, int> > slices,
                                         IDataIProvide trainData,
                                         DataParallelExecutorGroup sharedGroup = null)
        {
            Model.CheckArguments(sym);

            if (sharedGroup == null)
            {
                this._sharedDataArrays = ctx.Select(s => new Dictionary <string, NdArray>()).ToList();
            }
            else
            {
                this._sharedDataArrays = sharedGroup._sharedDataArrays;
            }
            this._dataNames  = trainData.ProvideData.Select(s => s.Key).ToList();
            this._labelNames = trainData.ProvideLabel.Select(s => s.Key).ToList();
            this._auxNames   = sym.ListAuxiliaryStates();

            this._paramIdx   = argNames.Select((x, i) => new { x = x, i = i }).Where(w => paramNames.Contains(w.x)).Select(s => s.i).ToList();
            this._paramNames = _paramIdx.Select(s => argNames[s]).ToList();

            this.TrainExecs = new List <Executor>();
            for (int i = 0; i < ctx.Count; i++)
            {
                var concat = trainData.ProvideData.Concat(trainData.ProvideLabel);

                var dataShapes = concat.ToDictionary(kvK => kvK.Key,
                                                     kvV =>
                {
                    List <uint> tuple = new List <uint>();
                    tuple.Add((uint)(slices[i].Item2 - slices[i].Item1));
                    tuple.AddRange(kvV.Value.Data().Skip(1));
                    return(tuple.ToArray());
                });

                var sharedExec = sharedGroup == null ? null : sharedGroup.TrainExecs[i];

                var trainExec = BindExec(sym, ctx[i], dataShapes, this._paramNames,
                                         needGradInput: true, baseExec: sharedExec,
                                         sharedDataArrays: this._sharedDataArrays[i]);

                this.TrainExecs.Add(trainExec);
            }


            this._dataArrays  = _dataNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList();
            this._labelArrays = _labelNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList();

            this.ParamArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.ArgArrays[i]).ToList()).ToList();

            this.GradArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.GradArrays[i]).ToList()).ToList();

            this.AuxArrays = Enumerable.Range(0, this._auxNames.Count).Select(i => TrainExecs.Select((e) => e.AuxArrays[i]).ToList()).ToList();

            this._slices = slices;
        }
예제 #2
0
        public DataParallelExecutorGroup(Symbol sym,
                                         List <string> arg_names,
                                         List <string> param_names,
                                         List <Context> ctx,
                                         List <Tuple <int, int> > slices,
                                         IDataIProvide train_data,
                                         DataParallelExecutorGroup shared_group = null)
        {
            Util._check_arguments(sym);

            if (shared_group == null)
            {
                this._shared_data_arrays = ctx.Select(s => new Dictionary <string, NDArray>()).ToList();
            }
            else
            {
                this._shared_data_arrays = shared_group._shared_data_arrays;
            }
            this._data_names  = train_data.provide_data.Select(s => s.Key).ToList();
            this._label_names = train_data.provide_label.Select(s => s.Key).ToList();
            this._aux_names   = sym.ListAuxiliaryStates();

            this._param_idx   = arg_names.Select((x, i) => new { x = x, i = i }).Where(w => param_names.Contains(w.x)).Select(s => s.i).ToList();
            this._param_names = _param_idx.Select(s => arg_names[s]).ToList();

            this.train_execs = new List <Executor>();
            for (int i = 0; i < ctx.Count; i++)
            {
                var concat = train_data.provide_data.Concat(train_data.provide_label);

                var data_shapes = concat.ToDictionary(kv_k => kv_k.Key,
                                                      kv_v =>
                {
                    List <uint> tuple = new List <uint>();
                    tuple.Add((uint)(slices[i].Item2 - slices[i].Item1));
                    tuple.AddRange(kv_v.Value.Data().Skip(1));
                    return(tuple.ToArray());
                });

                var shared_exec = shared_group == null ? null : shared_group.train_execs[i];

                var train_exec = _bind_exec(sym, ctx[i], data_shapes, this._param_names,
                                            need_grad_input: true, base_exec: shared_exec,
                                            shared_data_arrays: this._shared_data_arrays[i]);

                this.train_execs.Add(train_exec);
            }


            this._data_arrays  = _data_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList();
            this._label_arrays = _label_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList();

            this.param_arrays = _param_idx.Select(i => train_execs.Select((e) => e.arg_arrays[i]).ToList()).ToList();

            this.grad_arrays = _param_idx.Select(i => train_execs.Select((e) => e.grad_arrays[i]).ToList()).ToList();

            this._aux_arrays = Enumerable.Range(0, this._aux_names.Count).Select(i => train_execs.Select((e) => e.aux_arrays[i]).ToList()).ToList();

            this._slices = slices;
        }
예제 #3
0
        public DataParallelExecutorManager(Symbol symbol,
                                           SymbolGenerate symGen,
                                           IList <Context> ctx,
                                           IDataIter trainData,
                                           IList <string> paramNames,
                                           IList <string> argNames,
                                           IList <string> auxNames,
                                           IList <int> workLoadList,
                                           ILog logger)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            this._logger = logger;

            var numDevice = ctx.Count;

            logger.Info("Start training with " + string.Join("", ctx));

            if (workLoadList == null)
            {
                workLoadList = Enumerable.Repeat(1, numDevice).ToList();
            }
            Util.Assert(workLoadList.Count == numDevice, "Invalid settings for work load. ");

            var slices = SplitInputSlice(trainData.BatchSize, workLoadList);

            this._slices = slices;

            this._argNames  = argNames;
            this.ParamNames = paramNames;
            this._auxNames  = auxNames;
            this._ctx       = ctx;

            this._execgrp = new DataParallelExecutorGroup(symbol, this._argNames, this.ParamNames, this._ctx,
                                                          this._slices, trainData);

            this._symbol = symbol;

            this._symGen      = symGen;
            this._currExecgrp = null;
            // this is set when data is loaded
            if (this._symGen != null)
            {
                this._execgrpBucket = new Dictionary <string, DataParallelExecutorGroup>()
                {
                    { trainData.DefaultBucketKey, this._execgrp }
                };
            }
        }
예제 #4
0
        public DataParallelExecutorManager(Symbol symbol, SymbolGenerate sym_gen, List <Context> ctx,
                                           IDataIter train_data, List <string> param_names, List <string> arg_names,
                                           List <string> aux_names, List <int> work_load_list, ILog logger)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            this._logger = logger;

            var num_device = ctx.Count;

            logger.Info("Start training with " + string.Join("", ctx));

            if (work_load_list == null)
            {
                work_load_list = Enumerable.Repeat(1, num_device).ToList();
            }
            Util.Assert(work_load_list.Count == num_device, "Invalid settings for work load. ");

            var slices = _split_input_slice(train_data.batch_size, work_load_list);

            this._slices = slices;

            this._arg_names  = arg_names;
            this.param_names = param_names;
            this._aux_names  = aux_names;
            this._ctx        = ctx;

            this._execgrp = new DataParallelExecutorGroup(symbol, this._arg_names, this.param_names, this._ctx,
                                                          this._slices, train_data);

            this._symbol = symbol;

            this._sym_gen      = sym_gen;
            this._curr_execgrp = null;
            // this is set when data is loaded
            if (this._sym_gen != null)
            {
                this._execgrp_bucket = new Dictionary <string, DataParallelExecutorGroup>()
                {
                    { train_data.default_bucket_key, this._execgrp }
                };
            }
        }
예제 #5
0
 public void LoadDataBatch(IDataBatch dataBatch)
 {
     if (this._symGen != null)
     {
         var key = dataBatch.BucketKey;
         if (!_execgrpBucket.ContainsKey(key))
         {
             //create new bucket entry
             var symbol  = _symGen(key);
             var execgrp = new DataParallelExecutorGroup(symbol, this._argNames,
                                                         this.ParamNames, this._ctx,
                                                         this._slices, dataBatch,
                                                         sharedGroup: this._execgrp);
             this._execgrpBucket[key] = execgrp;
         }
         this._currExecgrp = this._execgrpBucket[key];
     }
     else
     {
         this._currExecgrp = this._execgrp;
     }
     this._currExecgrp.LoadDataBatch(dataBatch);
 }
예제 #6
0
 public void load_data_batch(IDataBatch data_batch)
 {
     if (this._sym_gen != null)
     {
         var key = data_batch.bucket_key;
         if (!_execgrp_bucket.ContainsKey(key))
         {
             //create new bucket entry
             var symbol  = _sym_gen(key);
             var execgrp = new DataParallelExecutorGroup(symbol, this._arg_names,
                                                         this.param_names, this._ctx,
                                                         this._slices, data_batch,
                                                         shared_group: this._execgrp);
             this._execgrp_bucket[key] = execgrp;
         }
         this._curr_execgrp = this._execgrp_bucket[key];
     }
     else
     {
         this._curr_execgrp = this._execgrp;
     }
     this._curr_execgrp.load_data_batch(data_batch);
 }