static void Main(string[] args) { Torch.SetSeed(1); var device = Torch.IsCudaAvailable() ? Device.CUDA : Device.CPU; if (device.Type == DeviceType.CUDA) { _trainBatchSize *= 8; _testBatchSize *= 8; _epochs *= 16; } Console.WriteLine(); Console.WriteLine($"\tRunning AlexNet with {_dataset} on {device.Type.ToString()} for {_epochs} epochs"); Console.WriteLine(); var sourceDir = _dataLocation; var targetDir = Path.Combine(_dataLocation, "test_data"); if (!Directory.Exists(targetDir)) { Directory.CreateDirectory(targetDir); Utils.Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir); } using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device)) using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device)) using (var model = new Model("model", _numClasses, device)) using (var optimizer = NN.Optimizer.Adam(model.parameters(), 0.001)) { Stopwatch sw = new Stopwatch(); sw.Start(); for (var epoch = 1; epoch <= _epochs; epoch++) { Train(model, optimizer, nll_loss(), train, epoch, _trainBatchSize, train.Size); Test(model, nll_loss(), test, test.Size); GC.Collect(); if (sw.Elapsed.TotalSeconds > 3600) { break; } } sw.Stop(); Console.WriteLine($"Elapsed time {sw.Elapsed.TotalSeconds} s."); Console.ReadLine(); } }
private readonly static int _timeout = 3600; // One hour by default. internal static void Main(string[] args) { torch.random.manual_seed(1); var device = // This worked on a GeForce RTX 2080 SUPER with 8GB, for all the available network architectures. // It may not fit with less memory than that, but it's worth modifying the batch size to fit in memory. torch.cuda.is_available() ? torch.CUDA : torch.CPU; if (device.type == DeviceType.CUDA) { _trainBatchSize *= 8; _testBatchSize *= 8; _epochs *= 16; } var modelName = args.Length > 0 ? args[0] : "AlexNet"; var epochs = args.Length > 1 ? int.Parse(args[1]) : _epochs; var timeout = args.Length > 2 ? int.Parse(args[2]) : _timeout; Console.WriteLine(); Console.WriteLine($"\tRunning {modelName} with {_dataset} on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}."); Console.WriteLine(); var sourceDir = _dataLocation; var targetDir = Path.Combine(_dataLocation, "test_data"); if (!Directory.Exists(targetDir)) { Directory.CreateDirectory(targetDir); Utils.Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir); } Console.WriteLine($"\tCreating the model..."); Module model = null; switch (modelName.ToLower()) { case "alexnet": model = new AlexNet(modelName, _numClasses, device); break; case "mobilenet": model = new MobileNet(modelName, _numClasses, device); break; case "vgg11": case "vgg13": case "vgg16": case "vgg19": model = new VGG(modelName, _numClasses, device); break; case "resnet18": model = ResNet.ResNet18(_numClasses, device); break; case "resnet34": _testBatchSize /= 4; model = ResNet.ResNet34(_numClasses, device); break; case "resnet50": _trainBatchSize /= 6; _testBatchSize /= 8; model = ResNet.ResNet50(_numClasses, device); break; #if false // The following is disabled, because they require big CUDA processors in order to run. case "resnet101": _trainBatchSize /= 6; _testBatchSize /= 8; model = ResNet.ResNet101(_numClasses, device); break; case "resnet152": _testBatchSize /= 4; model = ResNet.ResNet152(_numClasses, device); break; #endif } var hflip = torchvision.transforms.HorizontalFlip(); var gray = torchvision.transforms.Grayscale(3); var rotate = torchvision.transforms.Rotate(90); var contrast = torchvision.transforms.AdjustContrast(1.25); Console.WriteLine($"\tPreparing training and test data..."); Console.WriteLine(); using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device, transforms: new torchvision.ITransform[] { })) using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device)) using (var optimizer = torch.optim.Adam(model.parameters(), 0.001)) { Stopwatch totalSW = new Stopwatch(); totalSW.Start(); for (var epoch = 1; epoch <= epochs; epoch++) { Stopwatch epchSW = new Stopwatch(); epchSW.Start(); Train(model, optimizer, nll_loss(), train.Data(), epoch, _trainBatchSize, train.Size); Test(model, nll_loss(), test.Data(), test.Size); epchSW.Stop(); Console.WriteLine($"Elapsed time for this epoch: {epchSW.Elapsed.TotalSeconds} s."); if (totalSW.Elapsed.TotalSeconds > timeout) { break; } } totalSW.Stop(); Console.WriteLine($"Elapsed training time: {totalSW.Elapsed} s."); } model.Dispose(); }