static void RunAllExamples(DeviceDescriptor device)
        {
            Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========");
            LogisticRegression.TrainAndEvaluate(device);

            Console.WriteLine($"======== running SimpleFeedForwardClassifier.TrainSimpleFeedForwardClassifier using {device.Type} ========");
            SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);

            Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with MLP classifier ========");
            MNISTClassifier.TrainAndEvaluate(device, false, true);

            Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with convolution neural network ========");
            MNISTClassifier.TrainAndEvaluate(device, true, true);

            if (device.Type == DeviceKind.GPU)
            {
                Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");
                CifarResNetClassifier.TrainAndEvaluate(device, true);
            }

            if (device.Type == DeviceKind.GPU)
            {
                Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithFlowerData using {device.Type} ========");
                TransferLearning.TrainAndEvaluateWithFlowerData(device, true);

                Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithAnimalData using {device.Type} ========");
                TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
            }
            Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
            LSTMSequenceClassifier.Train(device);
        }
예제 #2
0
        static void Main(string[] args)
        {
            try
            {
                //Logging.OnWriteLog += Logging_OnWriteLog;

                //Setting global device
                GlobalParameters.Device = DeviceDescriptor.CPUDevice;

                //XOR Example
                //XORExample.LoadData();
                //XORExample.BuildModel();
                //XORExample.Train();

                //Housing regression example
                //HousingRegression.LoadData();
                //HousingRegression.BuildModel();
                //HousingRegression.Train();

                //MNIST Classification example
                MNISTClassifier.LoadData();
                MNISTClassifier.BuildModel();
                MNISTClassifier.Train();

                //Time series prediction
                //TimeSeriesPrediction.LoadData();
                //TimeSeriesPrediction.BuildModel();
                //TimeSeriesPrediction.Train();


                //Multi variate time series prediction
                //MiltiVariateTimeSeriesPrediction.LoadData();
                //MiltiVariateTimeSeriesPrediction.BuildModel();
                //MiltiVariateTimeSeriesPrediction.Train();

                //Cifar-10 Classification example
                //Cifar10Classification.LoadData();
                //Cifar10Classification.BuildModel();
                //Cifar10Classification.Train();

                //Image classification example
                //Console.WriteLine("ResNet50 Prediction: " + ImageClassification.ImagenetTest(Common.ImageNetModel.ResNet50)[0].Name);
                //Console.WriteLine("Cifar 10 Prediction: " + ImageClassification.Cifar10Test(Common.Cifar10Model.ResNet110)[0].Name);


                //Object Detection
                //ObjectDetection.PascalDetection();
                //ObjectDetection.GroceryDetection();
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
                Console.ReadLine();
            }
        }
예제 #3
0
        static void Main(string[] args)
        {
            try
            {
                //Setting global device
                Logging.OnWriteLog += Logging_OnWriteLog;

                //XOR Example
                XORExample.LoadData();
                XORExample.BuildModel();
                XORExample.Train();

                //Housing regression example
                HousingRegression.LoadData();
                HousingRegression.BuildModel();
                HousingRegression.Train();

                //MNIST Classification example
                MNISTClassifier.LoadData();
                MNISTClassifier.BuildModel();
                MNISTClassifier.Train();

                //LSTM Time series example
                TimeSeriesPrediction.LoadData();
                TimeSeriesPrediction.BuildModel();
                TimeSeriesPrediction.Train();

                //Cifar - 10 Classification example
                //Cifar10Classification.LoadData();
                //Cifar10Classification.BuildModel();
                //Cifar10Classification.Train();

                //Image classification example
                //Console.WriteLine("ResNet50 Prediction: " + ImageClassification.ImagenetTest(SiaNet.Common.ImageNetModel.ResNet50)[0].Name);
                //Console.WriteLine("Cifar 10 Prediction: " + ImageClassification.Cifar10Test(SiaNet.Common.Cifar10Model.ResNet110)[0].Name);

                //Object Detection
                //ObjectDetection.PascalDetection();
                //ObjectDetection.GroceryDetection();
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
                Console.ReadLine();
            }
        }
예제 #4
0
파일: Program.cs 프로젝트: prithule/SiaNet
        static void Main(string[] args)
        {
            try
            {
                //Setting global device
                Logging.OnWriteLog += Logging_OnWriteLog;
                ImageDataFrame frame = new ImageDataFrame(Variable.InputVariable(new int[] { 28, 28, 1 }, DataType.Float), Variable.InputVariable(new int[] { 10 }, DataType.Float));
                //frame.ExtractCifar10();
                frame.Load(@"C:\BDK\Dataset\Cifar10");
                //Housing regression example
                HousingRegression.LoadData();
                HousingRegression.BuildModel();
                HousingRegression.Train();

                //MNIST Classification example
                MNISTClassifier.LoadData();
                MNISTClassifier.BuildModel();
                MNISTClassifier.Train();

                //Cifar - 10 Classification example
                //Cifar10Classification.LoadData();
                //Cifar10Classification.BuildModel();
                //Cifar10Classification.Train();

                //Image classification example
                Console.WriteLine("ResNet50 Prediction: " + ImageClassification.ImagenetTest(SiaNet.Common.ImageNetModel.ResNet50)[0].Name);
                Console.WriteLine("Cifar 10 Prediction: " + ImageClassification.Cifar10Test(SiaNet.Common.Cifar10Model.ResNet110)[0].Name);

                //Object Detection
                //ObjectDetection.PascalDetection();
                //ObjectDetection.GroceryDetection();
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
                Console.ReadLine();
            }
        }
        static void Main(string[] args)
        {
            // Todo: move to a separate unit test.
            Console.WriteLine("Test CNTKLibraryCSTrainingExamples");
#if CPUONLY
            Console.WriteLine("======== Train model using CPUOnly build ========");
#else
            Console.WriteLine("======== Train model using GPU build ========");
#endif
            List <DeviceDescriptor> devices = new List <DeviceDescriptor>();
            if (ShouldRunOnCpu())
            {
                devices.Add(DeviceDescriptor.CPUDevice);
            }
            if (ShouldRunOnGpu())
            {
                devices.Add(DeviceDescriptor.GPUDevice(0));
            }

            string runTest = args.Length == 0 ? string.Empty : args[0];

            if (args.Length > 1)
            {
                Console.WriteLine($"-------- running with test data prefix : {args[1]} --------");
                TestCommon.TestDataDirPrefix = args[1];
            }
            else
            {
                Console.WriteLine("-------- No data folder path found in input, using default paths.");
                TestCommon.TestDataDirPrefix = "../../";
            }

            foreach (var device in devices)
            {
                // Data folders of example classes are set for non-CNTK test runs.
                // In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly.
                switch (runTest)
                {
                case "LogisticRegressionTest":
                    Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========");
                    LogisticRegression.TrainAndEvaluate(device);
                    break;

                case "SimpleFeedForwardClassifierTest":
                    Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========");
                    SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);
                    break;

                case "CifarResNetClassifierTest":
                    Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");

                    if (args.Length > 1)
                    {
                        Console.WriteLine($"-------- running with test data in {args[1]} --------");
                        // this test uses data from external folder, we execute this test with full data dir.
                        CifarResNetClassifier.CifarDataFolder = TestCommon.TestDataDirPrefix;
                    }

                    CifarResNetClassifier.TrainAndEvaluate(device, true);
                    break;

                case "LSTMSequenceClassifierTest":
                    Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
                    LSTMSequenceClassifier.Train(device);
                    break;

                case "MNISTClassifierTest":
                    Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate with Convnet using {device.Type} ========");
                    MNISTClassifier.TrainAndEvaluate(device, true, true);
                    break;

                case "TransferLearningTest":
                    TransferLearning.BaseResnetModelFile = "ResNet_18.model";
                    Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========");
                    TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
                    break;

                case "":
                    RunAllExamples(device);
                    break;

                default:
                    Console.WriteLine("'{0}' is not a valid test name.", runTest);
                    break;
                }
            }

            Console.WriteLine("======== Train completes. ========");
        }
예제 #6
0
파일: Program.cs 프로젝트: roymanoj/CNTK
        static void Main(string[] args)
        {
            // Todo: move to a separate unit test.
            Console.WriteLine("Test CNTKLibraryCSTrainingExamples");
#if CPUONLY
            Console.WriteLine("======== Train model using CPUOnly build ========");
#else
            Console.WriteLine("======== Train model using GPU build ========");
#endif

            List <DeviceDescriptor> devices = new List <DeviceDescriptor>();
            if (ShouldRunOnCpu())
            {
                devices.Add(DeviceDescriptor.CPUDevice);
            }
            if (ShouldRunOnGpu())
            {
                devices.Add(DeviceDescriptor.GPUDevice(0));
            }

            string runTest = args.Length == 0 ? string.Empty : args[0];


            foreach (var device in devices)
            {
                /// Data folders of example classes are set for non-CNTK test runs.
                /// In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly.
                switch (runTest)
                {
                case "SimpleFeedForwardClassifierTest":
                    SimpleFeedForwardClassifierTest.DataFolder = ".";
                    Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========");
                    SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);
                    break;

                case "CifarResNetClassifierTest":
                    CifarResNetClassifier.CifarDataFolder = "./cifar-10-batches-py";
                    Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");
                    CifarResNetClassifier.TrainAndEvaluate(device, true);
                    break;

                case "LSTMSequenceClassifierTest":
                    LSTMSequenceClassifier.DataFolder = "../../../Text/SequenceClassification/Data";
                    Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
                    LSTMSequenceClassifier.Train(device);
                    break;

                case "MNISTClassifierTest":
                    MNISTClassifier.ImageDataFolder = "../../../Image/Data/";
                    Console.WriteLine($"======== running MNISTClassifierTest.TrainAndEvaluate with Convnet using {device.Type} ========");
                    MNISTClassifier.TrainAndEvaluate(device, true, true);
                    break;

                case "TransferLearningTest":
                    TransferLearning.ExampleImageFoler   = ".";
                    TransferLearning.BaseResnetModelFile = "ResNet_18.model";
                    Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========");
                    TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
                    break;

                default:
                    RunAllExamples(device);
                    break;
                }
            }

            Console.WriteLine("======== Train completes. ========");
        }