public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params) { _period = Math.Max(1, _period); if ((epoch + 1) % _period == 0) { MxModel.SaveCheckpoint(_prefix, epoch + 1, symbol, arg_params, aux_params); } }
public static void SaveRNNCheckPoint(BaseRNNCell[] cells, string prefix, int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params) { foreach (var cell in cells) { arg_params = cell.UnpackWeights(arg_params); } MxModel.SaveCheckpoint(prefix, epoch, symbol, arg_params, aux_params); }
public static (Symbol, NDArrayDict, NDArrayDict) LoadRNNCheckPoint(BaseRNNCell[] cells, string prefix, int epoch) { var(sym, arg, aux) = MxModel.LoadCheckpoint(prefix, epoch); foreach (var cell in cells) { arg = cell.PackWeights(arg); } return(sym, arg, aux); }
public static Module Load(string prefix, int epoch, bool load_optimizer_states = false, string[] data_names = null, string[] label_names = null, Logger logging = null, Context context = null, int[] work_load_list = null, string[] fixed_param_names = null) { var(sym, args, auxs) = MxModel.LoadCheckpoint(prefix, epoch); var mod = new Module(sym); mod._arg_params = args; mod._aux_params = auxs; mod.ParamsInitialized = true; if (load_optimizer_states) { mod._preload_opt_states = $"{prefix}-{epoch.ToString("D4")}.states"; } return(mod); }
public override void Update() { if (!Binded && !ParamsInitialized && !OptimizerInitialized) { throw new Exception("Module not binded or param initialized or optimizer initialized"); } _params_dirty = true; if (_update_on_kvstore.HasValue && _update_on_kvstore.Value) { MxModel.UpdateParamsOnKVStore(_exec_group.ParamArrays, _exec_group.GradArrays, _kvstore, _exec_group.ParamNames); } else { MxModel.UpdateParams(_exec_group.ParamArrays, _exec_group.GradArrays, _updater, _context.Length, _kvstore, _exec_group.ParamNames); } }
private static Module LoadModel(string prefix, int epoch = 0, bool gpu = true) { var(sym, arg_params, aux_params) = MxModel.LoadCheckpoint(prefix, epoch); arg_params["prob_label"] = new NDArray(new float[0]); arg_params["softmax_label"] = new NDArray(new float[0]); Module mod = null; if (gpu) { mod = new Module(symbol: sym, context: new[] { mx.Gpu(0) }, data_names: new string[] { "data" }); } else { mod = new Module(symbol: sym, data_names: new string[] { "data" }); } mod.Bind(for_training: false, data_shapes: new[] { new DataDesc("data", new Shape(1, 3, 224, 224)) }); mod.SetParams(arg_params, aux_params); return(mod); }
public override void InitOptimizer(string kv = "local", Optimizer optimizer = null, Dictionary <string, object> optimizer_params = null, bool force_init = false) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } if (OptimizerInitialized && !force_init) { Logger.Warning("optimizer already initialized, ignoring..."); return; } if (optimizer == null) { optimizer = new SGD(); } if (_params_dirty) { SyncParamsFromDevices(); } var(kvstore, update_on_kvstore) = MxModel.CreateKVStore(kv, _context.Length, _arg_params); var batch_size = _exec_group.BatchSize; if (kvstore != null && kvstore.Type.Contains("dist") && kvstore.Type.Contains("_sync")) { batch_size *= kvstore.NumWorkers; } var rescale_grad = 1.0 / batch_size; var idx2name = new Dictionary <int, string>(); if (update_on_kvstore) { var i = 0; foreach (var name in _exec_group.ParamNames) { idx2name.Add(i, name); i++; } } else { for (var k = 0; k < _context.Length; k++) { var i = 0; foreach (var name in _exec_group.ParamNames) { idx2name.Add(i * _context.Length + k, name); i++; } } } if (optimizer.RescaleGrad != rescale_grad) { Logger.Warning("Optimizer created manually outside Module but rescale_grad " + $"is not normalized to 1.0/batch_size/num_workers ({optimizer.RescaleGrad} vs. {rescale_grad}). Is this intended?"); } if (optimizer.Idx2Name == null) { optimizer.Idx2Name = idx2name; } _optimizer = optimizer; _kvstore = kvstore; _update_on_kvstore = update_on_kvstore; _updater = null; if (kvstore != null) { if (_compression_params != null) { kvstore.SetGradientCompression(_compression_params); } if (update_on_kvstore) { kvstore.SetOptimizer(_optimizer); } MxModel.InitializeKVStore(kvstore, _exec_group.ParamArrays, _arg_params, _param_names, update_on_kvstore); } if (!update_on_kvstore) { _updater = optimizer.GetUpdater(); } OptimizerInitialized = true; if (!string.IsNullOrWhiteSpace(_preload_opt_states)) { LoadOptimizerStates(_preload_opt_states); _preload_opt_states = ""; } }
public MxViewModel() { mxInfo = new MxModel(); }