예제 #1
0
        static void Main(string[] args)
        {
            Torch.SetSeed(1);

            var device = Torch.IsCudaAvailable() ? Device.CUDA : Device.CPU;

            if (device.Type == DeviceType.CUDA)
            {
                _trainBatchSize *= 8;
                _testBatchSize  *= 8;
                _epochs         *= 16;
            }

            Console.WriteLine();
            Console.WriteLine($"\tRunning AlexNet with {_dataset} on {device.Type.ToString()} for {_epochs} epochs");
            Console.WriteLine();

            var sourceDir = _dataLocation;
            var targetDir = Path.Combine(_dataLocation, "test_data");

            if (!Directory.Exists(targetDir))
            {
                Directory.CreateDirectory(targetDir);
                Utils.Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir);
            }

            using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device))
                using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device))
                    using (var model = new Model("model", _numClasses, device))
                        using (var optimizer = NN.Optimizer.Adam(model.parameters(), 0.001)) {
                            Stopwatch sw = new Stopwatch();
                            sw.Start();

                            for (var epoch = 1; epoch <= _epochs; epoch++)
                            {
                                Train(model, optimizer, nll_loss(), train, epoch, _trainBatchSize, train.Size);
                                Test(model, nll_loss(), test, test.Size);
                                GC.Collect();

                                if (sw.Elapsed.TotalSeconds > 3600)
                                {
                                    break;
                                }
                            }

                            sw.Stop();
                            Console.WriteLine($"Elapsed time {sw.Elapsed.TotalSeconds} s.");
                            Console.ReadLine();
                        }
        }
예제 #2
0
        private readonly static int _timeout = 3600;    // One hour by default.

        internal static void Main(string[] args)
        {
            torch.random.manual_seed(1);

            var device =
                // This worked on a GeForce RTX 2080 SUPER with 8GB, for all the available network architectures.
                // It may not fit with less memory than that, but it's worth modifying the batch size to fit in memory.
                torch.cuda.is_available() ? torch.CUDA :
                torch.CPU;

            if (device.type == DeviceType.CUDA)
            {
                _trainBatchSize *= 8;
                _testBatchSize  *= 8;
                _epochs         *= 16;
            }

            var modelName = args.Length > 0 ? args[0] : "AlexNet";
            var epochs    = args.Length > 1 ? int.Parse(args[1]) : _epochs;
            var timeout   = args.Length > 2 ? int.Parse(args[2]) : _timeout;

            Console.WriteLine();
            Console.WriteLine($"\tRunning {modelName} with {_dataset} on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");
            Console.WriteLine();

            var sourceDir = _dataLocation;
            var targetDir = Path.Combine(_dataLocation, "test_data");

            if (!Directory.Exists(targetDir))
            {
                Directory.CreateDirectory(targetDir);
                Utils.Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir);
            }

            Console.WriteLine($"\tCreating the model...");

            Module model = null;

            switch (modelName.ToLower())
            {
            case "alexnet":
                model = new AlexNet(modelName, _numClasses, device);
                break;

            case "mobilenet":
                model = new MobileNet(modelName, _numClasses, device);
                break;

            case "vgg11":
            case "vgg13":
            case "vgg16":
            case "vgg19":
                model = new VGG(modelName, _numClasses, device);
                break;

            case "resnet18":
                model = ResNet.ResNet18(_numClasses, device);
                break;

            case "resnet34":
                _testBatchSize /= 4;
                model           = ResNet.ResNet34(_numClasses, device);
                break;

            case "resnet50":
                _trainBatchSize /= 6;
                _testBatchSize  /= 8;
                model            = ResNet.ResNet50(_numClasses, device);
                break;

#if false
            // The following is disabled, because they require big CUDA processors in order to run.
            case "resnet101":
                _trainBatchSize /= 6;
                _testBatchSize  /= 8;
                model            = ResNet.ResNet101(_numClasses, device);
                break;

            case "resnet152":
                _testBatchSize /= 4;
                model           = ResNet.ResNet152(_numClasses, device);
                break;
#endif
            }

            var hflip    = torchvision.transforms.HorizontalFlip();
            var gray     = torchvision.transforms.Grayscale(3);
            var rotate   = torchvision.transforms.Rotate(90);
            var contrast = torchvision.transforms.AdjustContrast(1.25);

            Console.WriteLine($"\tPreparing training and test data...");
            Console.WriteLine();

            using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device, transforms: new torchvision.ITransform[] { }))
                using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device))
                    using (var optimizer = torch.optim.Adam(model.parameters(), 0.001)) {
                        Stopwatch totalSW = new Stopwatch();
                        totalSW.Start();

                        for (var epoch = 1; epoch <= epochs; epoch++)
                        {
                            Stopwatch epchSW = new Stopwatch();
                            epchSW.Start();

                            Train(model, optimizer, nll_loss(), train.Data(), epoch, _trainBatchSize, train.Size);
                            Test(model, nll_loss(), test.Data(), test.Size);

                            epchSW.Stop();
                            Console.WriteLine($"Elapsed time for this epoch: {epchSW.Elapsed.TotalSeconds} s.");

                            if (totalSW.Elapsed.TotalSeconds > timeout)
                            {
                                break;
                            }
                        }

                        totalSW.Stop();
                        Console.WriteLine($"Elapsed training time: {totalSW.Elapsed} s.");
                    }

            model.Dispose();
        }