Example #1
0
 private void ResetBind()
 {
     Binded        = false;
     _exec_group   = null;
     _data_shapes  = null;
     _label_shapes = null;
 }
Example #2
0
        public Module(Symbol symbol, string[] data_names = null, string[] label_names = null,
                      Context[] context    = null, int[] work_load_list = null, string[] fixed_param_names = null,
                      string[] state_names = null,
                      Dictionary <string, Context>[] group2ctxs = null, Dictionary <string, object> compression_params = null)
        {
            if (context == null)
            {
                context = new[] { Context.Cpu() }
            }
            ;

            if (work_load_list == null)
            {
                work_load_list = new int[context.Length];
                for (var i = 0; i < work_load_list.Length; i++)
                {
                    work_load_list[i] = 1;
                }
            }

            if (context.Length != work_load_list.Length)
            {
                throw new Exception("Context and WorkLoadList length are not equal");
            }

            _group2ctxs        = group2ctxs;
            _symbol            = symbol;
            _data_names        = data_names != null ? data_names : new string[0];
            _label_names       = label_names != null ? label_names : new string[0];
            _state_names       = state_names != null ? state_names : new string[0];
            _fixed_param_names = fixed_param_names != null ? fixed_param_names : new string[0];
            CheckInputNames(symbol, _data_names, "data", true);
            CheckInputNames(symbol, _label_names, "label", false);
            CheckInputNames(symbol, _state_names, "state", true);
            CheckInputNames(symbol, _fixed_param_names, "fixed_param", true);

            var arg_names   = symbol.ListArguments();
            var input_names = new List <string>();

            input_names.AddRange(_data_names);
            input_names.AddRange(_label_names);
            input_names.AddRange(_state_names);

            _param_names  = arg_names.Where(x => arg_names.Contains(x)).ToArray();
            _aux_names    = symbol.ListAuxiliaryStates().ToArray();
            OutputNames   = symbol.ListOutputs().ToArray();
            _arg_params   = null;
            _aux_params   = null;
            _params_dirty = false;

            _compression_params = compression_params;
            _optimizer          = null;
            _kvstore            = null;
            _update_on_kvstore  = null;
            _updater            = null;
            _preload_opt_states = null;
            _grad_req           = OpGradReq.Null;

            _exec_group   = null;
            _data_shapes  = null;
            _label_shapes = null;
        }
Example #3
0
        public override void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true,
                                  bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null,
                                  OpGradReq grad_req    = OpGradReq.Write)
        {
            if (force_rebind)
            {
                ResetBind();
            }

            if (Binded)
            {
                Logger.Warning("Already bound, ignoring bind()");
                return;
            }

            ForTraining    = for_training;
            InputsNeedGrad = inputs_need_grad;
            _grad_req      = grad_req;

            if (!ForTraining)
            {
                if (InputsNeedGrad)
                {
                    throw new Exception("inputs_need_grad should be false if for_training=false");
                }
            }

            (_data_shapes, _label_shapes) = ParseDataDesc(DataNames, LabelNames, data_shapes, label_shapes);

            DataParallelExecutorGroup shared_group = null;

            if (shared_module != null)
            {
                if (!shared_module.Binded && !shared_module.ParamsInitialized)
                {
                    throw new Exception("shared_module not bounded or initialized");
                }

                shared_group = shared_module._exec_group;
                if (shared_group.Execs.Count < _context.Length)
                {
                    throw new Exception("shared_group execs length is less than context length");
                }
            }

            _exec_group = new DataParallelExecutorGroup(_symbol, _context, _work_load_list, _data_shapes, _label_shapes,
                                                        _param_names,
                                                        ForTraining, InputsNeedGrad, shared_group, _fixed_param_names, grad_req,
                                                        _state_names, _group2ctxs);
            _total_exec_bytes = _exec_group._total_exec_bytes;
            if (shared_group != null)
            {
                ParamsInitialized = true;
                _arg_params       = shared_group.ArgParams;
                _aux_params       = shared_group.AuxParams;
            }
            else if (ParamsInitialized)
            {
                _exec_group.SetParams(_arg_params, _aux_params);
            }
            else
            {
                if (_arg_params != null && _aux_params != null)
                {
                    throw new Exception("arg and aux params should be null");
                }
                _arg_params = new NDArrayDict();
                _aux_params = new NDArrayDict();

                var param_arrays = _exec_group.ParamArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _param_names.Length; i++)
                {
                    _arg_params[_param_names[i]] = param_arrays[i];
                }

                var aux_arrays = _exec_group.AuxArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _aux_names.Length; i++)
                {
                    _aux_params[_aux_names[i]] = aux_arrays[i];
                }
            }

            if (shared_module != null && shared_module.ParamsInitialized)
            {
                BorrowOptimizer(shared_module);
            }

            Binded = true;
        }