static void RunAllExamples(DeviceDescriptor device)
        {
            Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========");
            LogisticRegression.TrainAndEvaluate(device);

            Console.WriteLine($"======== running SimpleFeedForwardClassifier.TrainSimpleFeedForwardClassifier using {device.Type} ========");
            SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);

            Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with MLP classifier ========");
            MNISTClassifier.TrainAndEvaluate(device, false, true);

            Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate using {device.Type} with convolution neural network ========");
            MNISTClassifier.TrainAndEvaluate(device, true, true);

            if (device.Type == DeviceKind.GPU)
            {
                Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");
                CifarResNetClassifier.TrainAndEvaluate(device, true);
            }

            if (device.Type == DeviceKind.GPU)
            {
                Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithFlowerData using {device.Type} ========");
                TransferLearning.TrainAndEvaluateWithFlowerData(device, true);

                Console.WriteLine($"======== running TransferLearning.TrainAndEvaluateWithAnimalData using {device.Type} ========");
                TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
            }
            Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
            LSTMSequenceClassifier.Train(device);
        }
예제 #2
0
        public void lstm_test01()
        {
            //define values, and variables
            Variable x       = Variable.InputVariable(new int[] { 4 }, DataType.Float, "input");
            var      xValues = Value.CreateBatchOfSequences <float>(new int[] { 4 }, mData, device);

            //
            var lstm00 = RNN.RecurrenceLSTM(x, 3, 3, DataType.Float, device, false, Activation.TanH, true, true, 1);

            //
            LSTMReccurentNN lstmNN = new LSTMReccurentNN(1, 1, device);
            //lstm implme reference 01
            var lstmCell = lstmNN.CreateLSTM(x, "output1");
            var lstm01   = CNTKLib.SequenceLast(lstmCell.h);

            //lstme implementation refe 02
            var lstm02 = LSTMSequenceClassifier.LSTMNet(x, 1, device, "output1");

            //
            var wParams00 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();
            var wParams01 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();
            var wParams02 = lstm00.Inputs.Where(p => p.Uid.Contains("Parameter")).ToList();

            //parameter count
            Assert.Equal(wParams00.Count, wParams01.Count);
            Assert.Equal(wParams00.Count, wParams02.Count);

            //structure of parameters test
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_b")).Count(), wParams01.Where(p => p.Name.Contains("_b")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_w")).Count(), wParams01.Where(p => p.Name.Contains("_w")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("_u")).Count(), wParams01.Where(p => p.Name.Contains("_u")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("peep")).Count(), wParams01.Where(p => p.Name.Contains("peep")).Count());
            Assert.Equal(wParams00.Where(p => p.Name.Contains("stabilize")).Count(), wParams01.Where(p => p.Name.Contains("stabilize")).Count());


            //check structure of parameters with originaly developed lstm
            //chech for arguments
            Assert.True(lstm01.Arguments.Count == lstm02.Arguments.Count);
            for (int i = 0; i < lstm01.Arguments.Count; i++)
            {
                testVariable(lstm01.Arguments[i], lstm01.Arguments[i]);
            }

            ///
            Assert.True(lstm01.Inputs.Count == lstm02.Inputs.Count);
            for (int i = 0; i < lstm01.Inputs.Count; i++)
            {
                testVariable(lstm01.Inputs[i], lstm02.Inputs[i]);
            }

            ///
            Assert.True(lstm01.Outputs.Count == lstm02.Outputs.Count);
            for (int i = 0; i < lstm01.Outputs.Count; i++)
            {
                testVariable(lstm01.Outputs[i], lstm02.Outputs[i]);
            }
        }
        static void Main(string[] args)
        {
            // Todo: move to a separate unit test.
            Console.WriteLine("Test CNTKLibraryCSTrainingExamples");
#if CPUONLY
            Console.WriteLine("======== Train model using CPUOnly build ========");
#else
            Console.WriteLine("======== Train model using GPU build ========");
#endif
            List <DeviceDescriptor> devices = new List <DeviceDescriptor>();
            if (ShouldRunOnCpu())
            {
                devices.Add(DeviceDescriptor.CPUDevice);
            }
            if (ShouldRunOnGpu())
            {
                devices.Add(DeviceDescriptor.GPUDevice(0));
            }

            string runTest = args.Length == 0 ? string.Empty : args[0];

            if (args.Length > 1)
            {
                Console.WriteLine($"-------- running with test data prefix : {args[1]} --------");
                TestCommon.TestDataDirPrefix = args[1];
            }
            else
            {
                Console.WriteLine("-------- No data folder path found in input, using default paths.");
                TestCommon.TestDataDirPrefix = "../../";
            }

            foreach (var device in devices)
            {
                // Data folders of example classes are set for non-CNTK test runs.
                // In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly.
                switch (runTest)
                {
                case "LogisticRegressionTest":
                    Console.WriteLine($"======== running LogisticRegression.TrainAndEvaluate using {device.Type} ========");
                    LogisticRegression.TrainAndEvaluate(device);
                    break;

                case "SimpleFeedForwardClassifierTest":
                    Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========");
                    SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);
                    break;

                case "CifarResNetClassifierTest":
                    Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");

                    if (args.Length > 1)
                    {
                        Console.WriteLine($"-------- running with test data in {args[1]} --------");
                        // this test uses data from external folder, we execute this test with full data dir.
                        CifarResNetClassifier.CifarDataFolder = TestCommon.TestDataDirPrefix;
                    }

                    CifarResNetClassifier.TrainAndEvaluate(device, true);
                    break;

                case "LSTMSequenceClassifierTest":
                    Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
                    LSTMSequenceClassifier.Train(device);
                    break;

                case "MNISTClassifierTest":
                    Console.WriteLine($"======== running MNISTClassifier.TrainAndEvaluate with Convnet using {device.Type} ========");
                    MNISTClassifier.TrainAndEvaluate(device, true, true);
                    break;

                case "TransferLearningTest":
                    TransferLearning.BaseResnetModelFile = "ResNet_18.model";
                    Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========");
                    TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
                    break;

                case "":
                    RunAllExamples(device);
                    break;

                default:
                    Console.WriteLine("'{0}' is not a valid test name.", runTest);
                    break;
                }
            }

            Console.WriteLine("======== Train completes. ========");
        }
예제 #4
0
        public void LSTM_Test_Params_Count_with_peep_selfstabilize()
        {
            //define values, and variables
            Variable x = Variable.InputVariable(new int[] { 2 }, DataType.Float, "input");
            Variable y = Variable.InputVariable(new int[] { 3 }, DataType.Float, "output");

            #region lstm org implemented in cntk for reference
            //lstme implementation refe 02
            var lstmTest02 = LSTMSequenceClassifier.LSTMNet(x, 3, device, "output1");
            var ft2        = lstmTest02.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            var totalSize  = ft2.Sum(p => p.Shape.TotalSize);
            //bias params
            var bs2      = ft2.Where(p => p.Name.Contains("_b")).ToList();
            var totalBs2 = bs2.Sum(v => v.Shape.TotalSize);

            //weights
            var ws2      = ft2.Where(p => p.Name.Contains("_w")).ToList();
            var totalWs2 = ws2.Sum(v => v.Shape.TotalSize);

            //update
            var us2      = ft2.Where(p => p.Name.Contains("_u")).ToList();
            var totalUs2 = us2.Sum(v => v.Shape.TotalSize);

            //peephole
            var ph2      = ft2.Where(p => p.Name.Contains("_peep")).ToList();
            var totalph2 = ph2.Sum(v => v.Shape.TotalSize);

            //stabilize
            var st2      = ft2.Where(p => p.Name.Contains("_stabilize")).ToList();
            var totalst2 = st2.Sum(v => v.Shape.TotalSize);
            #endregion

            #region anndotnet old implementation
            //
            //LSTMReccurentNN lstmNN = new LSTMReccurentNN(3, 3, device);
            ////lstm implme reference 01
            //var lstmCell11 = lstmNN.CreateLSTM(x, "output1");
            //var lstmTest01 = CNTKLib.SequenceLast(lstmCell11.h);
            //var ft1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            //var consts1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Constant")).ToList();
            //var inp1 = lstmTest01.Inputs.Where(l => l.Uid.StartsWith("Input")).ToList();
            //var pparams1 = ft1.Sum(v => v.Shape.TotalSize);
            #endregion

            //Number of LSTM parameters
            var lstm1 = RNN.RecurrenceLSTM(x, 3, 3, DataType.Float, device, false, Activation.TanH, true, true, 1);

            var ft     = lstm1.Inputs.Where(l => l.Uid.StartsWith("Parameter")).ToList();
            var consts = lstm1.Inputs.Where(l => l.Uid.StartsWith("Constant")).ToList();
            var inp    = lstm1.Inputs.Where(l => l.Uid.StartsWith("Input")).ToList();

            //bias params
            var bs      = ft.Where(p => p.Name.Contains("_b")).ToList();
            var totalBs = bs.Sum(v => v.Shape.TotalSize);
            Assert.Equal(12, totalBs);
            //weights
            var ws      = ft.Where(p => p.Name.Contains("_w")).ToList();
            var totalWs = ws.Sum(v => v.Shape.TotalSize);
            Assert.Equal(24, totalWs);
            //update
            var us      = ft.Where(p => p.Name.Contains("_u")).ToList();
            var totalUs = us.Sum(v => v.Shape.TotalSize);
            Assert.Equal(36, totalUs);
            //peephole
            var ph      = ft.Where(p => p.Name.Contains("_peep")).ToList();
            var totalPh = ph.Sum(v => v.Shape.TotalSize);
            Assert.Equal(9, totalPh);
            //stabilize
            var st      = ft.Where(p => p.Name.Contains("_stabilize")).ToList();
            var totalst = st.Sum(v => v.Shape.TotalSize);
            Assert.Equal(6, totalst);

            var totalOnly          = totalBs + totalWs + totalUs;
            var totalWithSTabilize = totalOnly + totalst;
            var totalWithPeep      = totalOnly + totalPh;

            var totalP      = totalOnly + totalst + totalPh;
            var totalParams = ft.Sum(v => v.Shape.TotalSize);
            Assert.Equal(totalP, totalParams);
            //72- without peep and stab
            //75 - witout peep with stabil +3xm =
            //81 - with peephole and without stabil
            //87 - with peep ans stab 3+9
        }
예제 #5
0
파일: Program.cs 프로젝트: roymanoj/CNTK
        static void Main(string[] args)
        {
            // Todo: move to a separate unit test.
            Console.WriteLine("Test CNTKLibraryCSTrainingExamples");
#if CPUONLY
            Console.WriteLine("======== Train model using CPUOnly build ========");
#else
            Console.WriteLine("======== Train model using GPU build ========");
#endif

            List <DeviceDescriptor> devices = new List <DeviceDescriptor>();
            if (ShouldRunOnCpu())
            {
                devices.Add(DeviceDescriptor.CPUDevice);
            }
            if (ShouldRunOnGpu())
            {
                devices.Add(DeviceDescriptor.GPUDevice(0));
            }

            string runTest = args.Length == 0 ? string.Empty : args[0];


            foreach (var device in devices)
            {
                /// Data folders of example classes are set for non-CNTK test runs.
                /// In case of CNTK test runs (runTest is set to a test name) data folders need to be set accordingly.
                switch (runTest)
                {
                case "SimpleFeedForwardClassifierTest":
                    SimpleFeedForwardClassifierTest.DataFolder = ".";
                    Console.WriteLine($"======== running SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier using {device.Type} ========");
                    SimpleFeedForwardClassifierTest.TrainSimpleFeedForwardClassifier(device);
                    break;

                case "CifarResNetClassifierTest":
                    CifarResNetClassifier.CifarDataFolder = "./cifar-10-batches-py";
                    Console.WriteLine($"======== running CifarResNet.TrainAndEvaluate using {device.Type} ========");
                    CifarResNetClassifier.TrainAndEvaluate(device, true);
                    break;

                case "LSTMSequenceClassifierTest":
                    LSTMSequenceClassifier.DataFolder = "../../../Text/SequenceClassification/Data";
                    Console.WriteLine($"======== running LSTMSequenceClassifier.Train using {device.Type} ========");
                    LSTMSequenceClassifier.Train(device);
                    break;

                case "MNISTClassifierTest":
                    MNISTClassifier.ImageDataFolder = "../../../Image/Data/";
                    Console.WriteLine($"======== running MNISTClassifierTest.TrainAndEvaluate with Convnet using {device.Type} ========");
                    MNISTClassifier.TrainAndEvaluate(device, true, true);
                    break;

                case "TransferLearningTest":
                    TransferLearning.ExampleImageFoler   = ".";
                    TransferLearning.BaseResnetModelFile = "ResNet_18.model";
                    Console.WriteLine($"======== running TransferLearning.TrainAndEvaluate with animal data using {device.Type} ========");
                    TransferLearning.TrainAndEvaluateWithAnimalData(device, true);
                    break;

                default:
                    RunAllExamples(device);
                    break;
                }
            }

            Console.WriteLine("======== Train completes. ========");
        }