示例#1
0
        /// <summary>
        /// Restore-on-create for a variable be saved with this `Checkpointable`.
        /// </summary>
        /// <returns></returns>
        protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args)
        {
            ops.init_scope();
#pragma warning disable CS0219 // Variable is assigned but its value is never used
            IInitializer checkpoint_initializer = null;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
            if (tf.Context.executing_eagerly())
#pragma warning disable CS0642 // Possible mistaken empty statement
            {
                ;
            }
#pragma warning restore CS0642 // Possible mistaken empty statement
            else
            {
                checkpoint_initializer = null;
            }

            IVariableV1 new_variable;
            new_variable = args.Getter(args);

            // If we set an initializer and the variable processed it, tracking will not
            // assign again. It will add this variable to our dependencies, and if there
            // is a non-trivial restoration queued, it will handle that. This also
            // handles slot variables.
            if (!args.Overwrite || new_variable is RefVariable)
            {
                return(_track_checkpointable(new_variable, name: args.Name,
                                             overwrite: args.Overwrite));
            }
            else
            {
                return(new_variable);
            }
        }
        /// <summary>
        /// Adds a new variable to the layer.
        /// </summary>
        /// <param name="args"></param>
        /// <returns></returns>
        public static IVariableV1 make_variable(VariableArgs args)
        {
#pragma warning disable CS0219 // Variable is assigned but its value is never used
            var initializing_from_value = false;
#pragma warning restore CS0219 // Variable is assigned but its value is never used

            ops.init_scope();

            Func <Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs
            {
                Shape = args.Shape,
                DType = args.DType
            });

            var variable_dtype = args.DType.as_base_dtype();
            var v = tf.Variable(init_val,
                                dtype: variable_dtype,
                                shape: args.Shape,
                                name: args.Name,
                                trainable: args.Trainable,
                                validate_shape: args.ValidateShape,
                                use_resource: args.UseResource);

            return(v);
        }
示例#3
0
        protected virtual IVariableV1 add_weight(string name,
                                                 TensorShape shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            if (trainable == null)
            {
                trainable = true;
            }

            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}");
                }
            }

            var args = new VariableArgs
            {
                Name        = name,
                Shape       = shape,
                DType       = dtype,
                Getter      = getter ?? base_layer_utils.make_variable,
                Overwrite   = true,
                Initializer = initializer,
                Trainable   = trainable.Value
            };
            var variable = _add_variable_with_custom_getter(args);

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainableWeights.Add(variable);
            }
            else
            {
                nonTrainableWeights.Add(variable);
            }

            return(variable);
        }
        protected virtual IVariableV1 add_weight(string name,
                                                 Shape shape,
                                                 TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                 IInitializer initializer = null,
                                                 IRegularizer regularizer = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None,
                                                 bool trainable = true,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
                }
            }

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            var args = new VariableArgs
            {
                Name            = name,
                Shape           = shape,
                DType           = dtype,
                Getter          = getter ?? base_layer_utils.make_variable,
                Overwrite       = true,
                Initializer     = initializer,
                Synchronization = synchronization,
                Aggregation     = aggregation,
                Trainable       = trainable
            };
            var variable = _add_variable_with_custom_getter(args);

            if (regularizer != null)
            {
                var name_in_scope = variable.Name.Split(':')[0];
                _handle_weight_regularization(name_in_scope, variable, regularizer);
            }

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainable_weights.Add(variable);
            }
            else
            {
                non_trainable_weights.Add(variable);
            }

            return(variable);
        }