/// <summary> /// Restore-on-create for a variable be saved with this `Checkpointable`. /// </summary> /// <returns></returns> protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) { ops.init_scope(); #pragma warning disable CS0219 // Variable is assigned but its value is never used IInitializer checkpoint_initializer = null; #pragma warning restore CS0219 // Variable is assigned but its value is never used if (tf.Context.executing_eagerly()) #pragma warning disable CS0642 // Possible mistaken empty statement { ; } #pragma warning restore CS0642 // Possible mistaken empty statement else { checkpoint_initializer = null; } IVariableV1 new_variable; new_variable = args.Getter(args); // If we set an initializer and the variable processed it, tracking will not // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. if (!args.Overwrite || new_variable is RefVariable) { return(_track_checkpointable(new_variable, name: args.Name, overwrite: args.Overwrite)); } else { return(new_variable); } }
/// <summary> /// Adds a new variable to the layer. /// </summary> /// <param name="args"></param> /// <returns></returns> public static IVariableV1 make_variable(VariableArgs args) { #pragma warning disable CS0219 // Variable is assigned but its value is never used var initializing_from_value = false; #pragma warning restore CS0219 // Variable is assigned but its value is never used ops.init_scope(); Func <Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs { Shape = args.Shape, DType = args.DType }); var variable_dtype = args.DType.as_base_dtype(); var v = tf.Variable(init_val, dtype: variable_dtype, shape: args.Shape, name: args.Name, trainable: args.Trainable, validate_shape: args.ValidateShape, use_resource: args.UseResource); return(v); }
protected virtual IVariableV1 add_weight(string name, TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, bool?trainable = null, Func <VariableArgs, IVariableV1> getter = null) { if (dtype == TF_DataType.DtInvalid) { dtype = TF_DataType.TF_FLOAT; } if (trainable == null) { trainable = true; } // Initialize variable when no initializer provided if (initializer == null) { // If dtype is DT_FLOAT, provide a uniform unit scaling initializer if (dtype.is_floating()) { initializer = tf.glorot_uniform_initializer; } else if (dtype.is_integer()) { initializer = tf.zeros_initializer; } else { throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}"); } } var args = new VariableArgs { Name = name, Shape = shape, DType = dtype, Getter = getter ?? base_layer_utils.make_variable, Overwrite = true, Initializer = initializer, Trainable = trainable.Value }; var variable = _add_variable_with_custom_getter(args); //backend.track_variable(variable); if (trainable == true) { trainableWeights.Add(variable); } else { nonTrainableWeights.Add(variable); } return(variable); }
protected virtual IVariableV1 add_weight(string name, Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, IRegularizer regularizer = null, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None, bool trainable = true, Func <VariableArgs, IVariableV1> getter = null) { // Initialize variable when no initializer provided if (initializer == null) { // If dtype is DT_FLOAT, provide a uniform unit scaling initializer if (dtype.is_floating()) { initializer = tf.glorot_uniform_initializer; } else if (dtype.is_integer()) { initializer = tf.zeros_initializer; } else { throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); } } if (synchronization == VariableSynchronization.OnRead) { trainable = false; } var args = new VariableArgs { Name = name, Shape = shape, DType = dtype, Getter = getter ?? base_layer_utils.make_variable, Overwrite = true, Initializer = initializer, Synchronization = synchronization, Aggregation = aggregation, Trainable = trainable }; var variable = _add_variable_with_custom_getter(args); if (regularizer != null) { var name_in_scope = variable.Name.Split(':')[0]; _handle_weight_regularization(name_in_scope, variable, regularizer); } //backend.track_variable(variable); if (trainable == true) { trainable_weights.Add(variable); } else { non_trainable_weights.Add(variable); } return(variable); }