Ejemplo n.º 1
0
        protected override IVariableV1 add_weight(string name,
                                                  TensorShape shape        = null,
                                                  TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                  IInitializer initializer = null,
                                                  IRegularizer regularizer = null,
                                                  VariableSynchronization synchronization = VariableSynchronization.OnRead,
                                                  VariableAggregation aggregation         = VariableAggregation.Sum,
                                                  bool trainable = true,
                                                  Func <VariableArgs, IVariableV1> getter = null)
        {
            if (shape == null)
            {
                shape = new TensorShape(new int[0]);
            }

            return(tf_with(ops.init_scope(), delegate
            {
                return base.add_weight(name, shape,
                                       dtype: dtype,
                                       trainable: false,
                                       initializer: initializer,
                                       synchronization: synchronization,
                                       aggregation: aggregation);
            }));
        }
        public RefVariable get_variable(_VariableStore var_store,
                                        string name,
                                        TensorShape shape        = null,
                                        TF_DataType dtype        = TF_DataType.DtInvalid,
                                        IInitializer initializer = null,
                                        bool?trainable           = null,
                                        VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                        VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name;

            return(with(new ops.name_scope(null), scope =>
            {
                if (dtype == TF_DataType.DtInvalid)
                {
                    dtype = _dtype;
                }

                return var_store.get_variable(full_name,
                                              shape: shape,
                                              dtype: dtype,
                                              initializer: initializer,
                                              trainable: trainable,
                                              synchronization: synchronization,
                                              aggregation: aggregation);
            }));
        }
Ejemplo n.º 3
0
        protected virtual RefVariable add_weight(string name,
                                                 int[] shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                                 VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            var   default_graph = ops.get_default_graph();
            Graph init_graph    = null;

            RefVariable[] existing_variables = null;

            if (default_graph.building_function)
            {
                throw new NotImplementedException("add_weight");
            }
            else
            {
                init_graph         = default_graph;
                existing_variables = variables.global_variables().ToArray();
            }

            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            _set_scope();
            var reuse = built || (_reuse != null && _reuse.Value);

            return(Python.with(tf.variable_scope(_scope,
                                                 reuse: reuse,
                                                 auxiliary_name_scope: false), scope =>
            {
                _current_scope = scope;
                return Python.with(ops.name_scope(_name_scope()), delegate
                {
                    var variable = base.add_weight(name,
                                                   shape,
                                                   dtype: dtype,
                                                   initializer: initializer,
                                                   trainable: trainable,
                                                   getter: (name1, shape1, dtype1, initializer1, trainable1) =>
                    {
                        return tf.get_variable(name1,
                                               shape: new TensorShape(shape1),
                                               dtype: dtype1,
                                               initializer: initializer1,
                                               trainable: trainable1);
                    });

                    if (init_graph != null)
                    {
                        var trainable_variables = variables.trainable_variables();
                    }
                    return variable;
                });
            }));
        }
Ejemplo n.º 4
0
        public RefVariable get_variable(_VariableStore var_store,
                                        string name,
                                        TensorShape shape         = null,
                                        TF_DataType dtype         = TF_DataType.DtInvalid,
                                        object initializer        = null, // IInitializer or Tensor
                                        bool?trainable            = null,
                                        List <string> collections = null,
                                        bool?use_resource         = null,
                                        bool validate_shape       = true,
                                        VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation         = VariableAggregation.None)
        {
            string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;

            return(tf_with(ops.name_scope(null), scope =>
            {
                if (dtype == TF_DataType.DtInvalid)
                {
                    dtype = _dtype;
                }

                return var_store.get_variable(full_name,
                                              shape: shape,
                                              dtype: dtype,
                                              initializer: initializer,
                                              reuse: resue,
                                              trainable: trainable,
                                              collections: collections,
                                              synchronization: synchronization,
                                              aggregation: aggregation) as RefVariable;
            }));
        }
Ejemplo n.º 5
0
 public ResourceVariable(object initial_value      = null,
                         bool trainable            = true,
                         List <string> collections = null,
                         bool validate_shape       = true,
                         string caching_device     = "",
                         string name = null,
                         VariableDef variable_def        = null,
                         TF_DataType dtype               = TF_DataType.DtInvalid,
                         string import_scope             = "",
                         VariableAggregation aggregation = VariableAggregation.None,
                         TensorShape shape               = null)
 {
     if (variable_def != null)
     {
         if (initial_value != null)
         {
             throw new ValueError("variable_def and initial_value are mutually exclusive.");
         }
         _init_from_proto(variable_def, import_scope: import_scope);
     }
     else
     {
         _init_from_args(initial_value: initial_value,
                         trainable: trainable,
                         collections: collections,
                         caching_device: caching_device,
                         name: name,
                         dtype: dtype,
                         aggregation: aggregation,
                         shape: shape);
     }
 }
Ejemplo n.º 6
0
        public static RefVariable default_variable_creator(object initial_value,
                                                           string name         = null,
                                                           bool?trainable      = null,
                                                           TF_DataType dtype   = TF_DataType.DtInvalid,
                                                           bool validate_shape = false,
                                                           bool?use_resource   = null,
                                                           VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                           VariableAggregation aggregation         = VariableAggregation.None)
        {
            trainable = _get_trainable_value(synchronization, trainable);
            if (!use_resource.HasValue)
            {
                use_resource = get_variable_scope().use_resource;
            }

            if (!use_resource.HasValue)
            {
                use_resource = _DEFAULT_USE_RESOURCE;
            }

            if (use_resource.Value)
            {
                throw new NotImplementedException();
            }
            else
            {
                return(new RefVariable(initial_value,
                                       trainable: trainable.Value,
                                       validate_shape: validate_shape,
                                       name: name,
                                       dtype: dtype));
            }
        }
Ejemplo n.º 7
0
        private RefVariable _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool reuse          = false,
                                                 bool?trainable      = null,
                                                 bool validate_shape = false,
                                                 bool?use_resource   = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                                 VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            bool initializing_from_value = false;

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            Tensor init_val = null;

            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    init_val = initializer.call(shape, dtype);
                    var variable_dtype = dtype.as_base_dtype();
                }
            }

            // Create the variable.
            if (use_resource == null)
            {
                use_resource = false;
            }

            var v = variable_scope.default_variable_creator(init_val,
                                                            name: name,
                                                            trainable: trainable,
                                                            dtype: TF_DataType.DtInvalid,
                                                            validate_shape: validate_shape,
                                                            synchronization: synchronization,
                                                            aggregation: aggregation);

            _vars[name] = v;

            return(v);
        }
Ejemplo n.º 8
0
 public ResourceVariable Variable <T>(T data,
                                      bool trainable                  = true,
                                      bool validate_shape             = true,
                                      bool use_resource               = true,
                                      string name                     = null,
                                      TF_DataType dtype               = TF_DataType.DtInvalid,
                                      VariableAggregation aggregation = VariableAggregation.None,
                                      int[] shape                     = null)
 => new ResourceVariable(data,
                         trainable: trainable,
                         validate_shape: validate_shape,
                         name: name,
                         dtype: dtype,
                         aggregation: aggregation,
                         shape: shape);
Ejemplo n.º 9
0
        private IVariableV1 _true_getter(string name,
                                         TensorShape shape         = null,
                                         TF_DataType dtype         = TF_DataType.TF_FLOAT,
                                         object initializer        = null,
                                         bool?trainable            = null,
                                         List <string> collections = null,
                                         bool validate_shape       = true,
                                         VariableSynchronization synchronization = VariableSynchronization.Auto,
                                         VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool is_scalar = !(shape is null) && shape.ndim == 0;

            if (initializer is IInitializer init)
            {
                return(_get_single_variable(name: name,
                                            shape: shape,
                                            dtype: dtype,
                                            initializer: init,
                                            trainable: trainable,
                                            collections: collections,
                                            validate_shape: validate_shape,
                                            synchronization: synchronization,
                                            aggregation: aggregation));
            }
            else if (initializer is Tensor tensor)
            {
                return(_get_single_variable(name: name,
                                            shape: shape,
                                            dtype: dtype,
                                            initializer: tensor,
                                            trainable: trainable,
                                            validate_shape: validate_shape,
                                            synchronization: synchronization,
                                            aggregation: aggregation));
            }
            else
            {
                IInitializer init1 = null;
                return(_get_single_variable(name: name,
                                            shape: shape,
                                            dtype: dtype,
                                            initializer: init1,
                                            trainable: trainable,
                                            validate_shape: validate_shape,
                                            synchronization: synchronization,
                                            aggregation: aggregation));
            }
        }
Ejemplo n.º 10
0
        public static RefVariable get_variable(string name,
                                               TensorShape shape  = null,
                                               TF_DataType dtype  = TF_DataType.DtInvalid,
                                               object initializer = null, // IInitializer or Tensor
                                               bool?trainable     = null,
                                               VariableSynchronization synchronization = VariableSynchronization.Auto,
                                               VariableAggregation aggregation         = VariableAggregation.None)
        {
            var scope = Tensorflow.variable_scope.get_variable_scope();
            var store = Tensorflow.variable_scope._get_default_variable_store();

            return(scope.get_variable(store,
                                      name,
                                      shape: shape,
                                      dtype: dtype,
                                      initializer: initializer,
                                      trainable: trainable));
        }
Ejemplo n.º 11
0
        /// <summary>
        /// Restore-on-create for a variable be saved with this `Checkpointable`.
        /// </summary>
        /// <returns></returns>
        protected virtual IVariableV1 _add_variable_with_custom_getter(string name,
                                                                       int[] shape,
                                                                       TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                                       IInitializer initializer = null,
                                                                       Func <string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null,
                                                                       bool overwrite    = false,
                                                                       bool trainable    = false,
                                                                       bool use_resource = false,
                                                                       VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                                       VariableAggregation aggregation         = VariableAggregation.None)
        {
            ops.init_scope();
#pragma warning disable CS0219 // Variable is assigned but its value is never used
            IInitializer checkpoint_initializer = null;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
            if (tf.context.executing_eagerly())
#pragma warning disable CS0642 // Possible mistaken empty statement
            {
                ;
            }
#pragma warning restore CS0642 // Possible mistaken empty statement
            else
            {
                checkpoint_initializer = null;
            }

            IVariableV1 new_variable;
            new_variable = getter(name, shape, dtype, initializer, trainable);

            // If we set an initializer and the variable processed it, tracking will not
            // assign again. It will add this variable to our dependencies, and if there
            // is a non-trivial restoration queued, it will handle that. This also
            // handles slot variables.
            if (!overwrite || new_variable is RefVariable)
            {
                return(_track_checkpointable(new_variable, name: name,
                                             overwrite: overwrite));
            }
            else
            {
                return(new_variable);
            }
        }
Ejemplo n.º 12
0
        public static VariableV1 default_variable_creator(object initial_value,
                                                          string name               = null,
                                                          bool?trainable            = null,
                                                          List <string> collections = null,
                                                          TF_DataType dtype         = TF_DataType.DtInvalid,
                                                          int[] shape               = null,
                                                          bool validate_shape       = false,
                                                          bool?use_resource         = null,
                                                          VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                          VariableAggregation aggregation         = VariableAggregation.None)
        {
            trainable = _get_trainable_value(synchronization, trainable);
            if (!use_resource.HasValue)
            {
                use_resource = get_variable_scope().use_resource;
            }

            if (!use_resource.HasValue)
            {
                use_resource = _DEFAULT_USE_RESOURCE;
            }

            if (use_resource.Value)
            {
                return(new ResourceVariable(initial_value,
                                            trainable: trainable.Value,
                                            validate_shape: validate_shape,
                                            collections: collections,
                                            name: name,
                                            dtype: dtype,
                                            shape: shape));
            }
            else
            {
                return(new RefVariable(initial_value,
                                       trainable: trainable.Value,
                                       validate_shape: validate_shape,
                                       collections: collections,
                                       name: name,
                                       dtype: dtype));
            }
        }
        private RefVariable _true_getter(string name,
                                         TensorShape shape        = null,
                                         TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                         IInitializer initializer = null,
                                         bool?trainable           = null,
                                         bool validate_shape      = true,
                                         VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                         VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            bool is_scalar = shape.NDim == 0;

            return(_get_single_variable(name: name,
                                        shape: shape,
                                        dtype: dtype,
                                        initializer: initializer,
                                        trainable: trainable,
                                        validate_shape: validate_shape,
                                        synchronization: synchronization,
                                        aggregation: aggregation));
        }
Ejemplo n.º 14
0
        private RefVariable _get_single_variable(string name,
                                                 TensorShape shape   = null,
                                                 TF_DataType dtype   = TF_DataType.DtInvalid,
                                                 Tensor initializer  = null,
                                                 bool reuse          = false,
                                                 bool?trainable      = null,
                                                 bool validate_shape = false,
                                                 bool?use_resource   = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            if (use_resource == null)
            {
                use_resource = false;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            RefVariable v = null;

            // Create the variable.
            ops.init_scope();
            {
                var init_val = initializer;
                v = new RefVariable(init_val,
                                    name: name,
                                    validate_shape: validate_shape,
                                    trainable: trainable.Value);
            }

            _vars[name] = v;

            return(v);
        }
        public RefVariable get_variable(string name,
                                        TensorShape shape        = null,
                                        TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                        IInitializer initializer = null,
                                        bool?trainable           = null,
                                        bool validate_shape      = true,
                                        VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                        VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            dtype     = dtype.as_base_dtype();
            trainable = variable_scope._get_trainable_value(synchronization, trainable);

            return(_true_getter(name,
                                shape: shape,
                                dtype: dtype,
                                initializer: initializer,
                                trainable: trainable,
                                validate_shape: validate_shape,
                                synchronization: synchronization,
                                aggregation: aggregation));
        }
Ejemplo n.º 16
0
        private RefVariable _true_getter(string name,
                                         TensorShape shape   = null,
                                         TF_DataType dtype   = TF_DataType.TF_FLOAT,
                                         object initializer  = null,
                                         bool?trainable      = null,
                                         bool validate_shape = true,
                                         VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                         VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            bool is_scalar = !(shape is null) && shape.NDim == 0;

            if (initializer is IInitializer init)
            {
                return(_get_single_variable(name: name,
                                            shape: shape,
                                            dtype: dtype,
                                            initializer: init,
                                            trainable: trainable,
                                            validate_shape: validate_shape,
                                            synchronization: synchronization,
                                            aggregation: aggregation));
            }
            else if (initializer is Tensor tensor)
            {
                return(_get_single_variable(name: name,
                                            shape: shape,
                                            dtype: dtype,
                                            initializer: tensor,
                                            trainable: trainable,
                                            validate_shape: validate_shape,
                                            synchronization: synchronization,
                                            aggregation: aggregation));
            }
            else
            {
                throw new NotImplementedException("_true_getter");
            }
        }
Ejemplo n.º 17
0
        /// <summary>
        /// Adds a new variable to the layer, or gets an existing one; returns it.
        /// </summary>
        /// <param name="name"></param>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="initializer"></param>
        /// <param name="trainable"></param>
        /// <param name="synchronization"></param>
        /// <param name="aggregation"></param>
        /// <returns></returns>
        protected virtual IVariableV1 add_weight(string name,
                                                 int[] shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool trainable           = true,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            var   default_graph = ops.get_default_graph();
            Graph init_graph    = null;

            IVariableV1[] existing_variables = null;

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            if (default_graph.building_function)
            {
                throw new NotImplementedException("add_weight");
            }
            else
            {
                init_graph         = default_graph;
                existing_variables = variables.global_variables().ToArray();
            }

            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            _set_scope();
            var reuse = built || (_reuse != null && _reuse.Value);

            return(tf.Variable(0));
        }
Ejemplo n.º 18
0
        public RefVariable get_variable(string name,
                                        TensorShape shape         = null,
                                        TF_DataType dtype         = TF_DataType.DtInvalid,
                                        object initializer        = null, // IInitializer or Tensor
                                        bool?trainable            = null,
                                        List <string> collections = null,
                                        bool?use_resource         = null,
                                        bool validate_shape       = true,
                                        VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation         = VariableAggregation.None)
        {
            var scope = Tensorflow.variable_scope.get_variable_scope();
            var store = Tensorflow.variable_scope._get_default_variable_store();

            return(scope.get_variable(store,
                                      name,
                                      shape: shape,
                                      dtype: dtype,
                                      use_resource: use_resource,
                                      validate_shape: validate_shape,
                                      initializer: initializer,
                                      trainable: trainable,
                                      collections: collections));
        }
Ejemplo n.º 19
0
        public IVariableV1 get_variable(string name,
                                        TensorShape shape         = null,
                                        TF_DataType dtype         = TF_DataType.TF_FLOAT,
                                        object initializer        = null, // IInitializer or Tensor
                                        bool?reuse                = null,
                                        bool?trainable            = null,
                                        List <string> collections = null,
                                        bool validate_shape       = true,
                                        VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation         = VariableAggregation.None)
        {
            dtype     = dtype.as_base_dtype();
            trainable = variable_scope._get_trainable_value(synchronization, trainable);

            return(_true_getter(name,
                                shape: shape,
                                dtype: dtype,
                                initializer: initializer,
                                trainable: trainable,
                                collections: collections,
                                validate_shape: validate_shape,
                                synchronization: synchronization,
                                aggregation: aggregation));
        }
Ejemplo n.º 20
0
 public static VariableV1 make_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, Initializer initializer = null,
                                        bool trainable    = true, string caching_device          = null, bool validate_shape = true, Constraints.ConstraintBase constraint = null,
                                        bool use_resource = false, Graph[] collections           = null, VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation = VariableAggregation.None) => throw new NotImplementedException();
Ejemplo n.º 21
0
 public void add_weight(string name                     = null, TensorShape shape = null, string dtype = null, Initializer initializer = null,
                        Regularizer regularizer         = null, bool?trainable    = null, ConstraintBase constraint = null,
                        dynamic partitioner             = null, bool?use_resource = null, VariableSynchronization synchronization = VariableSynchronization.Auto,
                        VariableAggregation aggregation = VariableAggregation.None, Dictionary <string, object> kwargs = null) => throw new NotImplementedException();
Ejemplo n.º 22
0
        /// <summary>
        /// Adds a new variable to the layer, or gets an existing one; returns it.
        /// </summary>
        /// <param name="name"></param>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="initializer"></param>
        /// <param name="trainable"></param>
        /// <param name="synchronization"></param>
        /// <param name="aggregation"></param>
        /// <returns></returns>
        protected virtual IVariableV1 add_weight(string name,
                                                 int[] shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool trainable           = true,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            var   default_graph = ops.get_default_graph();
            Graph init_graph    = null;

            IVariableV1[] existing_variables = null;

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            if (default_graph.building_function)
            {
                throw new NotImplementedException("add_weight");
            }
            else
            {
                init_graph         = default_graph;
                existing_variables = variables.global_variables().ToArray();
            }

            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            _set_scope();
            var reuse = built || (_reuse != null && _reuse.Value);

            return(tf_with(tf.variable_scope(_scope,
                                             reuse: reuse,
                                             auxiliary_name_scope: false), scope =>
            {
                _current_scope = scope;
                return tf_with(ops.name_scope(_name_scope()), delegate
                {
                    var variable = base.add_weight(name,
                                                   shape,
                                                   dtype: dtype,
                                                   initializer: initializer,
                                                   trainable: trainable,
                                                   getter: (args) =>
                                                   tf.compat.v1.get_variable(args.Name,
                                                                             shape: args.Shape,
                                                                             dtype: args.DType,
                                                                             initializer: args.Initializer,
                                                                             trainable: args.Trainable)
                                                   );

                    //if (init_graph != null)
                    //var trainable_variables = variables.trainable_variables();

                    return variable;
                });
            }));
        }
Ejemplo n.º 23
0
 public void add_weight(string name, TensorShape shape          = null, VariableAggregation aggregation = VariableAggregation.Sum,
                        VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer = null,
                        string dtype = null) => throw new NotImplementedException();
Ejemplo n.º 24
0
        private IVariableV1 _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool reuse                = false,
                                                 bool?trainable            = null,
                                                 List <string> collections = null,
                                                 bool validate_shape       = false,
                                                 bool?use_resource         = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool initializing_from_value = false;

            if (use_resource == null)
            {
                use_resource = false;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            IVariableV1 v = null;

            // Create the tensor to initialize the variable with default value.
            if (initializer == null)
            {
                if (dtype.is_floating())
                {
                    initializer             = tf.glorot_uniform_initializer;
                    initializing_from_value = false;
                }
            }

            // Create the variable.
            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    Func <Tensor> init_val       = () => initializer.call(shape, dtype);
                    var           variable_dtype = dtype.as_base_dtype();

                    v = variable_scope.default_variable_creator(init_val,
                                                                name: name,
                                                                trainable: trainable,
                                                                collections: collections,
                                                                dtype: variable_dtype,
                                                                validate_shape: validate_shape,
                                                                synchronization: synchronization,
                                                                aggregation: aggregation);
                }
            }

            _vars[name] = v;

            return(v);
        }
Ejemplo n.º 25
0
        private void _init_from_args(object initial_value      = null,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     string caching_device     = "",
                                     string name       = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid,
                                     VariableAggregation aggregation = VariableAggregation.None,
                                     TensorShape shape = null)
        {
            var init_from_fn = initial_value.GetType().Name == "Func`1" ||
                               initial_value.GetType().GetInterface("IInitializer") != null;

            if (collections == null)
            {
                collections = new List <string>()
                {
                    tf.GraphKeys.GLOBAL_VARIABLES
                }
            }
            ;
            _trainable = trainable;

            if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
            }

            tf_with(ops.init_scope(), init_scope =>
            {
                _in_graph_mode = !tf.Context.executing_eagerly();
                var values     = init_from_fn ? new object[0] : new object[] { initial_value };
                tf_with(ops.name_scope(name, "Variable", values, skip_on_eager: false), scope =>
                {
                    name               = scope;
                    var handle_name    = ops.name_from_scope_name(name);
                    string unique_id   = "";
                    string shared_name = "";

                    if (_in_graph_mode)
                    {
                        shared_name = handle_name;
                        unique_id   = shared_name;
                    }
                    else
                    {
                        unique_id   = $"{handle_name}_{ops.uid()}";
                        shared_name = tf.Context.shared_name();
                    }

                    var attr  = new AttrValue();
                    attr.List = new AttrValue.Types.ListValue();
                    attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
                    tf_with(ops.name_scope("Initializer"), delegate
                    {
                        if (initial_value.GetType().GetInterface("IInitializer") != null)
                        {
                            initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
                        }
                        else
                        {
                            var value     = init_from_fn ? (initial_value as Func <Tensor>)() : initial_value;
                            initial_value = ops.convert_to_tensor(value,
                                                                  name: "initial_value",
                                                                  dtype: dtype);
                        }
                    });
                    _shape         = shape ?? (initial_value as Tensor).TensorShape;
                    _initial_value = initial_value as Tensor;



                    if (_in_graph_mode)
                    {
                        handle         = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
                        initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;

                        ops.colocate_with(initializer_op);

                        _graph_element = gen_array_ops.identity(handle, name = "read");
                        ops.add_to_collections <IVariableV1>(collections, this);
                        _dtype = handle.dtype;
                    }
                    else
                    {
                        handle = resource_variable_ops.eager_safe_variable_handle(
                            initial_value: _initial_value,
                            shape: _shape,
                            shared_name: shared_name,
                            name: name,
                            graph_mode: _in_graph_mode);

                        gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
                        initializer_op = null;
                        _graph_element = null;
                        _dtype         = _initial_value.dtype.as_base_dtype();
                        initial_value  = _in_graph_mode ? initial_value : null;
                    }

                    base.__init__(trainable: trainable,
                                  handle: handle,
                                  name: name,
                                  unique_id: unique_id,
                                  handle_name: handle_name);
                });
            });
        }
Ejemplo n.º 26
0
        protected virtual IVariableV1 add_weight(string name,
                                                 Shape shape,
                                                 TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                 IInitializer initializer = null,
                                                 IRegularizer regularizer = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None,
                                                 bool trainable = true,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
                }
            }

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            var args = new VariableArgs
            {
                Name            = name,
                Shape           = shape,
                DType           = dtype,
                Getter          = getter ?? base_layer_utils.make_variable,
                Overwrite       = true,
                Initializer     = initializer,
                Synchronization = synchronization,
                Aggregation     = aggregation,
                Trainable       = trainable
            };
            var variable = _add_variable_with_custom_getter(args);

            if (regularizer != null)
            {
                var name_in_scope = variable.Name.Split(':')[0];
                _handle_weight_regularization(name_in_scope, variable, regularizer);
            }

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainable_weights.Add(variable);
            }
            else
            {
                non_trainable_weights.Add(variable);
            }

            return(variable);
        }