Ejemplo n.º 1
0
        public static void MarkVariables(NDArrayList variables, NDArrayList gradients,
                                         OpGradReq grad_reqs = OpGradReq.Write)
        {
            var gradReqs = new int[variables.Length];

            for (var i = 0; i < gradReqs.Length; i++)
            {
                gradReqs[i] = (int)OpGradReq.Write;
            }

            NativeMethods.MXAutogradMarkVariables(variables.Length, MxUtil.GetNDArrayHandles(variables), gradReqs,
                                                  MxUtil.GetNDArrayHandles(gradients));
        }
Ejemplo n.º 2
0
 public virtual void Assign(NDArray dst, OpGradReq req, NDArray src)
 {
     if (req == OpGradReq.Null)
     {
         return;
     }
     else if (req == OpGradReq.Write)
     {
         dst = src;
     }
     else if (req == OpGradReq.Add)
     {
         dst += src;
     }
 }
Ejemplo n.º 3
0
 public Parameter(string name, OpGradReq grad_req = OpGradReq.Write, Shape shape = null, DType dtype = null,
                  float lr_mult           = 1.0f, float wd_mult      = 1.0f, Initializer init = null, bool allow_deferred_init = false,
                  bool differentiable     = true, StorageStype stype = StorageStype.Default,
                  StorageStype grad_stype = StorageStype.Default)
 {
     Name                     = name;
     Lr_Mult                  = lr_mult;
     Wd_Mult                  = wd_mult;
     Init                     = init;
     GradReg                  = grad_req;
     _shape                   = shape;
     DataType                 = dtype ?? DType.Float32;
     this.differentiable      = differentiable;
     Stype                    = stype;
     Grad_Stype               = grad_stype;
     this.allow_deferred_init = allow_deferred_init;
     grad_req                 = OpGradReq.Null;
     _ctx_map                 = new Dictionary <int, List <int> >();
 }
Ejemplo n.º 4
0
        public Module(Symbol symbol, string[] data_names = null, string[] label_names = null,
                      Context[] context    = null, int[] work_load_list = null, string[] fixed_param_names = null,
                      string[] state_names = null,
                      Dictionary <string, Context>[] group2ctxs = null, Dictionary <string, object> compression_params = null)
        {
            if (context == null)
            {
                context = new[] { Context.Cpu() }
            }
            ;

            if (work_load_list == null)
            {
                work_load_list = new int[context.Length];
                for (var i = 0; i < work_load_list.Length; i++)
                {
                    work_load_list[i] = 1;
                }
            }

            if (context.Length != work_load_list.Length)
            {
                throw new Exception("Context and WorkLoadList length are not equal");
            }

            _group2ctxs        = group2ctxs;
            _symbol            = symbol;
            _data_names        = data_names != null ? data_names : new string[0];
            _label_names       = label_names != null ? label_names : new string[0];
            _state_names       = state_names != null ? state_names : new string[0];
            _fixed_param_names = fixed_param_names != null ? fixed_param_names : new string[0];
            CheckInputNames(symbol, _data_names, "data", true);
            CheckInputNames(symbol, _label_names, "label", false);
            CheckInputNames(symbol, _state_names, "state", true);
            CheckInputNames(symbol, _fixed_param_names, "fixed_param", true);

            var arg_names   = symbol.ListArguments();
            var input_names = new List <string>();

            input_names.AddRange(_data_names);
            input_names.AddRange(_label_names);
            input_names.AddRange(_state_names);

            _param_names  = arg_names.Where(x => arg_names.Contains(x)).ToArray();
            _aux_names    = symbol.ListAuxiliaryStates().ToArray();
            OutputNames   = symbol.ListOutputs().ToArray();
            _arg_params   = null;
            _aux_params   = null;
            _params_dirty = false;

            _compression_params = compression_params;
            _optimizer          = null;
            _kvstore            = null;
            _update_on_kvstore  = null;
            _updater            = null;
            _preload_opt_states = null;
            _grad_req           = OpGradReq.Null;

            _exec_group   = null;
            _data_shapes  = null;
            _label_shapes = null;
        }
Ejemplo n.º 5
0
        public override void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true,
                                  bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null,
                                  OpGradReq grad_req    = OpGradReq.Write)
        {
            if (force_rebind)
            {
                ResetBind();
            }

            if (Binded)
            {
                Logger.Warning("Already bound, ignoring bind()");
                return;
            }

            ForTraining    = for_training;
            InputsNeedGrad = inputs_need_grad;
            _grad_req      = grad_req;

            if (!ForTraining)
            {
                if (InputsNeedGrad)
                {
                    throw new Exception("inputs_need_grad should be false if for_training=false");
                }
            }

            (_data_shapes, _label_shapes) = ParseDataDesc(DataNames, LabelNames, data_shapes, label_shapes);

            DataParallelExecutorGroup shared_group = null;

            if (shared_module != null)
            {
                if (!shared_module.Binded && !shared_module.ParamsInitialized)
                {
                    throw new Exception("shared_module not bounded or initialized");
                }

                shared_group = shared_module._exec_group;
                if (shared_group.Execs.Count < _context.Length)
                {
                    throw new Exception("shared_group execs length is less than context length");
                }
            }

            _exec_group = new DataParallelExecutorGroup(_symbol, _context, _work_load_list, _data_shapes, _label_shapes,
                                                        _param_names,
                                                        ForTraining, InputsNeedGrad, shared_group, _fixed_param_names, grad_req,
                                                        _state_names, _group2ctxs);
            _total_exec_bytes = _exec_group._total_exec_bytes;
            if (shared_group != null)
            {
                ParamsInitialized = true;
                _arg_params       = shared_group.ArgParams;
                _aux_params       = shared_group.AuxParams;
            }
            else if (ParamsInitialized)
            {
                _exec_group.SetParams(_arg_params, _aux_params);
            }
            else
            {
                if (_arg_params != null && _aux_params != null)
                {
                    throw new Exception("arg and aux params should be null");
                }
                _arg_params = new NDArrayDict();
                _aux_params = new NDArrayDict();

                var param_arrays = _exec_group.ParamArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _param_names.Length; i++)
                {
                    _arg_params[_param_names[i]] = param_arrays[i];
                }

                var aux_arrays = _exec_group.AuxArrays.Select(x => nd.ZerosLike(x[0])).ToArray();
                for (var i = 0; i < _aux_names.Length; i++)
                {
                    _aux_params[_aux_names[i]] = aux_arrays[i];
                }
            }

            if (shared_module != null && shared_module.ParamsInitialized)
            {
                BorrowOptimizer(shared_module);
            }

            Binded = true;
        }
Ejemplo n.º 6
0
 public abstract void Bind(DataDesc[] data_shapes, DataDesc[] label_shapes = null, bool for_training = true,
                           bool inputs_need_grad = false, bool force_rebind = false, Module shared_module = null,
                           OpGradReq grad_req    = OpGradReq.Write);
Ejemplo n.º 7
0
 public virtual void Assign(NDArray dst, OpGradReq req, NDArray src)
 {
     throw new NotImplementedException();
 }