// function VALUE-ITERATION(mdp, ε) returns a utility function /** * The value iteration algorithm for calculating the utility of states. * * @param mdp * an MDP with states S, actions A(s), <br> * transition model P(s' | s, a), rewards R(s) * @param epsilon * the maximum error allowed in the utility of any state * @return a vector of utilities for states in S */ public IMap <S, double> valueIteration(IMarkovDecisionProcess <S, A> mdp, double epsilon) { // // local variables: U, U', vectors of utilities for states in S, // initially zero IMap <S, double> U = Util.create(mdp.states(), 0D); IMap <S, double> Udelta = Util.create(mdp.states(), 0D); // δ the maximum change in the utility of any state in an // iteration double delta = 0; // Note: Just calculate this once for efficiency purposes: // ε(1 - γ)/γ double minDelta = epsilon * (1 - gamma) / gamma; // repeat do { // U <- U'; δ <- 0 U.PutAll(Udelta); delta = 0; // for each state s in S do foreach (S s in mdp.states()) { // max<sub>a ∈ A(s)</sub> ISet <A> actions = mdp.actions(s); // Handle terminal states (i.e. no actions). double aMax = 0; if (actions.Size() > 0) { aMax = double.NegativeInfinity; } foreach (A a in actions) { // Σ<sub>s'</sub>P(s' | s, a) U[s'] double aSum = 0; foreach (S sDelta in mdp.states()) { aSum += mdp.transitionProbability(sDelta, s, a) * U.Get(sDelta); } if (aSum > aMax) { aMax = aSum; } } // U'[s] <- R(s) + γ // max<sub>a ∈ A(s)</sub> Udelta.Put(s, mdp.reward(s) + gamma * aMax); // if |U'[s] - U[s]| > δ then δ <- |U'[s] - U[s]| double aDiff = System.Math.Abs(Udelta.Get(s) - U.Get(s)); if (aDiff > delta) { delta = aDiff; } } // until δ < ε(1 - γ)/γ } while (delta > minDelta); // return U return(U); }
public IMap <S, double> evaluate(IMap <S, A> pi_i, IMap <S, double> U, IMarkovDecisionProcess <S, A> mdp) { IMap <S, double> U_i = CollectionFactory.CreateMap <S, double>(U); IMap <S, double> U_ip1 = CollectionFactory.CreateMap <S, double>(U); // repeat k times to produce the next utility estimate for (int i = 0; i < k; ++i) { // U<sub>i+1</sub>(s) <- R(s) + // γΣ<sub>s'</sub>P(s'|s,π<sub>i</sub>(s))U<sub>i</sub>(s') foreach (S s in U.GetKeys()) { A ap_i = pi_i.Get(s); double aSum = 0; // Handle terminal states (i.e. no actions) if (null != ap_i) { foreach (S sDelta in U.GetKeys()) { aSum += mdp.transitionProbability(sDelta, s, ap_i) * U_i.Get(sDelta); } } U_ip1.Put(s, mdp.reward(s) + gamma * aSum); } U_i.PutAll(U_ip1); } return(U_ip1); }
public void testMDPTransitionModel() { Assert.AreEqual(0.8, mdp.transitionProbability(cw.GetCellAt(1, 2), cw.GetCellAt(1, 1), CellWorldAction.Up)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(1, 1), cw.GetCellAt(1, 1), CellWorldAction.Up)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(2, 1), cw.GetCellAt(1, 1), CellWorldAction.Up)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(1, 3), cw.GetCellAt(1, 1), CellWorldAction.Up)); Assert.AreEqual(0.9, mdp.transitionProbability(cw.GetCellAt(1, 1), cw.GetCellAt(1, 1), CellWorldAction.Down)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(2, 1), cw.GetCellAt(1, 1), CellWorldAction.Down)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(3, 1), cw.GetCellAt(1, 1), CellWorldAction.Down)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(1, 2), cw.GetCellAt(1, 1), CellWorldAction.Down)); Assert.AreEqual(0.9, mdp.transitionProbability(cw.GetCellAt(1, 1), cw.GetCellAt(1, 1), CellWorldAction.Left)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(2, 1), cw.GetCellAt(1, 1), CellWorldAction.Left)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(3, 1), cw.GetCellAt(1, 1), CellWorldAction.Left)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(1, 2), cw.GetCellAt(1, 1), CellWorldAction.Left)); Assert.AreEqual(0.8, mdp.transitionProbability(cw.GetCellAt(2, 1), cw.GetCellAt(1, 1), CellWorldAction.Right)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(1, 1), cw.GetCellAt(1, 1), CellWorldAction.Right)); Assert.AreEqual(0.1, mdp.transitionProbability(cw.GetCellAt(1, 2), cw.GetCellAt(1, 1), CellWorldAction.Right)); Assert.AreEqual(0.0, mdp.transitionProbability(cw.GetCellAt(1, 3), cw.GetCellAt(1, 1), CellWorldAction.Right)); }
// function POLICY-ITERATION(mdp) returns a policy /** * The policy iteration algorithm for calculating an optimal policy. * * @param mdp * an MDP with states S, actions A(s), transition model P(s'|s,a) * @return an optimal policy */ public IPolicy <S, A> policyIteration(IMarkovDecisionProcess <S, A> mdp) { // local variables: U, a vector of utilities for states in S, initially // zero IMap <S, double> U = Util.create(mdp.states(), 0D); // π, a policy vector indexed by state, initially random IMap <S, A> pi = initialPolicyVector(mdp); bool unchanged; // repeat do { // U <- POLICY-EVALUATION(π, U, mdp) U = policyEvaluation.evaluate(pi, U, mdp); // unchanged? <- true unchanged = true; // for each state s in S do foreach (S s in mdp.states()) { // calculate: // max<sub>a ∈ A(s)</sub> // Σ<sub>s'</sub>P(s'|s,a)U[s'] double aMax = double.NegativeInfinity, piVal = 0; A aArgmax = pi.Get(s); foreach (A a in mdp.actions(s)) { double aSum = 0; foreach (S sDelta in mdp.states()) { aSum += mdp.transitionProbability(sDelta, s, a) * U.Get(sDelta); } if (aSum > aMax) { aMax = aSum; aArgmax = a; } // track: // Σ<sub>s'</sub>P(s'|s,π[s])U[s'] if (a.Equals(pi.Get(s))) { piVal = aSum; } } // if max<sub>a ∈ A(s)</sub> // Σ<sub>s'</sub>P(s'|s,a)U[s'] // > Σ<sub>s'</sub>P(s'|s,π[s])U[s'] then do if (aMax > piVal) { // π[s] <- argmax<sub>a ∈A(s)</sub> // Σ<sub>s'</sub>P(s'|s,a)U[s'] pi.Put(s, aArgmax); // unchanged? <- false unchanged = false; } } // until unchanged? } while (!unchanged); // return π return(new LookupPolicy <S, A>(pi)); }