示例#1
0
文件: DQN.cs 项目: mbithy/TemboRL
        /// <summary>
        /// Returns an action from a state
        /// </summary>
        /// <param name="state">state size must be equal to NumberOfStates</param>
        /// <returns></returns>
        public int Act(double[] state)
        {
            Tembo.Assert(state.Length == NumberOfStates, $"Current state({state.Length}) not equal to NS({NumberOfStates})");
            var a = 0;
            // convert to a Mat column vector
            var s = new Matrix(NumberOfStates, 1);

            s.Set(state);
            // epsilon greedy policy
            if (Tembo.Random() < Options.Epsilon)
            {
                a = Tembo.RandomInt(0, NumberOfActions);
            }
            else
            {
                // greedy wrt Q function
                var amat = ForwardQ(Network, s, false);
                a = Tembo.Maxi(amat.W); // returns index of argmax action
            }
            // shift state memory
            this.s0 = this.s1;
            this.a0 = this.a1;
            this.s1 = s;
            this.a1 = a;
            return(a);
        }
示例#2
0
文件: DQN.cs 项目: mbithy/TemboRL
        /// <summary>
        /// OOP advatages adopted during translation...
        /// </summary>
        /// <param name="experience">See Experience</param>
        /// <returns></returns>
        private double LearnFromExperience(Experience experience /*Matrix s0, int a0, double r0, Matrix s1, int a1*/)
        {
            // want: Q(s,a) = r + gamma * max_a' Q(s',a')
            // compute the target Q value
            var tmat = ForwardQ(Network, s1, false);
            var qmax = r0 + Options.Gamma * tmat.W[Tembo.Maxi(tmat.W)];
            // now predict
            var pred    = ForwardQ(Network, s0, true);
            var tderror = pred.W[a0] - qmax;
            var clamp   = Options.ErrorClamp;

            if (Math.Abs(tderror) > clamp)
            {  // huber loss to robustify
                if (tderror > clamp)
                {
                    tderror = clamp;
                }
                if (tderror < -clamp)
                {
                    tderror = -clamp;
                }
            }
            pred.DW[a0] = tderror;
            LastGraph.Backward(); // compute gradients on net params
            // update net
            Tembo.UpdateNetwork(Network, Options.Alpha);
            return(tderror);
        }