예제 #1
0
 //"""Update 'ref' by adding 'value' to it.
 //
 //  This operation outputs "ref" after the update is done.
 //  This makes it easier to chain operations that need to use the reset value.
 //
 //  Args:
 //    ref: A mutable `Tensor`. Must be one of the following types:
 //      `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`,
 //      `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
 //      Should be from a `Variable` node.
 //    value: A `Tensor`. Must have the same type as `ref`.
 //      The value to be added to the variable.
 //    use_locking: An optional `bool`. Defaults to `False`.
 //      If True, the addition will be protected by a lock;
 //      otherwise the behavior is undefined, but may exhibit less contention.
 //    name: A name for the operation (optional).
 //
 //  Returns:
 //    Same as "ref".  Returned as a convenience for operations that want
 //    to use the new value after the variable has been updated.
 public static Tensor assign_add <T>(IVariableV1 @ref,
                                     T value,
                                     bool use_locking = false,
                                     string name      = null)
 {
     if (tf.executing_eagerly())
     {
         return(@ref.assign_add(value, use_locking: use_locking, name: name));
     }
     else
     {
         return(gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name));
     }
 }
예제 #2
0
        public Tensor update_state(Tensor values, Tensor sample_weight = null)
        {
            if (sample_weight != null)
            {
                (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions(
                    values, sample_weight: sample_weight);

                sample_weight = math_ops.cast(sample_weight, dtype: values.dtype);
                values        = math_ops.multiply(values, sample_weight);
            }

            Tensor update_total_op = null;
            var    value_sum       = math_ops.reduce_sum(values);

            tf_with(ops.control_dependencies(new[] { value_sum }), ctl =>
            {
                var update_total_op = total.assign_add(value_sum);
            });

            Tensor num_values = null;

            if (_reduction == ReductionV2.WEIGHTED_MEAN)
            {
                if (sample_weight == null)
                {
                    num_values = math_ops.cast(array_ops.size(values), _dtype);
                }
                else
                {
                    num_values = math_ops.reduce_sum(sample_weight);
                }
            }

            return(tf_with(ops.control_dependencies(new[] { update_total_op }), ctl
                           => count.assign_add(num_values)));
        }
예제 #3
0
 //"""Update 'ref' by adding 'value' to it.
 //
 //  This operation outputs "ref" after the update is done.
 //  This makes it easier to chain operations that need to use the reset value.
 //
 //  Args:
 //    ref: A mutable `Tensor`. Must be one of the following types:
 //      `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`,
 //      `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
 //      Should be from a `Variable` node.
 //    value: A `Tensor`. Must have the same type as `ref`.
 //      The value to be added to the variable.
 //    use_locking: An optional `bool`. Defaults to `False`.
 //      If True, the addition will be protected by a lock;
 //      otherwise the behavior is undefined, but may exhibit less contention.
 //    name: A name for the operation (optional).
 //
 //  Returns:
 //    Same as "ref".  Returned as a convenience for operations that want
 //    to use the new value after the variable has been updated.
 public static ITensorOrOperation assign_add <T>(IVariableV1 @ref,
                                                 T value,
                                                 bool use_locking = false,
                                                 string name      = null)
 => @ref.assign_add(value, use_locking: use_locking, name: name);