public Tensor gradient(Tensor target, ResourceVariable source)
        {
            var results = gradient(target, new List <IVariableV1> {
                source
            });

            return(results[0]);
        }
Example #2
0
        protected override Operation _resource_apply_dense(ResourceVariable var, EagerTensor grad, Dictionary <DeviceDType, Dictionary <string, Tensor> > _apply_state)
        {
            if (_momentum)
            {
                throw new NotImplementedException("_resource_apply_dense");
            }
            var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());

            return(gen_training_ops.resource_apply_gradient_descent(var.Handle as EagerTensor,
                                                                    _apply_state[device_dtype]["lr_t"] as EagerTensor,
                                                                    grad,
                                                                    use_locking: _use_locking));
        }
Example #3
0
        public unsafe ResourceVariable[] watched_variables()
        {
            BindingArray result    = c_api.TFE_TapeWatchedVariables(_handle);
            var          variables = new ResourceVariable[result.length];

            for (int i = 0; i < result.length; i++)
            {
                var handle = *((IntPtr *)result.array + i);
                var tensor = c_api.ResourceVariable_Handle(handle);
                variables[i] = new ResourceVariable(handle, tensor);
            }

            return(variables);
        }
Example #4
0
        public override Operation _apply_dense(Tensor grad, ResourceVariable var)
        {
            var m = get_slot(var, "m");
            var v = get_slot(var, "v");

            var(beta1_power, beta2_power) = _get_beta_accumulators();
            return(gen_training_ops.apply_adam(
                       var.Handle,
                       m.Handle,
                       v.Handle,
                       math_ops.cast(beta1_power.Handle, var.dtype.as_base_dtype()),
                       math_ops.cast(beta2_power.Handle, var.dtype.as_base_dtype()),
                       math_ops.cast(_lr_t, var.dtype.as_base_dtype()),
                       math_ops.cast(_beta1_t, var.dtype.as_base_dtype()),
                       math_ops.cast(_beta2_t, var.dtype.as_base_dtype()),
                       math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()),
                       grad,
                       use_locking: _use_locking).op);
        }
Example #5
0
 public void VariableAccessed(ResourceVariable variable)
 {
     Watch(variable.Handle.Id);
 }
Example #6
0
 public void VariableAccessed(ResourceVariable variable)
 {
     throw new NotImplementedException();
 }
Example #7
0
 public static void variable_accessed(ResourceVariable variable)
 {
     c_api.TFE_TapeVariableAccessed(variable);
 }
Example #8
0
        public Tensor gradient(Tensor target, ResourceVariable source)
        {
            var results = gradient(target as EagerTensor, new[] { source });

            return(results[0]);
        }