Пример #1
0
 public TrainingParameterScheduleDouble(VectorPairSizeTDouble schedule) : this(CNTKLibPINVOKE.new_TrainingParameterScheduleDouble__SWIG_4(VectorPairSizeTDouble.getCPtr(schedule)), true)
 {
     if (CNTKLibPINVOKE.SWIGPendingException.Pending)
     {
         throw CNTKLibPINVOKE.SWIGPendingException.Retrieve();
     }
 }
Пример #2
0
 public VectorPairSizeTDouble(VectorPairSizeTDouble other) : this(CNTKLibPINVOKE.new_VectorPairSizeTDouble__SWIG_1(VectorPairSizeTDouble.getCPtr(other)), true)
 {
     if (CNTKLibPINVOKE.SWIGPendingException.Pending)
     {
         throw CNTKLibPINVOKE.SWIGPendingException.Retrieve();
     }
 }
Пример #3
0
 public VectorPairSizeTDoubleEnumerator(VectorPairSizeTDouble collection)
 {
     collectionRef = collection;
     currentIndex  = -1;
     currentObject = null;
     currentSize   = collectionRef.Count;
 }
Пример #4
0
 public void SetRange(int index, VectorPairSizeTDouble values)
 {
     CNTKLibPINVOKE.VectorPairSizeTDouble_SetRange(swigCPtr, index, VectorPairSizeTDouble.getCPtr(values));
     if (CNTKLibPINVOKE.SWIGPendingException.Pending)
     {
         throw CNTKLibPINVOKE.SWIGPendingException.Retrieve();
     }
 }
Пример #5
0
        public static VectorPairSizeTDouble Repeat(PairSizeTDouble value, int count)
        {
            global::System.IntPtr cPtr = CNTKLibPINVOKE.VectorPairSizeTDouble_Repeat(PairSizeTDouble.getCPtr(value), count);
            VectorPairSizeTDouble ret  = (cPtr == global::System.IntPtr.Zero) ? null : new VectorPairSizeTDouble(cPtr, true);

            if (CNTKLibPINVOKE.SWIGPendingException.Pending)
            {
                throw CNTKLibPINVOKE.SWIGPendingException.Retrieve();
            }
            return(ret);
        }
Пример #6
0
        public VectorPairSizeTDouble GetRange(int index, int count)
        {
            global::System.IntPtr cPtr = CNTKLibPINVOKE.VectorPairSizeTDouble_GetRange(swigCPtr, index, count);
            VectorPairSizeTDouble ret  = (cPtr == global::System.IntPtr.Zero) ? null : new VectorPairSizeTDouble(cPtr, true);

            if (CNTKLibPINVOKE.SWIGPendingException.Pending)
            {
                throw CNTKLibPINVOKE.SWIGPendingException.Retrieve();
            }
            return(ret);
        }
Пример #7
0
        public Agent(int stateSize, int actionSize, int layerSize)
        {
            m_stateSize  = stateSize;
            m_actionSize = actionSize;

            m_localNetwork  = Model.CreateNetwork(m_stateSize, m_actionSize, layerSize, out m_stateInput);
            m_targetNetwork = Model.CreateNetwork(m_stateSize, m_actionSize, layerSize, out m_stateTargetInput);

            m_qTargetOutput = CNTKLib.InputVariable(new int[] { m_actionSize }, DataType.Float, "targetOutput");

            var loss = CNTKLib.Square(CNTKLib.Minus(m_localNetwork, m_qTargetOutput));
            var meas = CNTKLib.Square(CNTKLib.Minus(m_localNetwork, m_qTargetOutput));

            //learning rate schedule
            var vp = new VectorPairSizeTDouble()
            {
                //new PairSizeTDouble(2, 0.2),
                //new PairSizeTDouble(1, 0.1),
                //new PairSizeTDouble(1, 0.05),
                //new PairSizeTDouble(1, 0.02),
                new PairSizeTDouble(1, 0.02),
                new PairSizeTDouble(1, 0.01),
            };

            //per training batch
            var learningRate = new TrainingParameterScheduleDouble(vp, 4000);

            var learner = new List <Learner>()
            {
                Learner.SGDLearner(m_localNetwork.Parameters(), learningRate)
            };

            m_trainer = Trainer.CreateTrainer(m_localNetwork, loss, null, learner);

            m_memory = new Memory(m_stateSize);
        }
Пример #8
0
 internal static global::System.Runtime.InteropServices.HandleRef getCPtr(VectorPairSizeTDouble obj)
 {
     return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr);
 }
            /// <summary>
            /// Train and evaluate an image classifier with CIFAR-10 data.
            /// The classification model is saved after training.
            /// For repeated runs, the caller may choose whether to retrain a model or
            /// just validate an existing one.
            /// </summary>
            /// <param name="device">CPU or GPU device to run</param>
            /// <param name="forceRetrain">whether to override an existing model.
            /// if true, any existing model will be overridden and the new one evaluated.
            /// if false and there is an existing model, the existing model is evaluated.</param>
            public static void TrainAndEvaluate(DeviceDescriptor device, bool forceRetrain)
            {
                string modelFile = Path.Combine(CifarDataFolder, "CNTK-CSharp.model");

                // If a model already exists and not set to force retrain, validate the model and return.
                if (File.Exists(modelFile) && !forceRetrain)
                {
                    ValidateModel(device, modelFile);
                    return;
                }

                // prepare training data
                var minibatchSource = CreateMinibatchSource(Path.Combine(CifarDataFolder, "train_map.txt"),
                                                            Path.Combine(CifarDataFolder, "CIFAR-10_mean.xml"), imageDim, numClasses, MaxEpochs);
                var imageStreamInfo = minibatchSource.StreamInfo("features");
                var labelStreamInfo = minibatchSource.StreamInfo("labels");

                // build a model
                var imageInput       = CNTKLib.InputVariable(imageDim, imageStreamInfo.m_elementType, "Images");
                var labelsVar        = CNTKLib.InputVariable(new int[] { numClasses }, labelStreamInfo.m_elementType, "Labels");
                var classifierOutput = ResNetClassifier(imageInput, numClasses, device, "classifierOutput");

                // prepare for training
                var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelsVar, "lossFunction");
                var prediction   = CNTKLib.ClassificationError(classifierOutput, labelsVar, 3, "predictionError");

                //学习率策略
                double[]        lrs           = { 3e-2, 3e-3, 3e-4, 3e-4, 5e-5 }; //学习率
                int[]           check_point   = { 80, 120, 160, 180 };            //学习率在epoch到达多少时更新
                uint            minibatchSize = 32;
                PairSizeTDouble p1            = new PairSizeTDouble(80, lrs[0]);
                PairSizeTDouble p2            = new PairSizeTDouble(40, lrs[1]);
                PairSizeTDouble p3            = new PairSizeTDouble(40, lrs[2]);
                PairSizeTDouble p4            = new PairSizeTDouble(20, lrs[3]);
                PairSizeTDouble p5            = new PairSizeTDouble(20, lrs[4]);

                VectorPairSizeTDouble vp = new VectorPairSizeTDouble()
                {
                    p1, p2, p3, p4, p5
                };
                int sample_num_in_a_epoch = 50000;
                TrainingParameterScheduleDouble learningRateSchedule = new TrainingParameterScheduleDouble(vp, (uint)sample_num_in_a_epoch);
                //动量
                var momentum = new TrainingParameterScheduleDouble(0.9, 1);
                //SGD Learner
                //var sgdLearner = Learner.SGDLearner(classifierOutput.Parameters(), learningRateSchedule);
                //Adam Learner
                ParameterVector parameterVector = new ParameterVector();

                foreach (var parameter in classifierOutput.Parameters())
                {
                    parameterVector.Add(parameter);
                }
                var adamLearner = CNTKLib.AdamLearner(parameterVector, learningRateSchedule, momentum);
                //Trainer
                var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, new List <Learner> {
                    adamLearner
                });

                int       outputFrequencyInMinibatches = 20, miniBatchCount = 0;
                Stopwatch sw = new Stopwatch();

                sw.Start();
                // Feed data to the trainer for number of epochs.
                Console.WriteLine("*****************Train Start*****************");
                while (true)
                {
                    var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);

                    // Stop training once max epochs is reached.
                    if (minibatchData.empty())
                    {
                        break;
                    }

                    trainer.TrainMinibatch(new Dictionary <Variable, MinibatchData>()
                    {
                        { imageInput, minibatchData[imageStreamInfo] },
                        { labelsVar, minibatchData[labelStreamInfo] }
                    }, device);

                    TestHelper.PrintTrainingProgress(trainer, adamLearner, miniBatchCount++, outputFrequencyInMinibatches);
                }

                // save the model
                var imageClassifier = Function.Combine(new List <Variable>()
                {
                    trainingLoss, prediction, classifierOutput
                }, "ImageClassifier");

                imageClassifier.Save(modelFile);
                Console.WriteLine("*****************Train Stop*****************");

                // validate the model
                float acc = ValidateModel(device, modelFile);

                sw.Stop();
                TimeSpan ts2 = sw.Elapsed;

                Console.WriteLine("*****************Validate Stop*****************");
                string logstr = "Total time :" + ts2.TotalSeconds + "s. acc:" + acc;

                Console.WriteLine(logstr);

                int i = 1;

                while (System.IO.File.Exists("../../../../log_" + i.ToString() + ".txt"))
                {
                    i++;
                }

                var file = System.IO.File.Create("../../../../log_" + i.ToString() + ".txt");

                byte[] data = System.Text.Encoding.Default.GetBytes(logstr);
                file.Write(data, 0, data.Length);
            }