Пример #1
0
        /// <summary>
        /// Policy iteration for MDP
        /// </summary>
        /// <param name="basePolicy">Starting policy</param>
        /// <param name="optimalPolicy">Resulting optimal policy</param>
        /// <param name="tolerance">Convergence tolerance</param>
        /// <returns>Optimal value</returns>
        public double GetOptimalValueViaPolicyIteration(
            IDeterministicPolicy <TState> basePolicy,
            out IDeterministicPolicy <TState> optimalPolicy,
            double tolerance = 0.0001)
        {
            var policy = (IDeterministicPolicy <TState>)basePolicy.Clone();
            var node   = Policies.AddFirst(basePolicy);
            var value  = GetOptimalValue(policy, tolerance);

            optimalPolicy = policy = PolicyIteration(policy);
            var outputStates = new List <TState> {
                _initialState
            };

            outputStates.AddRange(policy.GetAllowedActions(_initialState).Where(a => a != null)
                                  .SelectMany(a => a[_initialState]));
            while (policy.IsModified)
            {
                node = Policies.AddAfter(node, policy);
                var nextValue = GetOptimalValue(policy, tolerance, false, outputStates.Distinct().ToArray());
                if (LogProgress)
                {
                    Log?.Info($"{value.Mean}->{nextValue.Mean} at variance {Math.Sqrt(value.Variance - (value.Mean * value.Mean))}->{Math.Sqrt(nextValue.Variance - (nextValue.Mean * nextValue.Mean))} with {policy.Modifications.Length} modifications");
                }
                var ratio = value.Mean > 0 ? nextValue.Mean / value.Mean : value.Mean / nextValue.Mean;
                if ((nextValue.Mean - (tolerance * 100) < value.Mean) || Math.Abs(ratio) < 1 + RelativeOptimalTolerance)
                {
                    value = nextValue;
                    break;
                }

                value  = nextValue;
                policy = PolicyIteration(policy);
                if (node.Previous == null)
                {
                    continue;
                }

                ValueFunctions[node.Previous.Value].Clear();
                ValueFunctions.Remove(node.Previous.Value);
                Policies.Remove(node.Previous);
            }

            value         = GetOptimalValue(policy, tolerance);
            optimalPolicy = (IDeterministicPolicy <TState>)basePolicy.Clone();
            foreach (var state in AllStateSpace)
            {
                optimalPolicy[state] = policy[state];
            }

            ValueFunctions.Clear();
            Policies.Clear();

            return(value.Mean);
        }
Пример #2
0
        private IDeterministicPolicy <TState> PolicyIteration(IDeterministicPolicy <TState> policy)
        {
            if (!UseReachableStateSpace)
            {
                return(PartialPolicyIteration(policy));
            }

            if (ReachableStates == null)
            {
                ReachableStates = GetReachableStateSpace(policy, _initialState);
            }

            var nextPolicy = (IDeterministicPolicy <TState>)policy.Clone();

            var progress = 0.0;
            var step     = 1.0 / ReachableStates.Count;
            var logStep  = 0.01;

            Log?.Info("Calculating optimal actions...");
            foreach (var state in ReachableStates.Keys)
            {
                var optimalAction = GetOptimalActionEx(policy, state);
                if (optimalAction != null)
                {
                    nextPolicy[state] = optimalAction;
                }
                if ((progress / logStep) - Math.Truncate(progress / logStep) < step / logStep)
                {
                    Log?.Info($"{Math.Truncate(progress * 10000) / 100}%");
                }
                progress += step;
            }

            return(nextPolicy);
        }
Пример #3
0
        private IDeterministicPolicy <TState> PartialPolicyIteration(IDeterministicPolicy <TState> policy)
        {
            var nextPolicy = (IDeterministicPolicy <TState>)policy.Clone();
            var dict       = new Dictionary <TState, double> {
                [_initialState] = 1.0
            };
            var allStates = new HashSet <TState> {
                _initialState
            };

            // {0} -> A{0}
            var nextStates = new HashSet <TState>();

            nextPolicy[_initialState] = GetOptimalActionEx(policy, _initialState);
            foreach (var action in policy.GetAllowedActions(_initialState).Where(a => a != null))
            {
                var states = action[_initialState].ToImmutableHashSet();
                foreach (var state in states)
                {
                    dict[state] = action[_initialState, state];
                    var optimalAction = GetOptimalActionEx(policy, state);
                    if (optimalAction == null)
                    {
                        continue;
                    }
                    nextPolicy[state] = optimalAction;
                }

                nextStates.UnionWith(states);
                allStates.UnionWith(states);
            }

            var tol      = MinProbability;
            var maxProba = 1.0;
            var logStep  = 0.3;
            var scale    = -1 / Math.Log(tol, 10);

            if (LogProgress)
            {
                Log?.Info("Calculating optimal actions...");
            }
            var prevMaxProba = 1.0;

            while (maxProba > tol)
            {
                maxProba = 0.0;
                var nextNextStates = new HashSet <TState>();
                foreach (var state in nextStates)
                {
                    var optimalAction = GetOptimalActionEx(policy, state);
                    if (optimalAction != null)
                    {
                        nextPolicy[state] = optimalAction;
                    }

                    var stateProba = dict[state];
                    if (stateProba > maxProba)
                    {
                        maxProba = stateProba;
                    }

                    var action = nextPolicy[state];
                    if (action == null)
                    {
                        continue;
                    }

                    var actionStates = action[state].ToImmutableHashSet();
                    foreach (var nextState in actionStates)
                    {
                        var p = action[state, nextState];

                        if (dict.TryGetValue(nextState, out var current))
                        {
                            dict[nextState] = Math.Max(current, stateProba * p);
                        }
                        else
                        {
                            dict.Add(nextState, stateProba * p);
                        }
                    }

                    nextNextStates.UnionWith(actionStates);
                }

                if (Math.Log(prevMaxProba, 10) - Math.Log(maxProba, 10) > logStep && LogProgress)
                {
                    Log?.Info(
                        $"{Math.Truncate(Math.Min(-Math.Log(maxProba, 10) * scale, 1) * 10000) / 100}%");
                    prevMaxProba = maxProba;
                }

                nextStates.Clear();
                foreach (var state in nextNextStates)
                {
                    if (allStates.Add(state))
                    {
                        nextStates.Add(state);
                    }
                }
            }

            return(nextPolicy);
        }