示例#1
0
        static void Main(string[] args)
        {
            var env        = new CliffWalkingEnvironment();
            var sarsaAgent = new SarsaCliffWalker();

            var values = sarsaAgent.ImproveEstimates(env, out var tdDiags, 10000);

            System.Console.WriteLine("td 0 avg Values:");
            StateActionValuesConsoleRenderer.RenderAverageValues(values);
            System.Console.WriteLine("td 0 highest Values:");
            StateActionValuesConsoleRenderer.RenderHighestValues(values);
            System.Console.WriteLine("");
            System.Console.WriteLine("td 0 Greedy path:");
            ConsolePathRenderer.RenderPath(GreedyPath(env, values));

            var qAgent  = new QLearningCliffWalker();
            var qValues = qAgent.ImproveEstimates(env, out var qDiags, 10000);

            System.Console.WriteLine("");
            System.Console.WriteLine("q learning avg Values:");
            StateActionValuesConsoleRenderer.RenderAverageValues(qValues);
            System.Console.WriteLine("q learning highest Values:");
            StateActionValuesConsoleRenderer.RenderHighestValues(qValues);
            System.Console.WriteLine("");
            System.Console.WriteLine("q learning Greedy path:");
            ConsolePathRenderer.RenderPath(GreedyPath(env, qValues));
        }
示例#2
0
        private static IEnumerable <double> GatherInterimPerformance(
            IEnumerable <double> learningRates,
            Func <double, ICliffWalkingAgent> createAgentFunc)
        {
            const int numRuns = 50;

            var env = new CliffWalkingEnvironment();

            foreach (var rate in learningRates)
            {
                var firstXEpisodeAverages = new List <double>();

                for (var i = 0; i < numRuns; i++)
                {
                    env.Reset();
                    var agent = createAgentFunc(rate);

                    agent.ImproveEstimates(env, out var diags, NumEpisodesForInterim);

                    firstXEpisodeAverages.Add(diags.RewardSumPerEpisode.Average());
                }

                yield return(firstXEpisodeAverages.Average());
            }
        }
示例#3
0
        public static void Run()
        {
            const int numEpisodes    = 100;
            var       env            = new CliffWalkingEnvironment();
            var       sarsaAgent     = new SarsaCliffWalker(0.1, 0.1);
            var       qLearningAgent = new QLearningCliffWalker(0.1, 0.1);

            var tdAverageRewards        = CollectAverageRewardSums(sarsaAgent, env, numEpisodes);
            var qLearningAverageRewards = CollectAverageRewardSums(qLearningAgent, env, numEpisodes);

            var plotter = new Plotter();
            var plt     = plotter.Plt;

            plt.Title("Average total reward per episode");
            var dataX = Enumerable.Range(0, numEpisodes).Select(i => (double)i).ToArray();

            plt.PlotScatter(dataX, tdAverageRewards, label: "TD 0 (Sarsa)");
            plt.PlotScatter(dataX, qLearningAverageRewards, label: "Q learning");

            plt.XLabel("Episodes");
            plt.YLabel("Average total reward");
            plt.Legend();

            plotter.Show();
        }
示例#4
0
        public void AfterImprovingEstimates_StartingPosition_HasActionValues()
        {
            var env   = new CliffWalkingEnvironment();
            var agent = new SarsaCliffWalker();

            var values = agent.ImproveEstimates(env, out var diags);

            Assert.IsNotEmpty(values.ActionValues(new Position(0, 0)));
        }
示例#5
0
        private static double[] CollectAverageRewardSums(
            ICliffWalkingAgent agent, CliffWalkingEnvironment env, int numEpisodes, int numRuns = 500)
        {
            var rewardSumSums = new double[numEpisodes];

            for (var i = 0; i < numRuns; i++)
            {
                agent.ImproveEstimates(env, out var diags, numEpisodes);
                for (var episode = 0; episode < numEpisodes; episode++)
                {
                    rewardSumSums[episode] += diags.RewardSumPerEpisode[episode];
                }
            }

            return(rewardSumSums.Select(s => s / numRuns).ToArray());
        }
示例#6
0
        private static IEnumerable <double> GatherAsymptoticPerformance(
            IEnumerable <double> learningRates,
            Func <double, ICliffWalkingAgent> createAgentFunc)
        {
            var env = new CliffWalkingEnvironment();

            foreach (var rate in learningRates)
            {
                env.Reset();
                var agent = createAgentFunc(rate);
                var sw    = Stopwatch.StartNew();

                agent.ImproveEstimates(env, out var diags, NumEpisodesForAsymptote);

                Console.WriteLine($"ran {NumEpisodesForAsymptote} episodes in {sw.Elapsed}");

                yield return(diags.RewardSumPerEpisode.Average());
            }
        }
示例#7
0
        private static IEnumerable <Position> GreedyPath(
            CliffWalkingEnvironment env, IStateActionValues values)
        {
            var currentPosition = env.Reset();
            var isDone          = false;

            while (!isDone)
            {
                yield return(currentPosition);

                var bestAction = values
                                 .ActionValues(currentPosition)
                                 .OrderBy(av => av.Item2)
                                 .Last().Item1;

                var(observation, _, done) = env.Step(bestAction);
                currentPosition           = observation;
                isDone = done;
            }

            yield return(currentPosition);
        }