Esempio n. 1
0
        static void Train(Iterator train_iterator, MnistLoader loader, VariableField x, VariableField t, StoreField accnode, StoreField lossnode, Flow trainflow, Parameters parameters)
        {
            Stopwatch sw = new Stopwatch();

            sw.Start();

            while (train_iterator.Epoch < 50)
            {
                (float[] images, float[] labels) = loader.GetTrain(train_iterator.Next());

                x.ValueTensor.State = images;
                t.ValueTensor.State = labels;

                trainflow.Execute(enable_processing_time: true);
                parameters.Update();

                if (train_iterator.Iteration % 100 != 0)
                {
                    continue;
                }

                float train_acc  = accnode.State[0];
                float train_loss = lossnode.State[0];

                Console.WriteLine($"[{train_iterator.Iteration}] train acc: {train_acc:F3} train loss: {train_loss:E3}");
            }

            sw.Stop();

            Console.WriteLine($"{sw.ElapsedMilliseconds} msec, {sw.ElapsedMilliseconds / train_iterator.Epoch} msec/epoch");
        }
Esempio n. 2
0
        static void Test(Iterator test_iterator, MnistLoader loader, VariableField x, VariableField t, Flow testflow, StoreField accnode)
        {
            List <float> test_acc_list = new List <float>();

            Stopwatch sw = new Stopwatch();

            sw.Start();

            while (test_iterator.Epoch < 1)
            {
                (float[] images, float[] labels) = loader.GetTest(test_iterator.Next());

                x.ValueTensor.State = images;
                t.ValueTensor.State = labels;

                testflow.Execute();

                float test_acc = accnode.State[0];

                test_acc_list.Add(test_acc);
            }

            sw.Stop();

            Console.WriteLine($"{sw.ElapsedMilliseconds} msec");

            Console.WriteLine($"[Result] acc: {test_acc_list.Average():F5}");
        }
Esempio n. 3
0
        private static void Main(string[] args)
        {
            Console.Clear();

            var watch = Stopwatch.StartNew();

            ProcessingDevice.Device = DeviceType.CPU_Parallel;

            watch.Stop();
            Console.WriteLine($"Device Time: {watch.ElapsedMilliseconds}ms");

            var model = new DenseModel();

            model.AddLayer(BuildedModels.DenseLeakRelu(784, 100, 1e-4f, 2e-1f, EnumOptimizerFunction.RmsProp));
            model.AddLayer(BuildedModels.DenseLeakRelu(100, 30, 1e-4f, 2e-1f, EnumOptimizerFunction.RmsProp));
            model.AddLayer(BuildedModels.DenseSoftMax(30, 10, 1e-4f, 2e-1f, EnumOptimizerFunction.RmsProp));
            model.SetLossFunction(new CrossEntropyLossFunction());

            watch = Stopwatch.StartNew();
            watch.Stop();
            Console.WriteLine($"Sinapse Time: {watch.ElapsedMilliseconds}ms");

            MnistLoader.DataPath = path;
            var trainingValues = MnistLoader.OpenMnist();

            int cont      = 0;
            int sizeTrain = trainingValues.Count;

            var err = 100f;
            var e   = 0f;

            while (true)
            {
                watch = Stopwatch.StartNew();
                e     = 0f;
                int ct = 0;
                for (int i = 0; i < sizeTrain; i++)
                {
                    var index = i;

                    var inputs = new FloatArray(ArrayMethods.ByteToArray(trainingValues[index].pixels, 28, 28));
                    var target = new FloatArray(ArrayMethods.ByteToArray(trainingValues[index].label, 10));

                    //  Learning
                    e += model.Learn(inputs, target);

                    // Sample
                    if (ct % 1000 == 0)
                    {
                        Write(trainingValues[index].ToString());
                        Write(ArrayMethods.PrintArray(model.Output(inputs), 10));
                        Write("\n");
                    }

                    ct++;
                }

                err = 0.999f * err + 0.001f * e;

                cont++;
                watch.Stop();
                var time = watch.ElapsedMilliseconds;
                Console.Title =
                    $"Error: {err} --- TSPS (Training Sample per Second): {Math.Ceiling(1000d / ((double)time / (double)sizeTrain))}";
            }
        }
Esempio n. 4
0
        static void Main()
        {
            const string dirpath_dataset = "mnist_dataset";
            const string dirpath_result  = "result";
            const int    classes         = 10;

            Console.WriteLine("Download mnist...");
            MnistDownloader.Download(dirpath_dataset);

            Console.WriteLine("Setup loader...");
            Random random = new Random(1234);

            MnistLoader loader         = new MnistLoader(dirpath_dataset, num_batches: 1000);
            Iterator    train_iterator = new ShuffleIterator(loader.NumBatches, loader.CountTrainDatas, random);
            Iterator    test_iterator  = new ShuffleIterator(loader.NumBatches, loader.CountTestDatas, random);

            Console.WriteLine("Create input tensor...");
            VariableField x = new Tensor(loader.BatchShape);
            VariableField t = new Tensor(Shape.Vector(loader.NumBatches));

            Console.WriteLine("Build model...");
            Field      y = CNN.Forward(x, classes);
            Field      acc = Accuracy(y, t);
            Field      err = Sum(SoftmaxCrossEntropy(y, OneHotVector(t, classes)), axes: new int[] { Axis.Map0D.Channels });
            StoreField accnode = acc.Save(), lossnode = Average(err).Save();

            Console.WriteLine("Build optimize flow...");
            (Flow trainflow, Parameters parameters) = Flow.Optimize(err);

            Console.WriteLine("Initialize params...");
            parameters
            .Where((parameter) => parameter.Category == ParameterCategory.Kernel)
            .InitializeTensor((tensor) => new HeNormal(tensor, random));
            parameters
            .Where((parameter) => parameter.Category == ParameterCategory.Bias)
            .InitializeTensor((tensor) => new Zero(tensor));

            Console.WriteLine("Set params updater...");
            parameters.AddUpdater((parameter) => new Nadam(parameter, alpha: 0.01f));
            parameters.AddUpdater((parameter) => new Ridge(parameter, decay: 1e-4f));

            Console.WriteLine("Training...");
            Train(train_iterator, loader, x, t, accnode, lossnode, trainflow, parameters);

            Console.WriteLine("Build inference flow...");
            Flow testflow = Flow.Inference(accnode);

            Console.WriteLine("Testing...");
            Test(test_iterator, loader, x, t, testflow, accnode);

            Console.WriteLine("Saving snapshot...");
            Snapshot      snapshot = parameters.Save();
            SnapshotSaver saver    = new ZippedBinaryShapshotSaver();

            if (!Directory.Exists(dirpath_result))
            {
                Directory.CreateDirectory(dirpath_result);
            }
            saver.Save($"{dirpath_result}/mnist.tss", snapshot);

            Console.WriteLine("END");
            Console.Read();
        }