public TrainingParameterScheduleDouble(VectorPairSizeTDouble schedule) : this(CNTKLibPINVOKE.new_TrainingParameterScheduleDouble__SWIG_4(VectorPairSizeTDouble.getCPtr(schedule)), true) { if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } }
public VectorPairSizeTDouble(VectorPairSizeTDouble other) : this(CNTKLibPINVOKE.new_VectorPairSizeTDouble__SWIG_1(VectorPairSizeTDouble.getCPtr(other)), true) { if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } }
public VectorPairSizeTDoubleEnumerator(VectorPairSizeTDouble collection) { collectionRef = collection; currentIndex = -1; currentObject = null; currentSize = collectionRef.Count; }
public void SetRange(int index, VectorPairSizeTDouble values) { CNTKLibPINVOKE.VectorPairSizeTDouble_SetRange(swigCPtr, index, VectorPairSizeTDouble.getCPtr(values)); if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } }
public static VectorPairSizeTDouble Repeat(PairSizeTDouble value, int count) { global::System.IntPtr cPtr = CNTKLibPINVOKE.VectorPairSizeTDouble_Repeat(PairSizeTDouble.getCPtr(value), count); VectorPairSizeTDouble ret = (cPtr == global::System.IntPtr.Zero) ? null : new VectorPairSizeTDouble(cPtr, true); if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } return(ret); }
public VectorPairSizeTDouble GetRange(int index, int count) { global::System.IntPtr cPtr = CNTKLibPINVOKE.VectorPairSizeTDouble_GetRange(swigCPtr, index, count); VectorPairSizeTDouble ret = (cPtr == global::System.IntPtr.Zero) ? null : new VectorPairSizeTDouble(cPtr, true); if (CNTKLibPINVOKE.SWIGPendingException.Pending) { throw CNTKLibPINVOKE.SWIGPendingException.Retrieve(); } return(ret); }
public Agent(int stateSize, int actionSize, int layerSize) { m_stateSize = stateSize; m_actionSize = actionSize; m_localNetwork = Model.CreateNetwork(m_stateSize, m_actionSize, layerSize, out m_stateInput); m_targetNetwork = Model.CreateNetwork(m_stateSize, m_actionSize, layerSize, out m_stateTargetInput); m_qTargetOutput = CNTKLib.InputVariable(new int[] { m_actionSize }, DataType.Float, "targetOutput"); var loss = CNTKLib.Square(CNTKLib.Minus(m_localNetwork, m_qTargetOutput)); var meas = CNTKLib.Square(CNTKLib.Minus(m_localNetwork, m_qTargetOutput)); //learning rate schedule var vp = new VectorPairSizeTDouble() { //new PairSizeTDouble(2, 0.2), //new PairSizeTDouble(1, 0.1), //new PairSizeTDouble(1, 0.05), //new PairSizeTDouble(1, 0.02), new PairSizeTDouble(1, 0.02), new PairSizeTDouble(1, 0.01), }; //per training batch var learningRate = new TrainingParameterScheduleDouble(vp, 4000); var learner = new List <Learner>() { Learner.SGDLearner(m_localNetwork.Parameters(), learningRate) }; m_trainer = Trainer.CreateTrainer(m_localNetwork, loss, null, learner); m_memory = new Memory(m_stateSize); }
internal static global::System.Runtime.InteropServices.HandleRef getCPtr(VectorPairSizeTDouble obj) { return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr); }
/// <summary> /// Train and evaluate an image classifier with CIFAR-10 data. /// The classification model is saved after training. /// For repeated runs, the caller may choose whether to retrain a model or /// just validate an existing one. /// </summary> /// <param name="device">CPU or GPU device to run</param> /// <param name="forceRetrain">whether to override an existing model. /// if true, any existing model will be overridden and the new one evaluated. /// if false and there is an existing model, the existing model is evaluated.</param> public static void TrainAndEvaluate(DeviceDescriptor device, bool forceRetrain) { string modelFile = Path.Combine(CifarDataFolder, "CNTK-CSharp.model"); // If a model already exists and not set to force retrain, validate the model and return. if (File.Exists(modelFile) && !forceRetrain) { ValidateModel(device, modelFile); return; } // prepare training data var minibatchSource = CreateMinibatchSource(Path.Combine(CifarDataFolder, "train_map.txt"), Path.Combine(CifarDataFolder, "CIFAR-10_mean.xml"), imageDim, numClasses, MaxEpochs); var imageStreamInfo = minibatchSource.StreamInfo("features"); var labelStreamInfo = minibatchSource.StreamInfo("labels"); // build a model var imageInput = CNTKLib.InputVariable(imageDim, imageStreamInfo.m_elementType, "Images"); var labelsVar = CNTKLib.InputVariable(new int[] { numClasses }, labelStreamInfo.m_elementType, "Labels"); var classifierOutput = ResNetClassifier(imageInput, numClasses, device, "classifierOutput"); // prepare for training var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelsVar, "lossFunction"); var prediction = CNTKLib.ClassificationError(classifierOutput, labelsVar, 3, "predictionError"); //学习率策略 double[] lrs = { 3e-2, 3e-3, 3e-4, 3e-4, 5e-5 }; //学习率 int[] check_point = { 80, 120, 160, 180 }; //学习率在epoch到达多少时更新 uint minibatchSize = 32; PairSizeTDouble p1 = new PairSizeTDouble(80, lrs[0]); PairSizeTDouble p2 = new PairSizeTDouble(40, lrs[1]); PairSizeTDouble p3 = new PairSizeTDouble(40, lrs[2]); PairSizeTDouble p4 = new PairSizeTDouble(20, lrs[3]); PairSizeTDouble p5 = new PairSizeTDouble(20, lrs[4]); VectorPairSizeTDouble vp = new VectorPairSizeTDouble() { p1, p2, p3, p4, p5 }; int sample_num_in_a_epoch = 50000; TrainingParameterScheduleDouble learningRateSchedule = new TrainingParameterScheduleDouble(vp, (uint)sample_num_in_a_epoch); //动量 var momentum = new TrainingParameterScheduleDouble(0.9, 1); //SGD Learner //var sgdLearner = Learner.SGDLearner(classifierOutput.Parameters(), learningRateSchedule); //Adam Learner ParameterVector parameterVector = new ParameterVector(); foreach (var parameter in classifierOutput.Parameters()) { parameterVector.Add(parameter); } var adamLearner = CNTKLib.AdamLearner(parameterVector, learningRateSchedule, momentum); //Trainer var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, new List <Learner> { adamLearner }); int outputFrequencyInMinibatches = 20, miniBatchCount = 0; Stopwatch sw = new Stopwatch(); sw.Start(); // Feed data to the trainer for number of epochs. Console.WriteLine("*****************Train Start*****************"); while (true) { var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device); // Stop training once max epochs is reached. if (minibatchData.empty()) { break; } trainer.TrainMinibatch(new Dictionary <Variable, MinibatchData>() { { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device); TestHelper.PrintTrainingProgress(trainer, adamLearner, miniBatchCount++, outputFrequencyInMinibatches); } // save the model var imageClassifier = Function.Combine(new List <Variable>() { trainingLoss, prediction, classifierOutput }, "ImageClassifier"); imageClassifier.Save(modelFile); Console.WriteLine("*****************Train Stop*****************"); // validate the model float acc = ValidateModel(device, modelFile); sw.Stop(); TimeSpan ts2 = sw.Elapsed; Console.WriteLine("*****************Validate Stop*****************"); string logstr = "Total time :" + ts2.TotalSeconds + "s. acc:" + acc; Console.WriteLine(logstr); int i = 1; while (System.IO.File.Exists("../../../../log_" + i.ToString() + ".txt")) { i++; } var file = System.IO.File.Create("../../../../log_" + i.ToString() + ".txt"); byte[] data = System.Text.Encoding.Default.GetBytes(logstr); file.Write(data, 0, data.Length); }