public static void SaveCheckpoint(string prefix, int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params, bool remove_amp_cast = true) { if (symbol != null) { symbol.Save($"{prefix}-symbol.json", remove_amp_cast); } NDArrayDict save_dict = new NDArrayDict(); foreach (var item in arg_params) { save_dict.Add($"arg:{item.Key}", item.Value); } foreach (var item in aux_params) { save_dict.Add($"aux:{item.Key}", item.Value); } string param_name = $"{prefix}-{epoch.ToString("D4")}.params"; NDArray.Save(param_name, save_dict); Logger.Info($"Saved checkpoint to \"{param_name}\""); }
public static (Symbol, NDArrayDict, NDArrayDict) LoadCheckpoint(string prefix, int epoch) { Symbol sym = Symbol.Load($"{prefix}-symbol.json"); string param_name = $"{prefix}-{epoch.ToString("D4")}.params"; var save_dict = NDArray.Load(param_name); NDArrayDict arg_params = new NDArrayDict(); NDArrayDict aux_params = new NDArrayDict(); if (save_dict == null) { Logger.Warning($"Params file '{param_name}' is empty"); } else { foreach (var item in save_dict) { if (item.Key.StartsWith("arg:")) { arg_params.Add(item.Key.Replace("arg:", ""), item.Value); } else if (item.Key.StartsWith("aux:")) { aux_params.Add(item.Key.Replace("aux:", ""), item.Value); } else { Logger.Warning($"Params file '{param_name}' contains unknown param '{item.Key}'"); } } } return(sym, arg_params, aux_params); }
public static NDArrayDict GetMNIST() { var path = "http://data.mxnet.io/data/mnist/"; var(train_lbl, train_img) = read_data(path + "train-labels-idx1-ubyte.gz", path + "train-images-idx3-ubyte.gz", 60000); var(test_lbl, test_img) = read_data(path + "t10k-labels-idx1-ubyte.gz", path + "t10k-images-idx3-ubyte.gz", 10000); var dataset = new NDArrayDict(); dataset.Add("train_data", train_img); dataset.Add("train_label", train_lbl); dataset.Add("test_data", test_img); dataset.Add("test_label", test_lbl); return(dataset); }