public override void InitParams(Initializer initializer = null, NDArrayDict arg_params = null, NDArrayDict aux_params = null, bool allow_missing = false, bool force_init = false, bool allow_extra = false) { if (ParamsInitialized && !force_init) { Logger.Warning("Parameters already initialized and force_init=False. init_params call ignored."); return; } if (!Binded) { throw new Exception("call bind before initializing the parameters"); } void impl(InitDesc name, ref NDArray arr, NDArrayDict cache) { if (cache != null) { NDArray cache_arr = null; if (cache.Contains(name.Name)) { cache_arr = cache[name.Name]; if (cache_arr != arr) { cache_arr.CopyTo(arr); } } else { if (!allow_missing) { throw new Exception($"{name.Name} is not presented"); } if (initializer != null) { initializer.InitWeight(name.Name, ref arr); } } } } var attr = Symbol.AttrDict(); foreach (var name in _arg_params.Keys) { var arr = _arg_params[name]; var desc = new InitDesc(name, attr.ContainsKey(name) ? attr[name] : null); impl(desc, ref arr, arg_params); } foreach (var name in _aux_params.Keys) { var arr = _aux_params[name]; var desc = new InitDesc(name, attr.ContainsKey(name) ? attr[name] : null); impl(desc, ref arr, aux_params); } ParamsInitialized = true; _params_dirty = false; _exec_group.SetParams(_arg_params, _aux_params, allow_extra); }
public override void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true, bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null, OpGradReq grad_req = OpGradReq.Write) { if (force_rebind) { ResetBind(); } if (Binded) { Logger.Warning("Already bound, ignoring bind()"); return; } ForTraining = for_training; InputsNeedGrad = inputs_need_grad; _grad_req = grad_req; if (!ForTraining) { if (InputsNeedGrad) { throw new Exception("inputs_need_grad should be false if for_training=false"); } } (_data_shapes, _label_shapes) = ParseDataDesc(DataNames, LabelNames, data_shapes, label_shapes); DataParallelExecutorGroup shared_group = null; if (shared_module != null) { if (!shared_module.Binded && !shared_module.ParamsInitialized) { throw new Exception("shared_module not bounded or initialized"); } shared_group = shared_module._exec_group; if (shared_group.Execs.Count < _context.Length) { throw new Exception("shared_group execs length is less than context length"); } } _exec_group = new DataParallelExecutorGroup(_symbol, _context, _work_load_list, _data_shapes, _label_shapes, _param_names, ForTraining, InputsNeedGrad, shared_group, _fixed_param_names, grad_req, _state_names, _group2ctxs); _total_exec_bytes = _exec_group._total_exec_bytes; if (shared_group != null) { ParamsInitialized = true; _arg_params = shared_group.ArgParams; _aux_params = shared_group.AuxParams; } else if (ParamsInitialized) { _exec_group.SetParams(_arg_params, _aux_params); } else { if (_arg_params != null && _aux_params != null) { throw new Exception("arg and aux params should be null"); } _arg_params = new NDArrayDict(); _aux_params = new NDArrayDict(); var param_arrays = _exec_group.ParamArrays.Select(x => nd.ZerosLike(x[0])).ToArray(); for (var i = 0; i < _param_names.Length; i++) { _arg_params[_param_names[i]] = param_arrays[i]; } var aux_arrays = _exec_group.AuxArrays.Select(x => nd.ZerosLike(x[0])).ToArray(); for (var i = 0; i < _aux_names.Length; i++) { _aux_params[_aux_names[i]] = aux_arrays[i]; } } if (shared_module != null && shared_module.ParamsInitialized) { BorrowOptimizer(shared_module); } Binded = true; }