예제 #1
0
        /// <summary>
        /// calcualte the discounted advantages for the current sequence of data, and add them to the databuffer
        /// </summary>
        protected void ProcessEpisodeHistory(float nextValue, int actorNum)
        {
            var advantages = RLUtils.GeneralAdvantageEst(rewardsEpisodeHistory[actorNum].ToArray(),
                                                         valuesEpisodeHistory[actorNum].ToArray(), RewardDiscountFactor, RewardGAEFactor, nextValue);

            float[] targetValues = new float[advantages.Length];
            for (int i = 0; i < targetValues.Length; ++i)
            {
                targetValues[i] = advantages[i] + valuesEpisodeHistory[actorNum][i];

                //test
                //advantages[i] = 1;
            }
            //test
            //targetValues = RLUtils.DiscountedRewards(rewardsEpisodeHistory.ToArray(), RewardDiscountFactor);


            dataBuffer.AddData(Tuple.Create <string, Array>("State", statesEpisodeHistory[actorNum].ToArray()),
                               Tuple.Create <string, Array>("Action", actionsEpisodeHistory[actorNum].ToArray()),
                               Tuple.Create <string, Array>("ActionProb", actionprobsEpisodeHistory[actorNum].ToArray()),
                               Tuple.Create <string, Array>("TargetValue", targetValues),
                               Tuple.Create <string, Array>("Advantage", advantages)
                               );

            statesEpisodeHistory[actorNum].Clear();
            rewardsEpisodeHistory[actorNum].Clear();
            actionsEpisodeHistory[actorNum].Clear();
            actionprobsEpisodeHistory[actorNum].Clear();
            valuesEpisodeHistory[actorNum].Clear();
        }
    public void EvaluateEpisode(List <float> vectorObsEpisodeHistory, List <List <float[, , ]> > visualEpisodeHistory, List <float> actionsEpisodeHistory, List <float> rewardsEpisodeHistory, List <List <float> > actionMasksEpisodeHistory,
                                out float[] values, out float[,] actionProbs, out float[] targetValues, out float[] advantages,
                                bool isDone, List <float> finalVectorObs = null, List <float[, , ]> finalVisualObs = null)
    {
        int obsSize = BrainToTrain.brainParameters.vectorObservationSize * BrainToTrain.brainParameters.numStackedVectorObservations;

        float[,] vectorObs = vectorObsEpisodeHistory.Reshape(obsSize);
        var visualObs   = CreateVisualInputBatch(visualEpisodeHistory, BrainToTrain.brainParameters.cameraResolutions);
        var actionMasks = CreateActionMasks(actionMasksEpisodeHistory, BrainToTrain.brainParameters.vectorActionSize);

        values      = iModelHPPO.EvaluateValue(vectorObs, visualObs);
        actionProbs = iModelHPPO.EvaluateProbability(vectorObs,
                                                     actionsEpisodeHistory.Reshape(BrainToTrain.brainParameters.vectorActionSpaceType == SpaceType.continuous ? BrainToTrain.brainParameters.vectorActionSize[0] : BrainToTrain.brainParameters.vectorActionSize.Length),
                                                     visualObs, actionMasks);



        //update process the episode data for PPO.
        float nextValue = 0;

        if (isDone)
        {
            nextValue = 0;  //this is very important!
        }
        else
        {
            List <List <float[, , ]> > visualTemp = new List <List <float[, , ]> >();
            foreach (var v in finalVisualObs)
            {
                var t = new List <float[, , ]>();
                t.Add(v);
                visualTemp.Add(t);
            }
            nextValue = iModelHPPO.EvaluateValue(finalVectorObs.Reshape(obsSize), CreateVisualInputBatch(visualTemp, BrainToTrain.brainParameters.cameraResolutions))[0];
        }

        advantages   = RLUtils.GeneralAdvantageEst(rewardsEpisodeHistory.ToArray(), values, parametersPPO.rewardDiscountFactor, parametersPPO.rewardGAEFactor, nextValue);
        targetValues = new float[advantages.Length];
        for (int i = 0; i < targetValues.Length; ++i)
        {
            targetValues[i] = advantages[i] + values[i];
        }
    }
    public override void ProcessExperience(Dictionary <Agent, AgentInfoInternal> currentInfo, Dictionary <Agent, AgentInfoInternal> newInfo)
    {
        var agentList = currentInfo.Keys;

        foreach (var agent in agentList)
        {
            var agentNewInfo = newInfo[agent];
            if (agentNewInfo.done || agentNewInfo.maxStepReached || rewardsEpisodeHistory[agent].Count > parametersPPO.timeHorizon)
            {
                //update process the episode data for PPO.
                float nextValue = 0;

                if (agentNewInfo.done && !agentNewInfo.maxStepReached)
                {
                    nextValue = 0;  //this is very important!
                }
                else
                {
                    nextValue = iModelPPO.EvaluateValue(Matrix.Reshape(agentNewInfo.stackedVectorObservation.ToArray(), 1, agentNewInfo.stackedVectorObservation.Count),
                                                        CreateVisualInputBatch(newInfo, new List <Agent>()
                    {
                        agent
                    }, BrainToTrain.brainParameters.cameraResolutions))[0];
                }

                var valueHistory = valuesEpisodeHistory[agent].ToArray();
                var advantages   = RLUtils.GeneralAdvantageEst(rewardsEpisodeHistory[agent].ToArray(),
                                                               valueHistory, parametersPPO.rewardDiscountFactor, parametersPPO.rewardGAEFactor, nextValue);
                float[] targetValues = new float[advantages.Length];
                for (int i = 0; i < targetValues.Length; ++i)
                {
                    targetValues[i] = advantages[i] + valueHistory[i];
                }

                //add those processed data to the buffer

                List <ValueTuple <string, Array> > dataToAdd = new List <ValueTuple <string, Array> >();
                dataToAdd.Add(ValueTuple.Create <string, Array>("Action", actionsEpisodeHistory[agent].ToArray()));
                dataToAdd.Add(ValueTuple.Create <string, Array>("ActionProb", actionprobsEpisodeHistory[agent].ToArray()));
                dataToAdd.Add(ValueTuple.Create <string, Array>("TargetValue", targetValues));
                dataToAdd.Add(ValueTuple.Create <string, Array>("OldValue", valueHistory));
                dataToAdd.Add(ValueTuple.Create <string, Array>("Advantage", advantages));
                if (statesEpisodeHistory[agent].Count > 0)
                {
                    dataToAdd.Add(ValueTuple.Create <string, Array>("VectorObservation", statesEpisodeHistory[agent].ToArray()));
                }
                for (int i = 0; i < visualEpisodeHistory[agent].Count; ++i)
                {
                    dataToAdd.Add(ValueTuple.Create <string, Array>("VisualObservation" + i, DataBuffer.ListToArray(visualEpisodeHistory[agent][i])));
                }
                for (int i = 0; i < actionMasksEpisodeHistory[agent].Count; ++i)
                {
                    dataToAdd.Add(ValueTuple.Create <string, Array>("ActionMask" + i, actionMasksEpisodeHistory[agent][i].ToArray()));
                }

                dataBuffer.AddData(dataToAdd.ToArray());

                //clear the temperary data record
                statesEpisodeHistory[agent].Clear();
                rewardsEpisodeHistory[agent].Clear();
                actionsEpisodeHistory[agent].Clear();
                actionprobsEpisodeHistory[agent].Clear();
                valuesEpisodeHistory[agent].Clear();
                for (int i = 0; i < visualEpisodeHistory[agent].Count; ++i)
                {
                    visualEpisodeHistory[agent][i].Clear();
                }
                for (int i = 0; i < actionMasksEpisodeHistory[agent].Count; ++i)
                {
                    actionMasksEpisodeHistory[agent][i].Clear();
                }



                //update stats if the agent is not using heuristic
                if (agentNewInfo.done || agentNewInfo.maxStepReached)
                {
                    var agentDecision = agent.GetComponent <AgentDependentDecision>();
                    if (!(isTraining && agentDecision != null && agentDecision.useDecision))// && parametersPPO.useHeuristicChance > 0
                    {
                        stats.AddData("accumulatedRewards", accumulatedRewards[agent]);
                        stats.AddData("episodeSteps", episodeSteps.ContainsKey(agent) ? episodeSteps[agent] : 0);
                    }


                    accumulatedRewards[agent] = 0;
                    episodeSteps[agent]       = 0;
                }
            }
        }
    }