Ejemplo n.º 1
0
        public static void SaveCheckpoint(string prefix, int epoch, Symbol symbol, NDArrayDict arg_params,
                                          NDArrayDict aux_params, bool remove_amp_cast = true)
        {
            if (symbol != null)
            {
                symbol.Save($"{prefix}-symbol.json", remove_amp_cast);
            }

            NDArrayDict save_dict = new NDArrayDict();

            foreach (var item in arg_params)
            {
                save_dict.Add($"arg:{item.Key}", item.Value);
            }

            foreach (var item in aux_params)
            {
                save_dict.Add($"aux:{item.Key}", item.Value);
            }

            string param_name = $"{prefix}-{epoch.ToString("D4")}.params";

            NDArray.Save(param_name, save_dict);
            Logger.Info($"Saved checkpoint to \"{param_name}\"");
        }
Ejemplo n.º 2
0
        public static (Symbol, NDArrayDict, NDArrayDict) LoadCheckpoint(string prefix, int epoch)
        {
            Symbol      sym        = Symbol.Load($"{prefix}-symbol.json");
            string      param_name = $"{prefix}-{epoch.ToString("D4")}.params";
            var         save_dict  = NDArray.Load(param_name);
            NDArrayDict arg_params = new NDArrayDict();
            NDArrayDict aux_params = new NDArrayDict();

            if (save_dict == null)
            {
                Logger.Warning($"Params file '{param_name}' is empty");
            }
            else
            {
                foreach (var item in save_dict)
                {
                    if (item.Key.StartsWith("arg:"))
                    {
                        arg_params.Add(item.Key.Replace("arg:", ""), item.Value);
                    }
                    else if (item.Key.StartsWith("aux:"))
                    {
                        aux_params.Add(item.Key.Replace("aux:", ""), item.Value);
                    }
                    else
                    {
                        Logger.Warning($"Params file '{param_name}' contains unknown param '{item.Key}'");
                    }
                }
            }

            return(sym, arg_params, aux_params);
        }
Ejemplo n.º 3
0
        public static NDArrayDict GetMNIST()
        {
            var path = "http://data.mxnet.io/data/mnist/";

            var(train_lbl, train_img) = read_data(path + "train-labels-idx1-ubyte.gz",
                                                  path + "train-images-idx3-ubyte.gz", 60000);
            var(test_lbl, test_img) =
                read_data(path + "t10k-labels-idx1-ubyte.gz", path + "t10k-images-idx3-ubyte.gz", 10000);

            var dataset = new NDArrayDict();

            dataset.Add("train_data", train_img);
            dataset.Add("train_label", train_lbl);
            dataset.Add("test_data", test_img);
            dataset.Add("test_label", test_lbl);

            return(dataset);
        }