public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._History.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            var eps = float.Parse(this.Params["eps"]);
            var lr  = this.GetLearningRate(index);
            var wd  = this.GetWeightDecay(index);

            this.UpdateCount(index);
            if (this.Params.ContainsKey("rescale_grad"))
            {
                grad *= float.Parse(this.Params["rescale_grad"]);
            }

            if (this.Params.ContainsKey("clip_gradient"))
            {
                Clip(ref grad, float.Parse(this.Params["clip_gradient"]));
            }

            //auto & history = *history_[index];
            //history += grad * grad;
            //weight -= (grad / _sqrt(history + eps) + weight * wd) * lr;
            var history = this._History[index];

            using (var tmp1 = grad * grad)
            {
                history.Add(tmp1);

                using (var tmp2 = history + eps)
                    using (var tmp3 = Sqrt(tmp2))
                        using (var tmp4 = weight * wd)
                            using (var tmp5 = grad / tmp3)
                                using (var tmp6 = tmp5 + tmp4)
                                    using (var tmp7 = tmp6 * lr)
                                        weight.Subtract(tmp7);
            }
        }
Ejemplo n.º 2
0
        public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._AccG.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            var rho     = float.Parse(this.Params["rho"]);
            var epsilon = float.Parse(this.Params["epsilon"]);
            var wd      = this.GetWeightDecay(index);

            this.UpdateCount(index);

            if (this.Params.ContainsKey("rescale_grad"))
            {
                grad *= float.Parse(this.Params["rescale_grad"]);
            }

            if (this.Params.ContainsKey("clip_gradient"))
            {
                Clip(ref grad, float.Parse(this.Params["clip_gradient"]));
            }

            //auto & acc_g = *acc_g_[index];
            //auto & acc_delta = *acc_delta_[index];
            //acc_g *= rho;
            //acc_g += grad * grad * (1.0f - rho);

            //auto delta = _sqrt(acc_delta + epsilon) / _sqrt(acc_g + epsilon) * grad;
            //acc_delta *= rho;
            //acc_delta += delta * delta * (1.0f - rho);
            //weight *= 1.0f - wd;
            //weight -= delta;
            var accG     = this._AccG[index];
            var accDelta = this._AccDelta[index];

            using (var tmp2 = grad * grad)
                using (var tmp3 = tmp2 * (1.0f - rho))
                {
                    accG.Multiply(rho);
                    accG.Add(tmp3);

                    using (var tmp4 = accDelta + epsilon)
                        using (var tmp5 = accG + epsilon)
                            using (var tmp6 = Sqrt(tmp4))
                                using (var tmp7 = Sqrt(tmp5))
                                    using (var tmp8 = tmp6 / tmp7)
                                        using (var delta = tmp8 * grad)
                                            using (var tmp11 = delta * delta)
                                                using (var tmp13 = tmp11 * (1.0f - wd))
                                                {
                                                    accDelta.Multiply(rho);
                                                    accDelta.Add(tmp13);

                                                    weight.Multiply(1.0f - wd);
                                                    weight.Subtract(delta);
                                                }
                }
        }