public PairSizeTDouble(PairSizeTDouble p) : this(CNTKLibPINVOKE.new_PairSizeTDouble__SWIG_2(PairSizeTDouble.getCPtr(p)), true) { if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } }
internal static global::System.Runtime.InteropServices.HandleRef getCPtr(PairSizeTDouble obj) { return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr); }
/// <summary> /// Train and evaluate an image classifier with CIFAR-10 data. /// The classification model is saved after training. /// For repeated runs, the caller may choose whether to retrain a model or /// just validate an existing one. /// </summary> /// <param name="device">CPU or GPU device to run</param> /// <param name="forceRetrain">whether to override an existing model. /// if true, any existing model will be overridden and the new one evaluated. /// if false and there is an existing model, the existing model is evaluated.</param> public static void TrainAndEvaluate(DeviceDescriptor device, bool forceRetrain) { string modelFile = Path.Combine(CifarDataFolder, "CNTK-CSharp.model"); // If a model already exists and not set to force retrain, validate the model and return. if (File.Exists(modelFile) && !forceRetrain) { ValidateModel(device, modelFile); return; } // prepare training data var minibatchSource = CreateMinibatchSource(Path.Combine(CifarDataFolder, "train_map.txt"), Path.Combine(CifarDataFolder, "CIFAR-10_mean.xml"), imageDim, numClasses, MaxEpochs); var imageStreamInfo = minibatchSource.StreamInfo("features"); var labelStreamInfo = minibatchSource.StreamInfo("labels"); // build a model var imageInput = CNTKLib.InputVariable(imageDim, imageStreamInfo.m_elementType, "Images"); var labelsVar = CNTKLib.InputVariable(new int[] { numClasses }, labelStreamInfo.m_elementType, "Labels"); var classifierOutput = ResNetClassifier(imageInput, numClasses, device, "classifierOutput"); // prepare for training var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelsVar, "lossFunction"); var prediction = CNTKLib.ClassificationError(classifierOutput, labelsVar, 3, "predictionError"); //学习率策略 double[] lrs = { 3e-2, 3e-3, 3e-4, 3e-4, 5e-5 }; //学习率 int[] check_point = { 80, 120, 160, 180 }; //学习率在epoch到达多少时更新 uint minibatchSize = 32; PairSizeTDouble p1 = new PairSizeTDouble(80, lrs[0]); PairSizeTDouble p2 = new PairSizeTDouble(40, lrs[1]); PairSizeTDouble p3 = new PairSizeTDouble(40, lrs[2]); PairSizeTDouble p4 = new PairSizeTDouble(20, lrs[3]); PairSizeTDouble p5 = new PairSizeTDouble(20, lrs[4]); VectorPairSizeTDouble vp = new VectorPairSizeTDouble() { p1, p2, p3, p4, p5 }; int sample_num_in_a_epoch = 50000; TrainingParameterScheduleDouble learningRateSchedule = new TrainingParameterScheduleDouble(vp, (uint)sample_num_in_a_epoch); //动量 var momentum = new TrainingParameterScheduleDouble(0.9, 1); //SGD Learner //var sgdLearner = Learner.SGDLearner(classifierOutput.Parameters(), learningRateSchedule); //Adam Learner ParameterVector parameterVector = new ParameterVector(); foreach (var parameter in classifierOutput.Parameters()) { parameterVector.Add(parameter); } var adamLearner = CNTKLib.AdamLearner(parameterVector, learningRateSchedule, momentum); //Trainer var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, new List <Learner> { adamLearner }); int outputFrequencyInMinibatches = 20, miniBatchCount = 0; Stopwatch sw = new Stopwatch(); sw.Start(); // Feed data to the trainer for number of epochs. Console.WriteLine("*****************Train Start*****************"); while (true) { var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device); // Stop training once max epochs is reached. if (minibatchData.empty()) { break; } trainer.TrainMinibatch(new Dictionary <Variable, MinibatchData>() { { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device); TestHelper.PrintTrainingProgress(trainer, adamLearner, miniBatchCount++, outputFrequencyInMinibatches); } // save the model var imageClassifier = Function.Combine(new List <Variable>() { trainingLoss, prediction, classifierOutput }, "ImageClassifier"); imageClassifier.Save(modelFile); Console.WriteLine("*****************Train Stop*****************"); // validate the model float acc = ValidateModel(device, modelFile); sw.Stop(); TimeSpan ts2 = sw.Elapsed; Console.WriteLine("*****************Validate Stop*****************"); string logstr = "Total time :" + ts2.TotalSeconds + "s. acc:" + acc; Console.WriteLine(logstr); int i = 1; while (System.IO.File.Exists("../../../../log_" + i.ToString() + ".txt")) { i++; } var file = System.IO.File.Create("../../../../log_" + i.ToString() + ".txt"); byte[] data = System.Text.Encoding.Default.GetBytes(logstr); file.Write(data, 0, data.Length); }