コード例 #1
0
ファイル: Trainers.DQL.cs プロジェクト: floAr/CNTKUnityTools
        public TrainerDQLSimple(DQLModel model, DQLModel modelTarget, LearnerDefs.LearnerDef learner, int numberOfActor = 1, int bufferSize = 50000, int maxStepHorizon = 2048, float endRandomChance = 0.05f, int randomChanceDropStepInterval = 10000)
        {
            Model       = model;
            ModelTarget = modelTarget;
            if (ModelTarget != null)
            {
                ModelTarget.OutputLoss.ToFunction().RestoreParametersByName(Model.OutputLoss.ToFunction());
            }


            MaxStepHorizon  = maxStepHorizon;
            NumberOfActor   = numberOfActor;
            Steps           = 0;
            EndRandomChance = endRandomChance;
            RandomChanceDropStepInterval = randomChanceDropStepInterval;

            //intiilaize the single episode history buffer
            statesEpisodeHistory  = new Dictionary <int, List <float> >();
            rewardsEpisodeHistory = new Dictionary <int, List <float> >();
            actionsEpisodeHistory = new Dictionary <int, List <float> >();
            gameEndEpisodeHistory = new Dictionary <int, List <float> >();
            for (int i = 0; i < numberOfActor; ++i)
            {
                statesEpisodeHistory[i]  = new List <float>();
                rewardsEpisodeHistory[i] = new List <float>();
                actionsEpisodeHistory[i] = new List <float>();
                gameEndEpisodeHistory[i] = new List <float>();
            }


            LastAction = new Dictionary <int, int>();
            LastState  = new Dictionary <int, float[]>();

            dataBuffer = new DataBuffer(bufferSize,
                                        new DataBuffer.DataInfo("State", DataBuffer.DataType.Float, Model.StateSize),
                                        new DataBuffer.DataInfo("Action", DataBuffer.DataType.Float, 1),
                                        new DataBuffer.DataInfo("Reward", DataBuffer.DataType.Float, 1),
                                        new DataBuffer.DataInfo("GameEnd", DataBuffer.DataType.Float, 1)
                                        );

            learners = new List <Learner>();
            List <Parameter> parameters = new List <Parameter>(Model.OutputLoss.ToFunction().Parameters());

            learners.Add(learner.Create(parameters));

            trainer = Trainer.CreateTrainer(Model.CNTKFunction, Model.OutputLoss, null, learners);
        }
コード例 #2
0
        public TrainerPPOSimple(PPOModel model, LearnerDefs.LearnerDef learner, int numberOfActor = 1, int bufferSize = 2048, int maxStepHorizon = 2048)
        {
            Model          = model;
            MaxStepHorizon = maxStepHorizon;
            NumberOfActor  = numberOfActor;

            statesEpisodeHistory      = new Dictionary <int, List <float> >();
            rewardsEpisodeHistory     = new Dictionary <int, List <float> >();
            actionsEpisodeHistory     = new Dictionary <int, List <float> >();
            valuesEpisodeHistory      = new Dictionary <int, List <float> >();
            actionprobsEpisodeHistory = new Dictionary <int, List <float> >();
            for (int i = 0; i < numberOfActor; ++i)
            {
                statesEpisodeHistory[i]      = new List <float>();
                rewardsEpisodeHistory[i]     = new List <float>();
                actionsEpisodeHistory[i]     = new List <float>();
                valuesEpisodeHistory[i]      = new List <float>();
                actionprobsEpisodeHistory[i] = new List <float>();
            }


            LastState       = new Dictionary <int, float[]>();
            LastAction      = new Dictionary <int, float[]>();
            LastActionProbs = new Dictionary <int, float[]>();
            LastValue       = new Dictionary <int, float>();

            dataBuffer = new DataBuffer(bufferSize,
                                        new DataBuffer.DataInfo("State", DataBuffer.DataType.Float, Model.StateSize),
                                        new DataBuffer.DataInfo("Action", DataBuffer.DataType.Float, Model.IsActionContinuous?Model.ActionSize:1),
                                        new DataBuffer.DataInfo("ActionProb", DataBuffer.DataType.Float, Model.IsActionContinuous ? Model.ActionSize : 1),
                                        new DataBuffer.DataInfo("TargetValue", DataBuffer.DataType.Float, 1),
                                        new DataBuffer.DataInfo("Advantage", DataBuffer.DataType.Float, 1)
                                        );

            learners = new List <Learner>();
            List <Parameter> parameters = new List <Parameter>(Model.OutputLoss.ToFunction().Parameters());

            learners.Add(learner.Create(parameters));

            //test
            trainer = Trainer.CreateTrainer(Model.OutputLoss, Model.OutputLoss, Model.OutputLoss, learners);
        }