private void ResetBind() { Binded = false; _exec_group = null; _data_shapes = null; _label_shapes = null; }
public Module(Symbol symbol, string[] data_names = null, string[] label_names = null, Context[] context = null, int[] work_load_list = null, string[] fixed_param_names = null, string[] state_names = null, Dictionary <string, Context>[] group2ctxs = null, Dictionary <string, object> compression_params = null) { if (context == null) { context = new[] { Context.Cpu() } } ; if (work_load_list == null) { work_load_list = new int[context.Length]; for (var i = 0; i < work_load_list.Length; i++) { work_load_list[i] = 1; } } if (context.Length != work_load_list.Length) { throw new Exception("Context and WorkLoadList length are not equal"); } _group2ctxs = group2ctxs; _symbol = symbol; _data_names = data_names != null ? data_names : new string[0]; _label_names = label_names != null ? label_names : new string[0]; _state_names = state_names != null ? state_names : new string[0]; _fixed_param_names = fixed_param_names != null ? fixed_param_names : new string[0]; CheckInputNames(symbol, _data_names, "data", true); CheckInputNames(symbol, _label_names, "label", false); CheckInputNames(symbol, _state_names, "state", true); CheckInputNames(symbol, _fixed_param_names, "fixed_param", true); var arg_names = symbol.ListArguments(); var input_names = new List <string>(); input_names.AddRange(_data_names); input_names.AddRange(_label_names); input_names.AddRange(_state_names); _param_names = arg_names.Where(x => arg_names.Contains(x)).ToArray(); _aux_names = symbol.ListAuxiliaryStates().ToArray(); OutputNames = symbol.ListOutputs().ToArray(); _arg_params = null; _aux_params = null; _params_dirty = false; _compression_params = compression_params; _optimizer = null; _kvstore = null; _update_on_kvstore = null; _updater = null; _preload_opt_states = null; _grad_req = OpGradReq.Null; _exec_group = null; _data_shapes = null; _label_shapes = null; }
public override void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true, bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null, OpGradReq grad_req = OpGradReq.Write) { if (force_rebind) { ResetBind(); } if (Binded) { Logger.Warning("Already bound, ignoring bind()"); return; } ForTraining = for_training; InputsNeedGrad = inputs_need_grad; _grad_req = grad_req; if (!ForTraining) { if (InputsNeedGrad) { throw new Exception("inputs_need_grad should be false if for_training=false"); } } (_data_shapes, _label_shapes) = ParseDataDesc(DataNames, LabelNames, data_shapes, label_shapes); DataParallelExecutorGroup shared_group = null; if (shared_module != null) { if (!shared_module.Binded && !shared_module.ParamsInitialized) { throw new Exception("shared_module not bounded or initialized"); } shared_group = shared_module._exec_group; if (shared_group.Execs.Count < _context.Length) { throw new Exception("shared_group execs length is less than context length"); } } _exec_group = new DataParallelExecutorGroup(_symbol, _context, _work_load_list, _data_shapes, _label_shapes, _param_names, ForTraining, InputsNeedGrad, shared_group, _fixed_param_names, grad_req, _state_names, _group2ctxs); _total_exec_bytes = _exec_group._total_exec_bytes; if (shared_group != null) { ParamsInitialized = true; _arg_params = shared_group.ArgParams; _aux_params = shared_group.AuxParams; } else if (ParamsInitialized) { _exec_group.SetParams(_arg_params, _aux_params); } else { if (_arg_params != null && _aux_params != null) { throw new Exception("arg and aux params should be null"); } _arg_params = new NDArrayDict(); _aux_params = new NDArrayDict(); var param_arrays = _exec_group.ParamArrays.Select(x => nd.ZerosLike(x[0])).ToArray(); for (var i = 0; i < _param_names.Length; i++) { _arg_params[_param_names[i]] = param_arrays[i]; } var aux_arrays = _exec_group.AuxArrays.Select(x => nd.ZerosLike(x[0])).ToArray(); for (var i = 0; i < _aux_names.Length; i++) { _aux_params[_aux_names[i]] = aux_arrays[i]; } } if (shared_module != null && shared_module.ParamsInitialized) { BorrowOptimizer(shared_module); } Binded = true; }