Пример #1
0
        protected override void _prepare_local(DeviceDType device_dtype,
                                               Dictionary <DeviceDType, Dictionary <string, Tensor> > _apply_state)
        {
            base._prepare_local(device_dtype, _apply_state);

            _apply_state[device_dtype]["momentum"] = array_ops.identity(
                _get_hyper("momentum", device_dtype.DType));
        }
Пример #2
0
        protected override void _prepare_local(DeviceDType device_dtype, Dictionary <DeviceDType, Dictionary <string, Tensor> > _apply_state)
        {
            base._prepare_local(device_dtype, _apply_state);
            var rho = array_ops.identity(_get_hyper("rho", device_dtype.DType));

            _apply_state[device_dtype]["neg_lr_t"]      = -_apply_state[device_dtype]["lr_t"];
            _apply_state[device_dtype]["epsilon"]       = ops.convert_to_tensor(args.Epsilon, dtype: device_dtype.DType);
            _apply_state[device_dtype]["rho"]           = rho;
            _apply_state[device_dtype]["momentum"]      = array_ops.identity(_get_hyper("momentum", device_dtype.DType));
            _apply_state[device_dtype]["one_minus_rho"] = 1.0f - rho;
        }
Пример #3
0
        protected override void _prepare_local(DeviceDType device_dtype, Dictionary <DeviceDType, Dictionary <string, Tensor> > apply_state)
        {
            base._prepare_local(device_dtype, apply_state);
            var var_dtype    = device_dtype.DType;
            var var_device   = device_dtype.Device;
            var local_step   = math_ops.cast(iterations + 1, var_dtype);
            var beta_1_t     = array_ops.identity(_get_hyper("beta_1", var_dtype));
            var beta_2_t     = array_ops.identity(_get_hyper("beta_2", var_dtype));
            var beta_1_power = math_ops.pow(beta_1_t, local_step);
            var beta_2_power = math_ops.pow(beta_2_t, local_step);
            var lr           = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power));

            // update state
            apply_state[device_dtype]["lr"]                 = lr;
            apply_state[device_dtype]["epsilon"]            = ops.convert_to_tensor(epsilon);
            apply_state[device_dtype]["beta_1_t"]           = beta_1_t;
            apply_state[device_dtype]["beta_1_power"]       = beta_1_power;
            apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t;
            apply_state[device_dtype]["beta_2_t"]           = beta_2_t;
            apply_state[device_dtype]["beta_2_power"]       = beta_2_power;
            apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t;
        }