예제 #1
0
        public RefVariable get_variable(_VariableStore var_store,
                                        string name,
                                        TensorShape shape         = null,
                                        TF_DataType dtype         = TF_DataType.DtInvalid,
                                        object initializer        = null, // IInitializer or Tensor
                                        bool?trainable            = null,
                                        List <string> collections = null,
                                        bool?use_resource         = null,
                                        bool validate_shape       = true,
                                        VariableSynchronization synchronization = VariableSynchronization.Auto,
                                        VariableAggregation aggregation         = VariableAggregation.None)
        {
            string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;

            return(tf_with(ops.name_scope(null), scope =>
            {
                if (dtype == TF_DataType.DtInvalid)
                {
                    dtype = _dtype;
                }

                return var_store.get_variable(full_name,
                                              shape: shape,
                                              dtype: dtype,
                                              initializer: initializer,
                                              reuse: resue,
                                              trainable: trainable,
                                              collections: collections,
                                              synchronization: synchronization,
                                              aggregation: aggregation) as RefVariable;
            }));
        }
        public RefVariable get_variable(_VariableStore var_store,
                                        string name,
                                        TensorShape shape        = null,
                                        TF_DataType dtype        = TF_DataType.DtInvalid,
                                        IInitializer initializer = null,
                                        bool?trainable           = null,
                                        VariableSynchronization synchronization = VariableSynchronization.AUTO,
                                        VariableAggregation aggregation         = VariableAggregation.NONE)
        {
            string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name;

            return(with(new ops.name_scope(null), scope =>
            {
                if (dtype == TF_DataType.DtInvalid)
                {
                    dtype = _dtype;
                }

                return var_store.get_variable(full_name,
                                              shape: shape,
                                              dtype: dtype,
                                              initializer: initializer,
                                              trainable: trainable,
                                              synchronization: synchronization,
                                              aggregation: aggregation);
            }));
        }
예제 #3
0
 public PureVariableScope(string name,
                          string old_name_scope = null,
                          TF_DataType dtype     = TF_DataType.DtInvalid)
 {
     _name            = name;
     _old_name_scope  = old_name_scope;
     _var_store       = variable_scope._get_default_variable_store();
     _var_scope_store = variable_scope.get_variable_scope_store();
 }
예제 #4
0
        public static _VariableStore _get_default_variable_store()
        {
            var store = ops.get_collection <_VariableStore>(_VARSTORE_KEY).FirstOrDefault();

            if (store == null)
            {
                store = new _VariableStore();
                ops.add_to_collection(_VARSTORE_KEY, store);
            }
            return(store);
        }
예제 #5
0
        public static _VariableStore _get_default_variable_store()
        {
            var store = ops.get_collection(_VARSTORE_KEY);

            if (store != null)
            {
                return((store as List <_VariableStore>)[0]);
            }

            var store1 = new _VariableStore();

            ops.add_to_collection(_VARSTORE_KEY, store1);
            return(store1);
        }
예제 #6
0
        public PureVariableScope(VariableScope scope,
                                 string old_name_scope = null,
                                 TF_DataType dtype     = TF_DataType.DtInvalid)
        {
            _scope           = scope;
            _old_name_scope  = old_name_scope;
            _var_store       = variable_scope._get_default_variable_store();
            _var_scope_store = variable_scope.get_variable_scope_store();
            _new_name        = _scope.name;

            string name_scope = _scope._name_scope;

            variable_scope_object = new VariableScope(_reuse,
                                                      name: _new_name,
                                                      name_scope: name_scope);

            _cached_variable_scope_object = variable_scope_object;
        }