public void ExecuteTest() { Snapshot snapshot = new Snapshot(); snapshot.Append("item1", Shape.Map0D(2, 3), new float[] { 0f, 1f, 2f, 3f, 4f, 5f }); snapshot.Append("item2", Shape.Map0D(3, 4), new float[] { 5f, 3f, 1f, 0f, 2f, 4f, 1f, 0f, 5f, 3f, 2f, 4f }); ZippedBinaryShapshotSaver saver = new ZippedBinaryShapshotSaver(); byte[] data = null; using (var stream = new MemoryStream()) { saver.Save(stream, snapshot); data = stream.ToArray(); } Snapshot snapshot2 = null; using (var stream = new MemoryStream(data)) { snapshot2 = saver.Load(stream); } CollectionAssert.AreEquivalent(new string[] { "item1", "item2" }, snapshot2.Keys.ToArray()); Assert.AreEqual(snapshot.Table["item1"].shape, snapshot2.Table["item1"].shape); CollectionAssert.AreEqual(snapshot.Table["item1"].state, snapshot2.Table["item1"].state); Assert.AreEqual(snapshot.Table["item2"].shape, snapshot2.Table["item2"].shape); CollectionAssert.AreEqual(snapshot.Table["item2"].state, snapshot2.Table["item2"].state); using (var stream = new FileStream("zipdebug.tss", FileMode.Create)) { using (var writer = new BinaryWriter(stream)) { writer.Write(data); } } }
static void Main() { const string dirpath_dataset = "mnist_dataset"; const string dirpath_result = "result"; const int classes = 10; Console.WriteLine("Download mnist..."); MnistDownloader.Download(dirpath_dataset); Console.WriteLine("Setup loader..."); Random random = new Random(1234); MnistLoader loader = new MnistLoader(dirpath_dataset, num_batches: 1000); Iterator train_iterator = new ShuffleIterator(loader.NumBatches, loader.CountTrainDatas, random); Iterator test_iterator = new ShuffleIterator(loader.NumBatches, loader.CountTestDatas, random); Console.WriteLine("Create input tensor..."); VariableField x = new Tensor(loader.BatchShape); VariableField t = new Tensor(Shape.Vector(loader.NumBatches)); Console.WriteLine("Build model..."); Field y = CNN.Forward(x, classes); Field acc = Accuracy(y, t); Field err = Sum(SoftmaxCrossEntropy(y, OneHotVector(t, classes)), axes: new int[] { Axis.Map0D.Channels }); StoreField accnode = acc.Save(), lossnode = Average(err).Save(); Console.WriteLine("Build optimize flow..."); (Flow trainflow, Parameters parameters) = Flow.Optimize(err); Console.WriteLine("Initialize params..."); parameters .Where((parameter) => parameter.Category == ParameterCategory.Kernel) .InitializeTensor((tensor) => new HeNormal(tensor, random)); parameters .Where((parameter) => parameter.Category == ParameterCategory.Bias) .InitializeTensor((tensor) => new Zero(tensor)); Console.WriteLine("Set params updater..."); parameters.AddUpdater((parameter) => new Nadam(parameter, alpha: 0.01f)); parameters.AddUpdater((parameter) => new Ridge(parameter, decay: 1e-4f)); Console.WriteLine("Training..."); Train(train_iterator, loader, x, t, accnode, lossnode, trainflow, parameters); Console.WriteLine("Build inference flow..."); Flow testflow = Flow.Inference(accnode); Console.WriteLine("Testing..."); Test(test_iterator, loader, x, t, testflow, accnode); Console.WriteLine("Saving snapshot..."); Snapshot snapshot = parameters.Save(); SnapshotSaver saver = new ZippedBinaryShapshotSaver(); if (!Directory.Exists(dirpath_result)) { Directory.CreateDirectory(dirpath_result); } saver.Save($"{dirpath_result}/mnist.tss", snapshot); Console.WriteLine("END"); Console.Read(); }