コード例 #1
0
        static void Main(string[] args)
        {
            int siteId = Convert.ToInt32(args[0]);

            Console.WriteLine("SiteId = {0}", siteId);

            bool advanced_input = (args[1] == "1");

            Console.WriteLine("Advanced input: {0}", advanced_input);

            int M = Convert.ToInt32(args[2]);

            Console.WriteLine("M = {0}", M);

            int numEpochs = Convert.ToInt32(args[3]);

            Console.WriteLine("numEpochs = {0}", numEpochs);

            int inDim = Convert.ToInt32(args[4]);

            Console.WriteLine("inDim = {0}", inDim);

            int cellDim = Convert.ToInt32(args[5]);

            Console.WriteLine("cellDim = {0}", cellDim);

            int hiDim = Convert.ToInt32(args[6]);

            Console.WriteLine("hidim = {0}", hiDim);

            DeviceDescriptor device = DeviceDescriptor.CPUDevice;

            Console.WriteLine($"======== running LSTMSequence.Train using {DeviceDescriptor.CPUDevice} ========");

            LSTMSequence network = new LSTMSequence(b, siteId, device, advanced_input);

            network.Train_predict(M, numEpochs, inDim, cellDim, hiDim);
        }
コード例 #2
0
        /// <summary>
        /// Build and train a RNN model.
        /// </summary>
        /// <param name="device">CPU or GPU device to train and run the model</param>
        public void Train_predict(int M, int numEpochs = 1500, int inDim = 30, int cellDim = 25, int hiDim = 5)
        {
            string featuresName = "features";
            string labelsName   = "label";

            const int ouDim = 1;

            Dictionary <string, Set> dataSet = loadData(inDim, featuresName, labelsName, fun);


            var featureSet = dataSet[featuresName];
            var labelSet   = dataSet[labelsName];


            ///// Debug data
            //int q = 0;
            //using (StreamWriter file = new StreamWriter("0.txt"))
            //{
            //    file.WriteLine("Train");
            //    for (int i = 0; i < featureSet.train.Length; i++)
            //    {
            //        file.Write(q + ": ");
            //        for (int j = 0; j < featureSet.train[i].Length; j++)
            //            file.Write(featureSet.train[i][j] + " ");
            //        file.Write(labelSet.train[i][0]);
            //        file.WriteLine();
            //        q++;
            //    }

            //    file.WriteLine("Valid");
            //    for (int i = 0; i < featureSet.valid.Length; i++)
            //    {
            //        file.Write(q + ": ");
            //        for (int j = 0; j < featureSet.valid[i].Length; j++)
            //            file.Write(featureSet.valid[i][j] + " ");
            //        file.Write(labelSet.valid[i][0]);
            //        file.WriteLine();
            //        q++;
            //    }

            //    file.WriteLine("Test");
            //    for (int i = 0; i < featureSet.test.Length; i++)
            //    {
            //        file.Write(q + ": ");
            //        for (int j = 0; j < featureSet.test[i].Length; j++)
            //            file.Write(featureSet.test[i][j] + " ");
            //        file.Write(labelSet.test[i][0]);
            //        file.WriteLine();
            //        q++;
            //    }

            //}

            // build the model

            var feature = Variable.InputVariable(new int[] { inDim + (advanced_input ? 2 : 0) }, DataType.Float, featuresName, null, false /*isSparse*/);
            var label   = Variable.InputVariable(new int[] { ouDim }, DataType.Float, labelsName, new List <CNTK.Axis>()
            {
                CNTK.Axis.DefaultBatchAxis()
            }, false);

            var lstmModel = CreateModel(feature, ouDim, hiDim, cellDim, "timeSeriesOutput");

            Function trainingLoss = CNTKLib.SquaredError(lstmModel, label, "squarederrorLoss");
            Function prediction   = CNTKLib.SquaredError(lstmModel, label, "squarederrorEval");

            // prepare for training
            TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.0005, 1);
            TrainingParameterScheduleDouble momentumTimeConstant  = CNTKLib.MomentumAsTimeConstantSchedule(256);

            IList <Learner> parameterLearners = new List <Learner>()
            {
                Learner.MomentumSGDLearner(lstmModel.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */ true)
            };


            var trainer = Trainer.CreateTrainer(lstmModel, trainingLoss, prediction, parameterLearners);

            // train the model
            int batchSize = 20;
            int outputFrequencyInMinibatches = 50;
            int miniBatchCount = 0;

            for (int i = 1; i <= numEpochs; i++)
            {
                //get the next minibatch amount of data
                foreach (var miniBatchData in LSTMSequence.nextBatch(featureSet.train, labelSet.train, batchSize))
                {
                    var xValues = Value.CreateBatch <float>(new NDShape(1, inDim + (advanced_input ? 2 : 0)), miniBatchData.X, device);
                    var yValues = Value.CreateBatch <float>(new NDShape(1, ouDim), miniBatchData.Y, device);

                    //Combine variables and data in to Dictionary for the training
                    var batchData = new Dictionary <Variable, Value>();
                    batchData.Add(feature, xValues);
                    batchData.Add(label, yValues);

                    //train minibarch data
                    trainer.TrainMinibatch(batchData, device);

                    TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches);
                }
            }
            predict_test(dataSet, trainer.Model(), inDim, ouDim, batchSize, featuresName, labelsName, M);
            predict(dataSet, trainer.Model(), inDim, ouDim, batchSize, featuresName, labelsName, M);
        }