protected override void _prepare_local(DeviceDType device_dtype, Dictionary <DeviceDType, Dictionary <string, Tensor> > _apply_state) { base._prepare_local(device_dtype, _apply_state); _apply_state[device_dtype]["momentum"] = array_ops.identity( _get_hyper("momentum", device_dtype.DType)); }
protected override void _prepare_local(DeviceDType device_dtype, Dictionary <DeviceDType, Dictionary <string, Tensor> > _apply_state) { base._prepare_local(device_dtype, _apply_state); var rho = array_ops.identity(_get_hyper("rho", device_dtype.DType)); _apply_state[device_dtype]["neg_lr_t"] = -_apply_state[device_dtype]["lr_t"]; _apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(args.Epsilon, dtype: device_dtype.DType); _apply_state[device_dtype]["rho"] = rho; _apply_state[device_dtype]["momentum"] = array_ops.identity(_get_hyper("momentum", device_dtype.DType)); _apply_state[device_dtype]["one_minus_rho"] = 1.0f - rho; }
protected override void _prepare_local(DeviceDType device_dtype, Dictionary <DeviceDType, Dictionary <string, Tensor> > apply_state) { base._prepare_local(device_dtype, apply_state); var var_dtype = device_dtype.DType; var var_device = device_dtype.Device; var local_step = math_ops.cast(iterations + 1, var_dtype); var beta_1_t = array_ops.identity(_get_hyper("beta_1", var_dtype)); var beta_2_t = array_ops.identity(_get_hyper("beta_2", var_dtype)); var beta_1_power = math_ops.pow(beta_1_t, local_step); var beta_2_power = math_ops.pow(beta_2_t, local_step); var lr = apply_state[device_dtype]["lr_t"] * (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)); // update state apply_state[device_dtype]["lr"] = lr; apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(epsilon); apply_state[device_dtype]["beta_1_t"] = beta_1_t; apply_state[device_dtype]["beta_1_power"] = beta_1_power; apply_state[device_dtype]["one_minus_beta_1_t"] = 1 - beta_1_t; apply_state[device_dtype]["beta_2_t"] = beta_2_t; apply_state[device_dtype]["beta_2_power"] = beta_2_power; apply_state[device_dtype]["one_minus_beta_2_t"] = 1 - beta_2_t; }