예제 #1
0
        public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._States.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            this.Params["lr"] = this.GetLearningRate(index).ToString(CultureInfo.InvariantCulture);
            this.Params["wd"] = this.GetWeightDecay(index).ToString(CultureInfo.InvariantCulture);
            this.UpdateCount(index);
            var keys   = this.GetParamKeys_();
            var values = this.GetParamValues_();

            Logging.CHECK_EQ(keys.Length, values.Length);

            var inputs = new NDArrayHandle[3];

            inputs[0] = weight.GetHandle();
            inputs[1] = grad.GetHandle();

            var numOutputs = 1;
            var output     = weight.GetHandle();
            var outputs    = new[] { output };

            if (this._States[index] == null)
            {
                NativeMethods.MXImperativeInvoke(this._UpdateHandle,
                                                 2,
                                                 inputs,
                                                 ref numOutputs,
                                                 ref outputs,
                                                 keys.Length,
                                                 keys,
                                                 values);
            }
            else
            {
                inputs[2] = this._States[index].GetHandle();
                NativeMethods.MXImperativeInvoke(this._MomUpdateHandle,
                                                 3,
                                                 inputs,
                                                 ref numOutputs,
                                                 ref outputs,
                                                 keys.Length,
                                                 keys,
                                                 values);
            }
        }
예제 #2
0
        public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._Mean.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            this.Params["lr"] = this.GetLearningRate(index).ToString(CultureInfo.InvariantCulture);
            this.Params["wd"] = this.GetWeightDecay(index).ToString(CultureInfo.InvariantCulture);
            this.UpdateCount(index);
            var keys   = this.GetParamKeys_();
            var values = this.GetParamValues_();

            Logging.CHECK_EQ(keys.Length, values.Length);

            //var lr = double.Parse(params_["lr"]);
            //var wd = float.Parse(params_["wd"]);
            //var b1 = float.Parse(params_["beta1"]);
            //var b2 = float.Parse(params_["beta2"]);
            //var t = count_[index];
            //var coef1 = 1.0d - Math.Pow(b1, t);
            //var coef2 = 1.0d - Math.Pow(b2, t);
            //lr *= Math.Sqrt(coef2) / coef1;

            var inputs = new NDArrayHandle[4];

            inputs[0] = weight.GetHandle();
            inputs[1] = grad.GetHandle();
            inputs[2] = this._Mean[index].GetHandle();
            inputs[3] = this._Var[index].GetHandle();

            var numOutputs = 1;
            var output     = weight.GetHandle();
            var outputs    = new[] { output };

            NativeMethods.MXImperativeInvoke(this._UpdateHandle,
                                             4,
                                             inputs,
                                             ref numOutputs,
                                             ref outputs,
                                             keys.Length,
                                             keys,
                                             values);
        }
예제 #3
0
        public Operator PushInput(NDArray ndarray)
        {
            if (ndarray == null)
            {
                throw new ArgumentNullException(nameof(ndarray));
            }

            this._InputNdarrays.Add(ndarray.GetHandle());
            return(this);
        }