示例#1
0
 public static Tensor assign(ResourceVariable @ref, object value,
                             bool validate_shape = true,
                             bool use_locking    = true,
                             string name         = null)
 {
     return(gen_state_ops.assign(@ref,
                                 value,
                                 validate_shape: validate_shape,
                                 use_locking: use_locking,
                                 name: name));
 }
示例#2
0
        /// <summary>
        /// Converts the given `value` to a `Tensor`.
        /// </summary>
        /// <param name="value"></param>
        /// <param name="dtype"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor convert_to_tensor(object value,
                                               TF_DataType dtype           = TF_DataType.DtInvalid,
                                               string name                 = null,
                                               bool as_ref                 = false,
                                               TF_DataType preferred_dtype = TF_DataType.DtInvalid,
                                               Context ctx                 = null)
        {
            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = preferred_dtype;
            }

            if (value is EagerTensor eager_tensor)
            {
                if (tf.executing_eagerly())
                {
                    return(eager_tensor);
                }
                else
                {
                    var graph = get_default_graph();
                    if (!graph.building_function)
                    {
                        throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
                    }
                    return((graph as FuncGraph).capture(eager_tensor, name: name));
                }
            }

            Tensor ret = value switch
            {
                NDArray nd => constant_op.constant(nd, dtype: dtype, name: name),
                EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE
                            ? tensor.AsPlaceholder(name: name)
                            : tensor.AsConstant(name: name),
                Tensor tensor => tensor,
                Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
                RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
                ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
                TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
                int[] dims => constant_op.constant(dims, dtype: dtype, name: name),
                string str => constant_op.constant(str, dtype: tf.@string, name: name),
                string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
                IEnumerable <object> objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name),
                _ => constant_op.constant(value, dtype: dtype, name: name)
            };

            return(ret);
        }
示例#3
0
        public static ResourceVariable cast(ResourceVariable x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
        {
            var base_type = dtype.as_base_dtype();
            if (base_type == x.dtype)
                return x;

            return tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
            {
                name = scope;
                var t_x = ops.convert_to_tensor(x, name: "x");
                if (t_x.dtype.as_base_dtype() != base_type)
                    t_x = gen_math_ops.cast(t_x, base_type, name: name);

                return x;
            });
        }
示例#4
0
        public static Tensor assign(ResourceVariable @ref, object value,
                                    bool validate_shape = true,
                                    bool use_locking    = true,
                                    string name         = null)
        {
            var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });

            var _result      = _op.outputs;
            var _inputs_flat = _op.inputs;

            var _attrs = new Dictionary <string, object>();

            _attrs["T"] = _op.get_attr("T");
            _attrs["validate_shape"] = _op.get_attr("validate_shape");
            _attrs["use_locking"]    = _op.get_attr("use_locking");

            return(_result[0]);
        }
示例#5
0
        private static Tensor op_helper <T>(string default_name, ResourceVariable x, T y)
        => tf_with(ops.name_scope(null, default_name, new { x, y }), scope =>
        {
            string name   = scope;
            var xVal      = x.value();
            var yTensor   = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y");
            Tensor result = null;
            switch (default_name)
            {
            case "add":
                result = x.dtype == TF_DataType.TF_STRING ?
                         gen_math_ops.add(xVal, yTensor, name) :
                         gen_math_ops.add_v2(xVal, yTensor, name);
                break;

            case "sub":
                result = gen_math_ops.sub(xVal, yTensor, name);
                break;

            case "mul":
                result = gen_math_ops.mul(xVal, yTensor, name: name);
                break;

            case "less":
                result = gen_math_ops.less(xVal, yTensor, name);
                break;

            case "greater":
                result = gen_math_ops.greater(xVal, yTensor, name);
                break;

            default:
                throw new NotImplementedException("");
            }

            // x.assign(result);
            // result.ResourceVar = x;
            return(result);
        });
示例#6
0
        public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                                                             bool clear_devices  = false,
                                                                                                                             string import_scope = "",
                                                                                                                             Dictionary <string, Tensor> input_map = null,
                                                                                                                             string unbound_inputs_col_name        = "unbound_inputs",
                                                                                                                             string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <ByteString, IVariableV1>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            IVariableV1 variable = null;
                            if (!variable_objects.ContainsKey(value))
                            {
                                var proto = VariableDef.Parser.ParseFrom(value);
                                if (proto.IsResource)
                                {
                                    variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                else
                                {
                                    variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                variable_objects[value] = variable;
                            }
                            variable = variable_objects[value];
                            graph.add_to_collection(col.Key, variable);
                        }
                    }
                    else
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            switch (col.Key)
                            {
                            case "cond_context":
                            {
                                var proto       = CondContextDef.Parser.ParseFrom(value);
                                var condContext = new CondContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, condContext);
                            }
                            break;

                            case "while_context":
                            {
                                var proto        = WhileContextDef.Parser.ParseFrom(value);
                                var whileContext = new WhileContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, whileContext);
                            }
                            break;

                            default:
                                Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
                                continue;
                            }
                        }
                    }

                    break;

                default:
                    Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
                    break;
                }
            }

            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
                                                               scope: scope_to_prepend_to_names);
            var var_list = new Dictionary <string, IVariableV1>();

            variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);

            return(var_list, imported_return_elements);
        }
示例#7
0
        private IVariableV1 _get_single_variable(string name,
                                                 TensorShape shape        = null,
                                                 TF_DataType dtype        = TF_DataType.DtInvalid,
                                                 IInitializer initializer = null,
                                                 Tensor init_value        = null,
                                                 bool reuse                = false,
                                                 bool?trainable            = null,
                                                 List <string> collections = null,
                                                 bool validate_shape       = false,
                                                 bool?use_resource         = null,
                                                 VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                 VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool initializing_from_value = init_value != null;

            if (use_resource == null)
            {
                use_resource = variable_scope._DEFAULT_USE_RESOURCE;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            IVariableV1 v = null;

            // Create the tensor to initialize the variable with default value.
            if (initializer == null && init_value == null)
            {
                if (dtype.is_floating())
                {
                    initializer             = tf.glorot_uniform_initializer;
                    initializing_from_value = false;
                }
            }

            // Create the variable.
            ops.init_scope();
            {
                if (initializing_from_value)
                {
                    v = new ResourceVariable(init_value,
                                             name: name,
                                             validate_shape: validate_shape,
                                             trainable: trainable.Value);
                }
                else
                {
                    Func <Tensor> init_val       = () => initializer.Apply(new InitializerArgs(shape, dtype: dtype));
                    var           variable_dtype = dtype.as_base_dtype();

                    v = variable_scope.default_variable_creator(init_val,
                                                                name: name,
                                                                trainable: trainable,
                                                                collections: collections,
                                                                dtype: variable_dtype,
                                                                use_resource: use_resource,
                                                                validate_shape: validate_shape,
                                                                synchronization: synchronization,
                                                                aggregation: aggregation);
                }
            }

            _vars[name] = v;

            return(v);
        }
示例#8
0
 public _DenseResourceVariableProcessor(ResourceVariable v)
 {
     _v = v;
 }
示例#9
0
 public static _OptimizableVariable _get_processor(ResourceVariable v)
 {
     return(new _DenseResourceVariableProcessor(v));
 }
示例#10
0
 public Tensor assign(ResourceVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
 => state_ops.assign(@ref, value, validate_shape, use_locking, name);