示例#1
0
        public override void InitParams(Initializer initializer = null, NDArrayDict arg_params = null,
                                        NDArrayDict aux_params  = null, bool allow_missing     = false, bool force_init = false,
                                        bool allow_extra        = false)
        {
            if (ParamsInitialized && !force_init)
            {
                Logger.Warning("Parameters already initialized and force_init=False. init_params call ignored.");
                return;
            }

            if (!Binded)
            {
                throw new Exception("call bind before initializing the parameters");
            }

            void impl(InitDesc name, ref NDArray arr, NDArrayDict cache)
            {
                if (cache != null)
                {
                    NDArray cache_arr = null;
                    if (cache.Contains(name.Name))
                    {
                        cache_arr = cache[name.Name];

                        if (cache_arr != arr)
                        {
                            cache_arr.CopyTo(arr);
                        }
                    }
                    else
                    {
                        if (!allow_missing)
                        {
                            throw new Exception($"{name.Name} is not presented");
                        }

                        if (initializer != null)
                        {
                            initializer.InitWeight(name.Name, ref arr);
                        }
                    }
                }
            }

            var attr = Symbol.AttrDict();

            foreach (var name in _arg_params.Keys)
            {
                var arr  = _arg_params[name];
                var desc = new InitDesc(name, attr.ContainsKey(name) ? attr[name] : null);
                impl(desc, ref arr, arg_params);
            }

            foreach (var name in _aux_params.Keys)
            {
                var arr  = _aux_params[name];
                var desc = new InitDesc(name, attr.ContainsKey(name) ? attr[name] : null);
                impl(desc, ref arr, aux_params);
            }

            ParamsInitialized = true;
            _params_dirty     = false;
            _exec_group.SetParams(_arg_params, _aux_params, allow_extra);
        }
示例#2
0
        public override void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true,
                                  bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null,
                                  OpGradReq grad_req    = OpGradReq.Write)
        {
            if (force_rebind)
            {
                ResetBind();
            }

            if (Binded)
            {
                Logger.Warning("Already bound, ignoring bind()");
                return;
            }

            ForTraining    = for_training;
            InputsNeedGrad = inputs_need_grad;
            _grad_req      = grad_req;

            if (!ForTraining)
            {
                if (InputsNeedGrad)
                {
                    throw new Exception("inputs_need_grad should be false if for_training=false");
                }
            }

            (_data_shapes, _label_shapes) = ParseDataDesc(DataNames, LabelNames, data_shapes, label_shapes);

            DataParallelExecutorGroup shared_group = null;

            if (shared_module != null)
            {
                if (!shared_module.Binded && !shared_module.ParamsInitialized)
                {
                    throw new Exception("shared_module not bounded or initialized");
                }

                shared_group = shared_module._exec_group;
                if (shared_group.Execs.Count < _context.Length)
                {
                    throw new Exception("shared_group execs length is less than context length");
                }
            }

            _exec_group = new DataParallelExecutorGroup(_symbol, _context, _work_load_list, _data_shapes, _label_shapes,
                                                        _param_names,
                                                        ForTraining, InputsNeedGrad, shared_group, _fixed_param_names, grad_req,
                                                        _state_names, _group2ctxs);
            _total_exec_bytes = _exec_group._total_exec_bytes;
            if (shared_group != null)
            {
                ParamsInitialized = true;
                _arg_params       = shared_group.ArgParams;
                _aux_params       = shared_group.AuxParams;
            }
            else if (ParamsInitialized)
            {
                _exec_group.SetParams(_arg_params, _aux_params);
            }
            else
            {
                if (_arg_params != null && _aux_params != null)
                {
                    throw new Exception("arg and aux params should be null");
                }
                _arg_params = new NDArrayDict();
                _aux_params = new NDArrayDict();

                var param_arrays = _exec_group.ParamArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _param_names.Length; i++)
                {
                    _arg_params[_param_names[i]] = param_arrays[i];
                }

                var aux_arrays = _exec_group.AuxArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _aux_names.Length; i++)
                {
                    _aux_params[_aux_names[i]] = aux_arrays[i];
                }
            }

            if (shared_module != null && shared_module.ParamsInitialized)
            {
                BorrowOptimizer(shared_module);
            }

            Binded = true;
        }