static void RunAllExamples(DeviceDescriptor device) { Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========"); LogisticRegression.TrainAndEvaluate(device); Console.WriteLine($"======== running SimpleFeedForwardClassifier.TrainSimpleFeedForwardClassifier using {device.Type} ========"); SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device); Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with MLP classifier ========"); MNISTClassifier.TrainAndEvaluate(device, false, true); Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with convolution neural network ========"); MNISTClassifier.TrainAndEvaluate(device, true, true); if (device.Type == DeviceKind.GPU) { Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========"); CifarResNetClassifier.TrainAndEvaluate(device, true); } if (device.Type == DeviceKind.GPU) { Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithFlowerData using {device.Type} ========"); TransferLearning.TrainAndEvaluateWithFlowerData(device, true); Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithAnimalData using {device.Type} ========"); TransferLearning.TrainAndEvaluateWithAnimalData(device, true); } Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========"); LSTMSequenceClassifier.Train(device); }
static void Main(string[] args) { // Todo: move to a separate unit test. Console.WriteLine("Test CNTKLibraryCSTrainingExamples"); #if CPUONLY Console.WriteLine("======== Train model using CPUOnly build ========"); #else Console.WriteLine("======== Train model using GPU build ========"); #endif List <DeviceDescriptor> devices = new List <DeviceDescriptor>(); if (ShouldRunOnCpu()) { devices.Add(DeviceDescriptor.CPUDevice); } if (ShouldRunOnGpu()) { devices.Add(DeviceDescriptor.GPUDevice(0)); } string runTest = args.Length == 0 ? string.Empty : args[0]; if (args.Length > 1) { Console.WriteLine($"-------- running with test data prefix : {args[1]} --------"); TestCommon.TestDataDirPrefix = args[1]; } else { Console.WriteLine("-------- No data folder path found in input, using default paths."); TestCommon.TestDataDirPrefix = "../../"; } foreach (var device in devices) { // Data folders of example classes are set for non-CNTK test runs. // In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly. switch (runTest) { case "LogisticRegressionTest": Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========"); LogisticRegression.TrainAndEvaluate(device); break; case "SimpleFeedForwardClassifierTest": Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========"); SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device); break; case "CifarResNetClassifierTest": Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========"); if (args.Length > 1) { Console.WriteLine($"-------- running with test data in {args[1]} --------"); // this test uses data from external folder, we execute this test with full data dir. CifarResNetClassifier.CifarDataFolder = TestCommon.TestDataDirPrefix; } CifarResNetClassifier.TrainAndEvaluate(device, true); break; case "LSTMSequenceClassifierTest": Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========"); LSTMSequenceClassifier.Train(device); break; case "MNISTClassifierTest": Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate with Convnet using {device.Type} ========"); MNISTClassifier.TrainAndEvaluate(device, true, true); break; case "TransferLearningTest": TransferLearning.BaseResnetModelFile = "ResNet_18.model"; Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========"); TransferLearning.TrainAndEvaluateWithAnimalData(device, true); break; case "": RunAllExamples(device); break; default: Console.WriteLine("'{0}' is not a valid test name.", runTest); break; } } Console.WriteLine("======== Train completes. ========"); }
static void Main(string[] args) { // Todo: move to a separate unit test. Console.WriteLine("Test CNTKLibraryCSTrainingExamples"); #if CPUONLY Console.WriteLine("======== Train model using CPUOnly build ========"); #else Console.WriteLine("======== Train model using GPU build ========"); #endif List <DeviceDescriptor> devices = new List <DeviceDescriptor>(); if (ShouldRunOnCpu()) { devices.Add(DeviceDescriptor.CPUDevice); } if (ShouldRunOnGpu()) { devices.Add(DeviceDescriptor.GPUDevice(0)); } string runTest = args.Length == 0 ? string.Empty : args[0]; foreach (var device in devices) { /// Data folders of example classes are set for non-CNTK test runs. /// In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly. switch (runTest) { case "SimpleFeedForwardClassifierTest": SimpleFeedForwardClassifierTest.DataFolder = "."; Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========"); SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device); break; case "CifarResNetClassifierTest": CifarResNetClassifier.CifarDataFolder = "./cifar-10-batches-py"; Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========"); CifarResNetClassifier.TrainAndEvaluate(device, true); break; case "LSTMSequenceClassifierTest": LSTMSequenceClassifier.DataFolder = "../../../Text/SequenceClassification/Data"; Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========"); LSTMSequenceClassifier.Train(device); break; case "MNISTClassifierTest": MNISTClassifier.ImageDataFolder = "../../../Image/Data/"; Console.WriteLine($"======== running MNISTClassifierTest.TrainAndEvaluate with Convnet using {device.Type} ========"); MNISTClassifier.TrainAndEvaluate(device, true, true); break; case "TransferLearningTest": TransferLearning.ExampleImageFoler = "."; TransferLearning.BaseResnetModelFile = "ResNet_18.model"; Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========"); TransferLearning.TrainAndEvaluateWithAnimalData(device, true); break; default: RunAllExamples(device); break; } } Console.WriteLine("======== Train completes. ========"); }