Exemplo n.º 1
0
        public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
        {
            dtype = dtype.as_base_dtype();
            return(with(ops.name_scope(name, "zeros", shape), scope =>
            {
                name = scope;
                switch (dtype)
                {
                case TF_DataType.TF_BOOL:
                    return _constant_if_small(false, shape, dtype, name);

                case TF_DataType.TF_DOUBLE:
                    return _constant_if_small(0.0D, shape, dtype, name);

                case TF_DataType.TF_FLOAT:
                    return _constant_if_small(0.0F, shape, dtype, name);

                case TF_DataType.TF_INT32:
                    return _constant_if_small(0, shape, dtype, name);

                default:
                    throw new TypeError("can't find type for zeros");
                }
            }));
        }
Exemplo n.º 2
0
        public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
        {
            Tensor output = null;

            dtype = dtype.as_base_dtype();
            Python.with(new ops.name_scope(name, "zeros", shape), self =>
            {
                name = self as ops.name_scope;
                switch (dtype)
                {
                case TF_DataType.TF_BOOL:
                    output = _constant_if_small(false, shape, dtype, name);
                    break;

                case TF_DataType.TF_DOUBLE:
                    output = _constant_if_small(0.0D, shape, dtype, name);
                    break;

                case TF_DataType.TF_FLOAT:
                    output = _constant_if_small(0.0F, shape, dtype, name);
                    break;

                case TF_DataType.TF_INT32:
                    output = _constant_if_small(0, shape, dtype, name);
                    break;

                default:
                    break;
                }
            });

            return(output);
        }
Exemplo n.º 3
0
        public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
        {
            dtype = dtype.as_base_dtype();
            return(tf_with(ops.name_scope(name, "zeros", shape), scope =>
            {
                name = scope;
                switch (dtype)
                {
                case TF_DataType.TF_BOOL:
                    return gen_array_ops.fill(shape, tf.constant(false, dtype: dtype), name: name);

                case TF_DataType.TF_DOUBLE:
                    return gen_array_ops.fill(shape, tf.constant(0.0D, dtype: dtype), name: name);

                case TF_DataType.TF_FLOAT:
                    return gen_array_ops.fill(shape, tf.constant(0.0F, dtype: dtype), name: name);

                case TF_DataType.TF_INT32:
                    return gen_array_ops.fill(shape, tf.constant(0, dtype: dtype), name: name);

                default:
                    throw new TypeError("can't find type for zeros");
                }
            }));
        }
Exemplo n.º 4
0
        public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
        {
            dtype = dtype.as_base_dtype();

            if (tf.executing_eagerly())
            {
                return(tf_with(ops.name_scope(name, "zeros", shape), scope =>
                {
                    name = scope;
                    var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
                    Tensor zeros = null;
                    switch (dtype)
                    {
                    case TF_DataType.TF_DOUBLE:
                        zeros = constant(0d);
                        break;

                    case TF_DataType.TF_FLOAT:
                        zeros = constant(0f);
                        break;

                    default:
                        zeros = constant(0);
                        break;
                    }
                    return fill(shape_tensor, zeros, name: name);
                }));
            }
            else
            {
                return(tf_with(ops.name_scope(name, "zeros", shape), scope =>
                {
                    name = scope;
                    switch (dtype)
                    {
                    case TF_DataType.TF_BOOL:
                        return _constant_if_small(false, shape, dtype, name);

                    case TF_DataType.TF_DOUBLE:
                        return _constant_if_small(0.0D, shape, dtype, name);

                    case TF_DataType.TF_FLOAT:
                        return _constant_if_small(0.0F, shape, dtype, name);

                    case TF_DataType.TF_INT64:
                        return _constant_if_small(0l, shape, dtype, name);

                    case TF_DataType.TF_INT32:
                        return _constant_if_small(0, shape, dtype, name);

                    case TF_DataType.TF_INT8:
                        return _constant_if_small <byte>(0, shape, dtype, name);

                    default:
                        throw new TypeError("can't find type for zeros");
                    }
                }));
            }
        }
Exemplo n.º 5
0
        /// <summary>
        /// Casts a tensor to a new type.
        /// </summary>
        /// <param name="x"></param>
        /// <param name="dtype"></param>
        /// <param name="name"></param>
        /// <returns>A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and same type as `dtype`.</returns>
        public static Tensor __case__(Tensor x, TF_DataType dtype, string name = null)
        {
            var base_type = dtype.as_base_dtype();
            if (x is Tensor && base_type == x.dtype)
                return x;

            // math_ops.py cast
            throw new NotImplementedException();
        }
Exemplo n.º 6
0
 public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
 {
     dtype = dtype.as_base_dtype();
     return(with(ops.name_scope(name, "ones", new { shape }), scope =>
     {
         name = scope;
         var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
         return output;
     }));
 }
Exemplo n.º 7
0
        private RefVariable _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool reuse          = false,
                                                 bool?trainable      = null,
                                                 bool validate_shape = false,
                                                 bool?use_resource   = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                                 VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            bool initializing_from_value = false;

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            Tensor init_val = null;

            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    init_val = initializer.call(shape, dtype);
                    var variable_dtype = dtype.as_base_dtype();
                }
            }

            // Create the variable.
            if (use_resource == null)
            {
                use_resource = false;
            }

            var v = variable_scope.default_variable_creator(init_val,
                                                            name: name,
                                                            trainable: trainable,
                                                            dtype: TF_DataType.DtInvalid,
                                                            validate_shape: validate_shape,
                                                            synchronization: synchronization,
                                                            aggregation: aggregation);

            _vars[name] = v;

            return(v);
        }
Exemplo n.º 8
0
 public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
 {
     dtype = dtype.as_base_dtype();
     return(with(ops.name_scope(name, "ones", new { dims }), scope =>
     {
         name = scope;
         var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32);
         var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
         return output;
     }));
 }
Exemplo n.º 9
0
        public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
        {
            var base_type = dtype.as_base_dtype();

            if (base_type == x.dtype)
            {
                return(x);
            }

            throw new NotImplementedException("math_ops.cast");
        }
Exemplo n.º 10
0
 public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
                          bool clear_after_read = true, string tensor_array_name  = null, Tensor handle = null, Tensor flow = null,
                          bool infer_shape      = true, Shape?element_shape       = null,
                          bool colocate_with_first_write_call = true, string name = null)
 {
     _flow          = constant_op.constant(0);
     _infer_shape   = infer_shape;
     _element_shape = element_shape ?? Shape.Null;
     _colocate_with_first_write_call = colocate_with_first_write_call;
     _dtype            = dtype.as_base_dtype();
     _dynamic_size     = dynamic_size;
     _clear_after_read = clear_after_read;
     _tensor_array     = new List <Tensor>();
 }
Exemplo n.º 11
0
        public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();

            return tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
            {
                name = scope;
                var x_tensor = ops.convert_to_tensor(x, name: "x");
                if (x_tensor.dtype.as_base_dtype() != base_type)
                    x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name);

                return x_tensor;
            });
        }
Exemplo n.º 12
0
        public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();
            if(base_type == x.dtype)
                return x;

            return with(new ops.name_scope(name, "Cast", new { x }), scope =>
            {
                x = ops.convert_to_tensor(x, name: "x");
                if (x.dtype.as_base_dtype() != base_type)
                    x = gen_math_ops.cast(x, base_type, name: name);

                return x;
            });
        }
Exemplo n.º 13
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="type"></param>
        /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns>
        public static Type as_numpy_dtype(this TF_DataType type)
        {
            switch (type.as_base_dtype())
            {
            case TF_DataType.TF_BOOL:
                return(typeof(bool));

            case TF_DataType.TF_UINT8:
                return(typeof(byte));

            case TF_DataType.TF_INT8:
                return(typeof(sbyte));

            case TF_DataType.TF_INT64:
                return(typeof(long));

            case TF_DataType.TF_UINT64:
                return(typeof(ulong));

            case TF_DataType.TF_INT32:
                return(typeof(int));

            case TF_DataType.TF_UINT32:
                return(typeof(uint));

            case TF_DataType.TF_INT16:
                return(typeof(short));

            case TF_DataType.TF_UINT16:
                return(typeof(ushort));

            case TF_DataType.TF_FLOAT:
                return(typeof(float));

            case TF_DataType.TF_DOUBLE:
                return(typeof(double));

            case TF_DataType.TF_STRING:
                return(typeof(string));

            case TF_DataType.TF_COMPLEX128:
            case TF_DataType.TF_COMPLEX64:     //64 is also TF_COMPLEX
                return(typeof(Complex));

            default:
                return(null);
            }
        }
Exemplo n.º 14
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="type"></param>
        /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns>
        public static Type as_system_dtype(this TF_DataType type)
        {
            switch (type.as_base_dtype())
            {
            case TF_DataType.TF_BOOL:
                return(typeof(bool));

            case TF_DataType.TF_UINT8:
                return(typeof(byte));

            case TF_DataType.TF_INT8:
                return(typeof(sbyte));

            case TF_DataType.TF_INT64:
                return(typeof(long));

            case TF_DataType.TF_UINT64:
                return(typeof(ulong));

            case TF_DataType.TF_INT32:
                return(typeof(int));

            case TF_DataType.TF_UINT32:
                return(typeof(uint));

            case TF_DataType.TF_INT16:
                return(typeof(short));

            case TF_DataType.TF_UINT16:
                return(typeof(ushort));

            case TF_DataType.TF_FLOAT:
                return(typeof(float));

            case TF_DataType.TF_DOUBLE:
                return(typeof(double));

            case TF_DataType.TF_STRING:
                return(typeof(string));

            case TF_DataType.TF_COMPLEX128:
            case TF_DataType.TF_COMPLEX64:     //64 is also TF_COMPLEX
                return(typeof(Complex));

            default:
                throw new NotSupportedException($"Unable to convert {type} to a system data type.");
            }
        }
Exemplo n.º 15
0
        public static ResourceVariable cast(ResourceVariable x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();
            if (base_type == x.dtype)
                return x;

            return tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
            {
                name = scope;
                var t_x = ops.convert_to_tensor(x, name: "x");
                if (t_x.dtype.as_base_dtype() != base_type)
                    t_x = gen_math_ops.cast(t_x, base_type, name: name);

                return x;
            });
        }
Exemplo n.º 16
0
        /// <summary>
        /// Adds a new variable to the layer.
        /// </summary>
        /// <param name="name"></param>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="initializer"></param>
        /// <param name="trainable"></param>
        /// <returns></returns>
        public static RefVariable make_variable(string name,
                                                int[] shape,
                                                TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                IInitializer initializer = null,
                                                bool trainable           = true,
                                                bool use_resource        = true)
        {
            var initializing_from_value = false;

            ops.init_scope();

            Func <Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);

            var variable_dtype = dtype.as_base_dtype();
            var v = tf.Variable(init_val);

            return(v);
        }
Exemplo n.º 17
0
        public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
        {
            dtype = dtype.as_base_dtype();

            if (tf.executing_eagerly())
            {
                return(tf_with(ops.name_scope(name, "zeros", shape), scope =>
                {
                    name = scope;
                    // var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
                    Tensor zeros = dtype switch
                    {
                        TF_DataType.TF_DOUBLE => constant(0d),
                        TF_DataType.TF_FLOAT => constant(0f),
                        TF_DataType.TF_INT8 => constant((sbyte)0),
                        TF_DataType.TF_UINT8 => constant((byte)0),
                        _ => constant(0)
                    };
                    return fill(shape, zeros, name: name);
                }));
Exemplo n.º 18
0
        public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();

            if (base_type == x.dtype)
            {
                return(x);
            }

            return(tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
            {
                name = scope;
                if (x.dtype.as_base_dtype() != base_type)
                {
                    x = gen_math_ops.cast(x, base_type, name: name);
                }

                return x;
            }));
        }
Exemplo n.º 19
0
        public static Tensor cast(IVariableV1 x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();

            if (base_type == x.dtype)
            {
                return(x.AsTensor());
            }

            return(tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
            {
                name = scope;
                var t_x = ops.convert_to_tensor(x, name: "x");
                if (t_x.dtype.as_base_dtype() != base_type)
                {
                    t_x = gen_math_ops.cast(t_x, base_type, name: name);
                }

                return x.AsTensor();
            }));
        }
        public RefVariable get_variable(string name,
                                        TensorShape shape        = null,
                                        TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                        IInitializer initializer = null,
                                        bool?trainable           = null,
                                        bool validate_shape      = true,
                                        VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                        VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            dtype     = dtype.as_base_dtype();
            trainable = variable_scope._get_trainable_value(synchronization, trainable);

            return(_true_getter(name,
                                shape: shape,
                                dtype: dtype,
                                initializer: initializer,
                                trainable: trainable,
                                validate_shape: validate_shape,
                                synchronization: synchronization,
                                aggregation: aggregation));
        }
Exemplo n.º 21
0
        public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
        => tf_with(ops.name_scope(name, "zeros", shape), scope =>
        {
            dtype            = dtype.as_base_dtype();
            name             = scope;
            var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
            Tensor zeros     = null;
            switch (dtype)
            {
            case TF_DataType.TF_DOUBLE:
                zeros = constant(0d);
                break;

            case TF_DataType.TF_FLOAT:
                zeros = constant(0f);
                break;

            default:
                zeros = constant(0);
                break;
            }
            return(fill(shape_tensor, zeros, name: name));
        });
Exemplo n.º 22
0
        /// <summary>
        /// Adds a new variable to the layer.
        /// </summary>
        /// <param name="name"></param>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="initializer"></param>
        /// <param name="trainable"></param>
        /// <returns></returns>
        public static IVariableV1 make_variable(string name,
                                                int[] shape,
                                                TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                IInitializer initializer = null,
                                                bool trainable           = true)
        {
#pragma warning disable CS0219 // Variable is assigned but its value is never used
            var  initializing_from_value = false;
            bool use_resource            = true;
#pragma warning restore CS0219 // Variable is assigned but its value is never used

            ops.init_scope();

            Func <Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);

            var variable_dtype = dtype.as_base_dtype();
            var v = tf.Variable(init_val,
                                dtype: dtype,
                                shape: shape,
                                name: name);

            return(v);
        }
Exemplo n.º 23
0
        public IVariableV1 get_variable(string name,
                                        TensorShape shape         = null,
                                        TF_DataType dtype         = TF_DataType.TF_FLOAT,
                                        object initializer        = null, // IInitializer or Tensor
                                        bool?reuse                = null,
                                        bool?trainable            = null,
                                        List <string> collections = null,
                                        bool validate_shape       = true,
                                        VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation         = VariableAggregation.None)
        {
            dtype     = dtype.as_base_dtype();
            trainable = variable_scope._get_trainable_value(synchronization, trainable);

            return(_true_getter(name,
                                shape: shape,
                                dtype: dtype,
                                initializer: initializer,
                                trainable: trainable,
                                collections: collections,
                                validate_shape: validate_shape,
                                synchronization: synchronization,
                                aggregation: aggregation));
        }
Exemplo n.º 24
0
        protected virtual RefVariable add_weight(string name,
                                                 int[] shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 Func <string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            if (trainable == null)
            {
                trainable = true;
            }

            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}");
                }
            }
            var variable = _add_variable_with_custom_getter(name,
                                                            shape,
                                                            dtype: dtype,
                                                            getter: (getter == null) ? base_layer_utils.make_variable : getter,
                                                            overwrite: true,
                                                            initializer: initializer,
                                                            trainable: trainable.Value);

            backend.track_variable(variable);
            _trainable_weights.Add(variable);

            return(variable);
        }
Exemplo n.º 25
0
        protected virtual IVariableV1 add_weight(string name,
                                                 TensorShape shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            if (trainable == null)
            {
                trainable = true;
            }

            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}");
                }
            }

            var args = new VariableArgs
            {
                Name        = name,
                Shape       = shape,
                DType       = dtype,
                Getter      = getter ?? base_layer_utils.make_variable,
                Overwrite   = true,
                Initializer = initializer,
                Trainable   = trainable.Value
            };
            var variable = _add_variable_with_custom_getter(args);

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainableWeights.Add(variable);
            }
            else
            {
                nonTrainableWeights.Add(variable);
            }

            return(variable);
        }
Exemplo n.º 26
0
        private IVariableV1 _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool reuse                = false,
                                                 bool?trainable            = null,
                                                 List <string> collections = null,
                                                 bool validate_shape       = false,
                                                 bool?use_resource         = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool initializing_from_value = false;

            if (use_resource == null)
            {
                use_resource = false;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            IVariableV1 v = null;

            // Create the tensor to initialize the variable with default value.
            if (initializer == null)
            {
                if (dtype.is_floating())
                {
                    initializer             = tf.glorot_uniform_initializer;
                    initializing_from_value = false;
                }
            }

            // Create the variable.
            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    Func <Tensor> init_val       = () => initializer.call(shape, dtype);
                    var           variable_dtype = dtype.as_base_dtype();

                    v = variable_scope.default_variable_creator(init_val,
                                                                name: name,
                                                                trainable: trainable,
                                                                collections: collections,
                                                                dtype: variable_dtype,
                                                                validate_shape: validate_shape,
                                                                synchronization: synchronization,
                                                                aggregation: aggregation);
                }
            }

            _vars[name] = v;

            return(v);
        }
Exemplo n.º 27
0
        protected virtual IVariableV1 add_weight(string name,
                                                 Shape shape,
                                                 TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                 IInitializer initializer = null,
                                                 IRegularizer regularizer = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None,
                                                 bool trainable = true,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
                }
            }

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            var args = new VariableArgs
            {
                Name            = name,
                Shape           = shape,
                DType           = dtype,
                Getter          = getter ?? base_layer_utils.make_variable,
                Overwrite       = true,
                Initializer     = initializer,
                Synchronization = synchronization,
                Aggregation     = aggregation,
                Trainable       = trainable
            };
            var variable = _add_variable_with_custom_getter(args);

            if (regularizer != null)
            {
                var name_in_scope = variable.Name.Split(':')[0];
                _handle_weight_regularization(name_in_scope, variable, regularizer);
            }

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainable_weights.Add(variable);
            }
            else
            {
                non_trainable_weights.Add(variable);
            }

            return(variable);
        }
Exemplo n.º 28
0
 public DataType _MakeType(TF_DataType v, AttrDef attr_def)
 {
     return(v.as_base_dtype().as_datatype_enum());
 }
Exemplo n.º 29
0
        private void _init_from_args(object initial_value,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     bool validate_shape       = true,
                                     string caching_device     = "",
                                     string name       = "",
                                     TF_DataType dtype = TF_DataType.DtInvalid)
        {
            if (initial_value is null)
            {
                throw new ValueError("initial_value must be specified.");
            }

            var init_from_fn = false;

            if (collections == null)
            {
                collections = new List <string> {
                    ops.GraphKeys.GLOBAL_VARIABLES
                };
            }

            // Store the graph key so optimizers know how to only retrieve variables from
            // this graph.
            _graph_key = ops.get_default_graph()._graph_key;

            _trainable = trainable;
            if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
            }

            ops.init_scope();
            var values = init_from_fn ? new List <object>() : new List <object> {
                initial_value
            };

            using (var namescope = new ops.name_scope <object>(name, "Variable", values))
            {
                name = namescope;

                if (init_from_fn)
                {
                }
                // Or get the initial value from a Tensor or Python object.
                else
                {
                    _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");

                    var shape = _initial_value.shape;
                    dtype     = _initial_value.dtype;
                    _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
                }

                // Manually overrides the variable's shape with the initial value's.
                if (validate_shape)
                {
                    var initial_value_shape = _initial_value.shape;
                }

                // If 'initial_value' makes use of other variables, make sure we don't
                // have an issue if these other variables aren't initialized first by
                // using their initialized_value() method.
                var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);

                _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

                if (!String.IsNullOrEmpty(caching_device))
                {
                }
                else
                {
                    ops.colocate_with(_initializer_op);

                    _snapshot = gen_array_ops.identity(_variable, name = "read");
                }

                ops.add_to_collections(collections, this);
            }
        }
Exemplo n.º 30
0
        private void _init_from_args(object initial_value,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     bool validate_shape       = true,
                                     string caching_device     = "",
                                     string name       = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid)
        {
            if (initial_value is null)
            {
                throw new ValueError("initial_value must be specified.");
            }

            var init_from_fn = initial_value.GetType().Name == "Func`1";

            if (collections == null)
            {
                collections = new List <string> {
                    tf.GraphKeys.GLOBAL_VARIABLES
                };
            }

            // Store the graph key so optimizers know how to only retrieve variables from
            // this graph.
            _graph_key = ops.get_default_graph().graph_key;

            _trainable = trainable;
            if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
            }

            ops.init_scope();
            var values = init_from_fn ? new object[0] : new object[] { initial_value };

            tf_with(ops.name_scope(name, "Variable", values), scope =>
            {
                name = scope;
                if (init_from_fn)
                {
                    // Use attr_scope and device(None) to simulate the behavior of
                    // colocate_with when the variable we want to colocate with doesn't
                    // yet exist.
                    string true_name = ops._name_from_scope_name(name);
                    var attr         = new AttrValue
                    {
                        List = new AttrValue.Types.ListValue()
                    };
                    attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
                    tf_with(ops.name_scope("Initializer"), scope2 =>
                    {
                        _initial_value = (initial_value as Func <Tensor>)();
                        _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
                    });
                    _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
                }
                // Or get the initial value from a Tensor or Python object.
                else
                {
                    _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);

                    var shape = _initial_value.shape;
                    dtype     = _initial_value.dtype;
                    _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
                }

                // Manually overrides the variable's shape with the initial value's.
                if (validate_shape)
                {
                    var initial_value_shape = _initial_value.TensorShape;
                    if (!initial_value_shape.is_fully_defined())
                    {
                        throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
                    }
                }

                // If 'initial_value' makes use of other variables, make sure we don't
                // have an issue if these other variables aren't initialized first by
                // using their initialized_value() method.
                var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);

                _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

                if (!String.IsNullOrEmpty(caching_device))
                {
                }
                else
                {
                    ops.colocate_with(_initializer_op);

                    _snapshot = gen_array_ops.identity(_variable, name = "read");
                }

                ops.add_to_collections(collections, this as VariableV1);
            });
        }