Esempio n. 1
0
        public VarianceScaling(float factor      = 2.0f,
                               string mode       = "FAN_IN",
                               bool uniform      = false,
                               int?seed          = null,
                               TF_DataType dtype = TF_DataType.TF_FLOAT)
        {
            if (!dtype.is_floating())
            {
                throw new TypeError("Cannot create initializer for non-floating point type.");
            }
            if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
            {
                throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
            }

            if (factor < 0)
            {
                throw new ValueError("`scale` must be positive float.");
            }

            _scale = factor;
            _mode  = mode;
            _seed  = seed;
            _dtype = dtype;
        }
Esempio n. 2
0
        public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null)
        {
            if (dtype == image.dtype)
            {
                return(array_ops.identity(image, name: name));
            }

            return(tf_with(ops.name_scope(name, "convert_image", image), scope =>
            {
                name = scope;

                if (image.dtype.is_integer() && dtype.is_integer())
                {
                    throw new NotImplementedException("convert_image_dtype is_integer");
                }
                else if (image.dtype.is_floating() && dtype.is_floating())
                {
                    throw new NotImplementedException("convert_image_dtype is_floating");
                }
                else
                {
                    if (image.dtype.is_integer())
                    {
                        // Converting to float: first cast, then scale. No saturation possible.
                        var cast = math_ops.cast(image, dtype);
                        var scale = 1.0f / image.dtype.max();
                        return math_ops.multiply(cast, scale, name: name);
                    }
                    else
                    {
                        throw new NotImplementedException("convert_image_dtype is_integer");
                    }
                }
            }));
        }
Esempio n. 3
0
        protected virtual IVariableV1 add_weight(string name,
                                                 TensorShape shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            if (trainable == null)
            {
                trainable = true;
            }

            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}");
                }
            }

            var args = new VariableArgs
            {
                Name        = name,
                Shape       = shape,
                DType       = dtype,
                Getter      = getter ?? base_layer_utils.make_variable,
                Overwrite   = true,
                Initializer = initializer,
                Trainable   = trainable.Value
            };
            var variable = _add_variable_with_custom_getter(args);

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainableWeights.Add(variable);
            }
            else
            {
                nonTrainableWeights.Add(variable);
            }

            return(variable);
        }
Esempio n. 4
0
        protected virtual RefVariable add_weight(string name,
                                                 int[] shape,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool?trainable           = null,
                                                 Func <string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = TF_DataType.TF_FLOAT;
            }

            if (trainable == null)
            {
                trainable = true;
            }

            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}");
                }
            }
            var variable = _add_variable_with_custom_getter(name,
                                                            shape,
                                                            dtype: dtype,
                                                            getter: (getter == null) ? base_layer_utils.make_variable : getter,
                                                            overwrite: true,
                                                            initializer: initializer,
                                                            trainable: trainable.Value);

            backend.track_variable(variable);
            _trainable_weights.Add(variable);

            return(variable);
        }
        protected virtual IVariableV1 add_weight(string name,
                                                 Shape shape,
                                                 TF_DataType dtype        = TF_DataType.TF_FLOAT,
                                                 IInitializer initializer = null,
                                                 IRegularizer regularizer = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None,
                                                 bool trainable = true,
                                                 Func <VariableArgs, IVariableV1> getter = null)
        {
            // Initialize variable when no initializer provided
            if (initializer == null)
            {
                // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
                if (dtype.is_floating())
                {
                    initializer = tf.glorot_uniform_initializer;
                }
                else if (dtype.is_integer())
                {
                    initializer = tf.zeros_initializer;
                }
                else
                {
                    throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
                }
            }

            if (synchronization == VariableSynchronization.OnRead)
            {
                trainable = false;
            }

            var args = new VariableArgs
            {
                Name            = name,
                Shape           = shape,
                DType           = dtype,
                Getter          = getter ?? base_layer_utils.make_variable,
                Overwrite       = true,
                Initializer     = initializer,
                Synchronization = synchronization,
                Aggregation     = aggregation,
                Trainable       = trainable
            };
            var variable = _add_variable_with_custom_getter(args);

            if (regularizer != null)
            {
                var name_in_scope = variable.Name.Split(':')[0];
                _handle_weight_regularization(name_in_scope, variable, regularizer);
            }

            //backend.track_variable(variable);
            if (trainable == true)
            {
                trainable_weights.Add(variable);
            }
            else
            {
                non_trainable_weights.Add(variable);
            }

            return(variable);
        }
Esempio n. 6
0
        private IVariableV1 _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 bool reuse                = false,
                                                 bool?trainable            = null,
                                                 List <string> collections = null,
                                                 bool validate_shape       = false,
                                                 bool?use_resource         = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool initializing_from_value = false;

            if (use_resource == null)
            {
                use_resource = false;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            IVariableV1 v = null;

            // Create the tensor to initialize the variable with default value.
            if (initializer == null)
            {
                if (dtype.is_floating())
                {
                    initializer             = tf.glorot_uniform_initializer;
                    initializing_from_value = false;
                }
            }

            // Create the variable.
            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    Func <Tensor> init_val       = () => initializer.call(shape, dtype);
                    var           variable_dtype = dtype.as_base_dtype();

                    v = variable_scope.default_variable_creator(init_val,
                                                                name: name,
                                                                trainable: trainable,
                                                                collections: collections,
                                                                dtype: variable_dtype,
                                                                validate_shape: validate_shape,
                                                                synchronization: synchronization,
                                                                aggregation: aggregation);
                }
            }

            _vars[name] = v;

            return(v);
        }