/// <summary> /// Returns an action from a state /// </summary> /// <param name="state">state size must be equal to NumberOfStates</param> /// <returns></returns> public int Act(double[] state) { Tembo.Assert(state.Length == NumberOfStates, $"Current state({state.Length}) not equal to NS({NumberOfStates})"); var a = 0; // convert to a Mat column vector var s = new Matrix(NumberOfStates, 1); s.Set(state); // epsilon greedy policy if (Tembo.Random() < Options.Epsilon) { a = Tembo.RandomInt(0, NumberOfActions); } else { // greedy wrt Q function var amat = ForwardQ(Network, s, false); a = Tembo.Maxi(amat.W); // returns index of argmax action } // shift state memory this.s0 = this.s1; this.a0 = this.a1; this.s1 = s; this.a1 = a; return(a); }
/// <summary> /// OOP advatages adopted during translation... /// </summary> /// <param name="experience">See Experience</param> /// <returns></returns> private double LearnFromExperience(Experience experience /*Matrix s0, int a0, double r0, Matrix s1, int a1*/) { // want: Q(s,a) = r + gamma * max_a' Q(s',a') // compute the target Q value var tmat = ForwardQ(Network, s1, false); var qmax = r0 + Options.Gamma * tmat.W[Tembo.Maxi(tmat.W)]; // now predict var pred = ForwardQ(Network, s0, true); var tderror = pred.W[a0] - qmax; var clamp = Options.ErrorClamp; if (Math.Abs(tderror) > clamp) { // huber loss to robustify if (tderror > clamp) { tderror = clamp; } if (tderror < -clamp) { tderror = -clamp; } } pred.DW[a0] = tderror; LastGraph.Backward(); // compute gradients on net params // update net Tembo.UpdateNetwork(Network, Options.Alpha); return(tderror); }