Exemple #1
0
        public void ExecuteTest()
        {
            Snapshot snapshot = new Snapshot();

            snapshot.Append("item1", Shape.Map0D(2, 3), new float[] { 0f, 1f, 2f, 3f, 4f, 5f });
            snapshot.Append("item2", Shape.Map0D(3, 4), new float[] { 5f, 3f, 1f, 0f, 2f, 4f, 1f, 0f, 5f, 3f, 2f, 4f });

            ZippedBinaryShapshotSaver saver = new ZippedBinaryShapshotSaver();

            byte[] data = null;

            using (var stream = new MemoryStream()) {
                saver.Save(stream, snapshot);

                data = stream.ToArray();
            }

            Snapshot snapshot2 = null;

            using (var stream = new MemoryStream(data)) {
                snapshot2 = saver.Load(stream);
            }

            CollectionAssert.AreEquivalent(new string[] { "item1", "item2" }, snapshot2.Keys.ToArray());
            Assert.AreEqual(snapshot.Table["item1"].shape, snapshot2.Table["item1"].shape);
            CollectionAssert.AreEqual(snapshot.Table["item1"].state, snapshot2.Table["item1"].state);
            Assert.AreEqual(snapshot.Table["item2"].shape, snapshot2.Table["item2"].shape);
            CollectionAssert.AreEqual(snapshot.Table["item2"].state, snapshot2.Table["item2"].state);

            using (var stream = new FileStream("zipdebug.tss", FileMode.Create)) {
                using (var writer = new BinaryWriter(stream)) {
                    writer.Write(data);
                }
            }
        }
Exemple #2
0
        static void Main()
        {
            const string dirpath_dataset = "mnist_dataset";
            const string dirpath_result  = "result";
            const int    classes         = 10;

            Console.WriteLine("Download mnist...");
            MnistDownloader.Download(dirpath_dataset);

            Console.WriteLine("Setup loader...");
            Random random = new Random(1234);

            MnistLoader loader         = new MnistLoader(dirpath_dataset, num_batches: 1000);
            Iterator    train_iterator = new ShuffleIterator(loader.NumBatches, loader.CountTrainDatas, random);
            Iterator    test_iterator  = new ShuffleIterator(loader.NumBatches, loader.CountTestDatas, random);

            Console.WriteLine("Create input tensor...");
            VariableField x = new Tensor(loader.BatchShape);
            VariableField t = new Tensor(Shape.Vector(loader.NumBatches));

            Console.WriteLine("Build model...");
            Field      y = CNN.Forward(x, classes);
            Field      acc = Accuracy(y, t);
            Field      err = Sum(SoftmaxCrossEntropy(y, OneHotVector(t, classes)), axes: new int[] { Axis.Map0D.Channels });
            StoreField accnode = acc.Save(), lossnode = Average(err).Save();

            Console.WriteLine("Build optimize flow...");
            (Flow trainflow, Parameters parameters) = Flow.Optimize(err);

            Console.WriteLine("Initialize params...");
            parameters
            .Where((parameter) => parameter.Category == ParameterCategory.Kernel)
            .InitializeTensor((tensor) => new HeNormal(tensor, random));
            parameters
            .Where((parameter) => parameter.Category == ParameterCategory.Bias)
            .InitializeTensor((tensor) => new Zero(tensor));

            Console.WriteLine("Set params updater...");
            parameters.AddUpdater((parameter) => new Nadam(parameter, alpha: 0.01f));
            parameters.AddUpdater((parameter) => new Ridge(parameter, decay: 1e-4f));

            Console.WriteLine("Training...");
            Train(train_iterator, loader, x, t, accnode, lossnode, trainflow, parameters);

            Console.WriteLine("Build inference flow...");
            Flow testflow = Flow.Inference(accnode);

            Console.WriteLine("Testing...");
            Test(test_iterator, loader, x, t, testflow, accnode);

            Console.WriteLine("Saving snapshot...");
            Snapshot      snapshot = parameters.Save();
            SnapshotSaver saver    = new ZippedBinaryShapshotSaver();

            if (!Directory.Exists(dirpath_result))
            {
                Directory.CreateDirectory(dirpath_result);
            }
            saver.Save($"{dirpath_result}/mnist.tss", snapshot);

            Console.WriteLine("END");
            Console.Read();
        }