static void Main(string[] args) { int siteId = Convert.ToInt32(args[0]); Console.WriteLine("SiteId = {0}", siteId); bool advanced_input = (args[1] == "1"); Console.WriteLine("Advanced input: {0}", advanced_input); int M = Convert.ToInt32(args[2]); Console.WriteLine("M = {0}", M); int numEpochs = Convert.ToInt32(args[3]); Console.WriteLine("numEpochs = {0}", numEpochs); int inDim = Convert.ToInt32(args[4]); Console.WriteLine("inDim = {0}", inDim); int cellDim = Convert.ToInt32(args[5]); Console.WriteLine("cellDim = {0}", cellDim); int hiDim = Convert.ToInt32(args[6]); Console.WriteLine("hidim = {0}", hiDim); DeviceDescriptor device = DeviceDescriptor.CPUDevice; Console.WriteLine($"======== running LSTMSequence.Train using {DeviceDescriptor.CPUDevice} ========"); LSTMSequence network = new LSTMSequence(b, siteId, device, advanced_input); network.Train_predict(M, numEpochs, inDim, cellDim, hiDim); }
/// <summary> /// Build and train a RNN model. /// </summary> /// <param name="device">CPU or GPU device to train and run the model</param> public void Train_predict(int M, int numEpochs = 1500, int inDim = 30, int cellDim = 25, int hiDim = 5) { string featuresName = "features"; string labelsName = "label"; const int ouDim = 1; Dictionary <string, Set> dataSet = loadData(inDim, featuresName, labelsName, fun); var featureSet = dataSet[featuresName]; var labelSet = dataSet[labelsName]; ///// Debug data //int q = 0; //using (StreamWriter file = new StreamWriter("0.txt")) //{ // file.WriteLine("Train"); // for (int i = 0; i < featureSet.train.Length; i++) // { // file.Write(q + ": "); // for (int j = 0; j < featureSet.train[i].Length; j++) // file.Write(featureSet.train[i][j] + " "); // file.Write(labelSet.train[i][0]); // file.WriteLine(); // q++; // } // file.WriteLine("Valid"); // for (int i = 0; i < featureSet.valid.Length; i++) // { // file.Write(q + ": "); // for (int j = 0; j < featureSet.valid[i].Length; j++) // file.Write(featureSet.valid[i][j] + " "); // file.Write(labelSet.valid[i][0]); // file.WriteLine(); // q++; // } // file.WriteLine("Test"); // for (int i = 0; i < featureSet.test.Length; i++) // { // file.Write(q + ": "); // for (int j = 0; j < featureSet.test[i].Length; j++) // file.Write(featureSet.test[i][j] + " "); // file.Write(labelSet.test[i][0]); // file.WriteLine(); // q++; // } //} // build the model var feature = Variable.InputVariable(new int[] { inDim + (advanced_input ? 2 : 0) }, DataType.Float, featuresName, null, false /*isSparse*/); var label = Variable.InputVariable(new int[] { ouDim }, DataType.Float, labelsName, new List <CNTK.Axis>() { CNTK.Axis.DefaultBatchAxis() }, false); var lstmModel = CreateModel(feature, ouDim, hiDim, cellDim, "timeSeriesOutput"); Function trainingLoss = CNTKLib.SquaredError(lstmModel, label, "squarederrorLoss"); Function prediction = CNTKLib.SquaredError(lstmModel, label, "squarederrorEval"); // prepare for training TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.0005, 1); TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256); IList <Learner> parameterLearners = new List <Learner>() { Learner.MomentumSGDLearner(lstmModel.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */ true) }; var trainer = Trainer.CreateTrainer(lstmModel, trainingLoss, prediction, parameterLearners); // train the model int batchSize = 20; int outputFrequencyInMinibatches = 50; int miniBatchCount = 0; for (int i = 1; i <= numEpochs; i++) { //get the next minibatch amount of data foreach (var miniBatchData in LSTMSequence.nextBatch(featureSet.train, labelSet.train, batchSize)) { var xValues = Value.CreateBatch <float>(new NDShape(1, inDim + (advanced_input ? 2 : 0)), miniBatchData.X, device); var yValues = Value.CreateBatch <float>(new NDShape(1, ouDim), miniBatchData.Y, device); //Combine variables and data in to Dictionary for the training var batchData = new Dictionary <Variable, Value>(); batchData.Add(feature, xValues); batchData.Add(label, yValues); //train minibarch data trainer.TrainMinibatch(batchData, device); TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches); } } predict_test(dataSet, trainer.Model(), inDim, ouDim, batchSize, featuresName, labelsName, M); predict(dataSet, trainer.Model(), inDim, ouDim, batchSize, featuresName, labelsName, M); }