Exemplo n.º 1
0
 public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params)
 {
     _period = Math.Max(1, _period);
     if ((epoch + 1) % _period == 0)
     {
         MxModel.SaveCheckpoint(_prefix, epoch + 1, symbol, arg_params, aux_params);
     }
 }
Exemplo n.º 2
0
        public static void SaveRNNCheckPoint(BaseRNNCell[] cells, string prefix, int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params)
        {
            foreach (var cell in cells)
            {
                arg_params = cell.UnpackWeights(arg_params);
            }

            MxModel.SaveCheckpoint(prefix, epoch, symbol, arg_params, aux_params);
        }
Exemplo n.º 3
0
        public static (Symbol, NDArrayDict, NDArrayDict) LoadRNNCheckPoint(BaseRNNCell[] cells, string prefix, int epoch)
        {
            var(sym, arg, aux) = MxModel.LoadCheckpoint(prefix, epoch);
            foreach (var cell in cells)
            {
                arg = cell.PackWeights(arg);
            }

            return(sym, arg, aux);
        }
Exemplo n.º 4
0
        public static Module Load(string prefix, int epoch, bool load_optimizer_states = false,
                                  string[] data_names = null, string[] label_names = null, Logger logging             = null,
                                  Context context     = null, int[] work_load_list = null, string[] fixed_param_names = null)
        {
            var(sym, args, auxs) = MxModel.LoadCheckpoint(prefix, epoch);
            var mod = new Module(sym);

            mod._arg_params       = args;
            mod._aux_params       = auxs;
            mod.ParamsInitialized = true;
            if (load_optimizer_states)
            {
                mod._preload_opt_states = $"{prefix}-{epoch.ToString("D4")}.states";
            }

            return(mod);
        }
Exemplo n.º 5
0
        public override void Update()
        {
            if (!Binded && !ParamsInitialized && !OptimizerInitialized)
            {
                throw new Exception("Module not binded or param initialized or optimizer initialized");
            }

            _params_dirty = true;
            if (_update_on_kvstore.HasValue && _update_on_kvstore.Value)
            {
                MxModel.UpdateParamsOnKVStore(_exec_group.ParamArrays, _exec_group.GradArrays, _kvstore,
                                              _exec_group.ParamNames);
            }
            else
            {
                MxModel.UpdateParams(_exec_group.ParamArrays, _exec_group.GradArrays, _updater, _context.Length, _kvstore,
                                     _exec_group.ParamNames);
            }
        }
Exemplo n.º 6
0
        private static Module LoadModel(string prefix, int epoch = 0, bool gpu = true)
        {
            var(sym, arg_params, aux_params) = MxModel.LoadCheckpoint(prefix, epoch);
            arg_params["prob_label"]         = new NDArray(new float[0]);
            arg_params["softmax_label"]      = new NDArray(new float[0]);
            Module mod = null;

            if (gpu)
            {
                mod = new Module(symbol: sym, context: new[] { mx.Gpu(0) }, data_names: new string[] { "data" });
            }
            else
            {
                mod = new Module(symbol: sym, data_names: new string[] { "data" });
            }

            mod.Bind(for_training: false, data_shapes: new[] { new DataDesc("data", new Shape(1, 3, 224, 224)) });
            mod.SetParams(arg_params, aux_params);
            return(mod);
        }
Exemplo n.º 7
0
        public override void InitOptimizer(string kv = "local", Optimizer optimizer = null,
                                           Dictionary <string, object> optimizer_params = null, bool force_init = false)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            if (OptimizerInitialized && !force_init)
            {
                Logger.Warning("optimizer already initialized, ignoring...");
                return;
            }

            if (optimizer == null)
            {
                optimizer = new SGD();
            }

            if (_params_dirty)
            {
                SyncParamsFromDevices();
            }

            var(kvstore, update_on_kvstore) = MxModel.CreateKVStore(kv, _context.Length, _arg_params);
            var batch_size = _exec_group.BatchSize;

            if (kvstore != null && kvstore.Type.Contains("dist") && kvstore.Type.Contains("_sync"))
            {
                batch_size *= kvstore.NumWorkers;
            }

            var rescale_grad = 1.0 / batch_size;
            var idx2name     = new Dictionary <int, string>();

            if (update_on_kvstore)
            {
                var i = 0;
                foreach (var name in _exec_group.ParamNames)
                {
                    idx2name.Add(i, name);
                    i++;
                }
            }
            else
            {
                for (var k = 0; k < _context.Length; k++)
                {
                    var i = 0;
                    foreach (var name in _exec_group.ParamNames)
                    {
                        idx2name.Add(i * _context.Length + k, name);
                        i++;
                    }
                }
            }

            if (optimizer.RescaleGrad != rescale_grad)
            {
                Logger.Warning("Optimizer created manually outside Module but rescale_grad " +
                               $"is not normalized to 1.0/batch_size/num_workers ({optimizer.RescaleGrad} vs. {rescale_grad}). Is this intended?");
            }

            if (optimizer.Idx2Name == null)
            {
                optimizer.Idx2Name = idx2name;
            }

            _optimizer         = optimizer;
            _kvstore           = kvstore;
            _update_on_kvstore = update_on_kvstore;
            _updater           = null;
            if (kvstore != null)
            {
                if (_compression_params != null)
                {
                    kvstore.SetGradientCompression(_compression_params);
                }

                if (update_on_kvstore)
                {
                    kvstore.SetOptimizer(_optimizer);
                }

                MxModel.InitializeKVStore(kvstore, _exec_group.ParamArrays, _arg_params, _param_names, update_on_kvstore);
            }

            if (!update_on_kvstore)
            {
                _updater = optimizer.GetUpdater();
            }

            OptimizerInitialized = true;
            if (!string.IsNullOrWhiteSpace(_preload_opt_states))
            {
                LoadOptimizerStates(_preload_opt_states);
                _preload_opt_states = "";
            }
        }
Exemplo n.º 8
0
 public MxViewModel()
 {
     mxInfo = new MxModel();
 }