Example #1
0
        public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary <string, AttrValue> attrs = null)
        {
            var node_def = new node_def_pb2.NodeDef();

            node_def.Op   = op_type;
            node_def.Name = name;

            foreach (var attr in attrs)
            {
                node_def.Attr.Add(attr.Key, attr.Value);
            }

            return(node_def);
        }
Example #2
0
        private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null)
        {
            var output_node = new NodeDef
            {
                Op   = "Const",
                Name = node_name
            };

            output_node.Attr["dtype"] = dtype;
            output_node.Attr["value"] = new AttrValue()
            {
                Tensor = tensor_util.make_tensor_proto(
                    data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape)
            };

            return(output_node);
        }
Example #3
0
        public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List <Tensor> inputs)
        {
            var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

            // Add inputs
            if (inputs != null)
            {
                foreach (var op_input in inputs)
                {
                    bool isList = false;
                    if (!isList)
                    {
                        c_api.TF_AddInput(op_desc, op_input._as_tf_output());
                    }
                    else
                    {
                        c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
                    }
                }
            }

            var status = new Status();

            // Add control inputs

            // Add attrs
            foreach (var attr in node_def.Attr)
            {
                var bytes = attr.Value.ToByteArray();
                var proto = Marshal.AllocHGlobal(bytes.Length);
                Marshal.Copy(bytes, 0, proto, bytes.Length);

                c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);

                status.Check(true);
            }

            var c_op = c_api.TF_FinishOperation(op_desc, status);

            status.Check(true);

            return(c_op);
        }
Example #4
0
        /// <summary>
        /// Update the input to this operation at the given index.
        ///
        /// NOTE: This is for TF internal use only.Please don't use it.
        /// </summary>
        /// <param name="index">the index of the input to update.</param>
        /// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
        public void _update_input(int index, Tensor tensor)
        {
            _assert_same_graph(tensor);

            var input  = _tf_input(index);
            var output = tensor._as_tf_output();

            // Reset cached inputs.
            _inputs_val = null;
            _node_def   = null;
            // after the c_api call next time _inputs is accessed
            // the updated inputs are reloaded from the c_api
            lock (Locks.ProcessWide)
            {
                // disable
                // c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle);
                //var updated_inputs = inputs;
                tf.Status.Check();
            }
        }
Example #5
0
        public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            Graph = g;

            _id_value = Graph._next_id();
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            _handle = ops._create_c_op(g, node_def, inputs);

            output_types = new TF_DataType[NumOutputs];

            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            Graph._add_op(this);
        }
Example #6
0
        public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List <Tensor> inputs)
        {
            var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);

            // Add inputs
            if (inputs != null)
            {
                foreach (var op_input in inputs)
                {
                    c_api.TF_AddInput(op_desc, op_input._as_tf_output());
                }
            }

            var status = new Status();

            // Add control inputs

            // Add attrs
            foreach (var attr in node_def.Attr)
            {
                var bytes = attr.Value.ToByteArray();
                var proto = Marshal.AllocHGlobal(bytes.Length);
                Marshal.Copy(bytes, 0, proto, bytes.Length);
                c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);

                if (status.Code != TF_Code.TF_OK)
                {
                    throw new Exception(status.Message);
                }
            }

            var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);

            if (status.Code != TF_Code.TF_OK)
            {
                throw new Exception(status.Message);
            }

            return(c_op);
        }
Example #7
0
 private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
 {
     foreach (var attr_def in op_def.Attr)
     {
         var key = attr_def.Name;
         if (attr_def.DefaultValue != null)
         {
             if (node_def.Attr.ContainsKey(key))
             {
                 var value = node_def.Attr[key];
                 if (value == null)
                 {
                     node_def.Attr[key] = attr_def.DefaultValue;
                 }
             }
             else
             {
                 node_def.Attr[key] = attr_def.DefaultValue;
             }
         }
     }
 }
Example #8
0
        /// <summary>
        /// Creates an `Operation`.
        /// </summary>
        /// <param name="node_def">`node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.</param>
        /// <param name="g">`Graph`. The parent graph.</param>
        /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
        /// <param name="output_types">list of `DType` objects.</param>
        /// <param name="control_inputs">
        /// list of operations or tensors from which to have a
        /// control dependency.
        /// </param>
        /// <param name="input_types">
        /// List of `DType` objects representing the
        /// types of the tensors accepted by the `Operation`. By default
        /// uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
        /// reference-typed inputs must specify these explicitly.
        /// </param>
        /// <param name="original_op"></param>
        /// <param name="op_def"></param>
        public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;

            // Build the list of control inputs.
            var control_input_ops = new List <Operation>();

            if (control_inputs != null)
            {
                foreach (var c in control_inputs)
                {
                    switch (c)
                    {
                    case Operation c1:
                        control_input_ops.Add(c1);
                        break;

                    case Tensor tensor:
                        control_input_ops.Add(tensor.op);
                        break;

                    // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
                    //case IndexedSlices islices:
                    //    control_input_ops.Add(islices.op);
                    //    break;
                    default:
                        throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
                    }
                }
            }

            // Dict mapping op name to file and line information for op colocation
            // context managers.
            _control_flow_context = graph._get_control_flow_context();

            // This will be set by self.inputs.
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

            // Initialize self._outputs.
            output_types = new TF_DataType[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            _outputs = new Tensor[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                _outputs[i] = new Tensor(this, i, OutputType(i));
            }

            graph._add_op(this);

            if (_handle != IntPtr.Zero)
            {
                _control_flow_post_processing();
            }
        }
Example #9
0
        /// <summary>
        /// Replaces all the variables in a graph with constants of the same values.
        /// </summary>
        /// <param name="sess">Active TensorFlow session containing the variables.</param>
        /// <param name="input_graph_def">GraphDef object holding the network.</param>
        /// <param name="output_node_names">List of name strings for the result nodes of the graph.</param>
        /// <param name="variable_names_whitelist"></param>
        /// <param name="variable_names_blacklist"></param>
        /// <returns>GraphDef containing a simplified version of the original.</returns>
        public GraphDef convert_variables_to_constants(Session sess,
                                                       GraphDef input_graph_def,
                                                       string[] output_node_names,
                                                       string[] variable_names_whitelist = null,
                                                       string[] variable_names_blacklist = null)
        {
            // This graph only includes the nodes needed to evaluate the output nodes, and
            // removes unneeded nodes like those involved in saving and assignment.
            var inference_graph = extract_sub_graph(input_graph_def, output_node_names);

            // Identify the ops in the graph.
            var map_name_to_node = new Dictionary <string, NodeDef>();

            inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray();

            // Get list of variables.
            var variable_names      = new List <string>();
            var variable_dict_names = new List <string>();

            foreach (var node in inference_graph.Node)
            {
                if (new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op))
                {
                    var variable_name = node.Name;

                    variable_dict_names.Add(variable_name);
                    if (node.Op == "VarHandleOp")
                    {
                        variable_names.Add(variable_name + "/Read/ReadVariableOp:0");
                    }
                    else
                    {
                        variable_names.Add(variable_name + ":0");
                    }
                }
                else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op))
                {
                    // There can be one or more Identity ops in between the ReadVariableOp and
                    // VarHandleOp.  Store the Identity ops with the associated dtypes.
                    var source_op_name = get_input_name(node);
                    while (map_name_to_node[source_op_name].Op == "Identity")
                    {
                        throw new NotImplementedException("map_name_to_node[source_op_name].Op");

                        /*resource_identity_types[source_op_name] = node.attr["dtype"];
                         * source_op_name = get_input_name(map_name_to_node[source_op_name]);*/
                    }
                }
            }

            // Gets map of variables and the associated data.
            NDArray returned_variables = null;

            if (variable_names != null)
            {
                returned_variables = sess.run(variable_names);
            }

            var variables_data_map = new Dictionary <string, NDArray>();

            foreach (var(i, name) in enumerate(variable_dict_names))
            {
                variables_data_map[name] = returned_variables[i];
            }
            print($"Froze {len(returned_variables)} variables.");

            // Reconstruct the graph with constants in place of variables.
            var output_graph_def   = new GraphDef();
            int how_many_converted = 0;

            foreach (var input_node in inference_graph.Node)
            {
                var output_node = new NodeDef();
                if (variables_data_map.ContainsKey(input_node.Name))
                {
                    var data = variables_data_map[input_node.Name];
                    output_node = create_const_op(input_node.Name, input_node.Attr["dtype"],
                                                  data, data.shape);
                    how_many_converted += 1;
                }
                // else if (resource_identity_types.ContainsKey(input_node.Name))
                else if (input_node.Op == "ReadVariableOp")
                {
                    output_node.Op   = "Identity";
                    output_node.Name = input_node.Name;
                    output_node.Input.AddRange(new[] { input_node.Input[0] });
                    output_node.Attr["T"] = input_node.Attr["dtype"];
                }
                else if (input_node.Op == "ResourceGather")
                {
                }
                else
                {
                    output_node.MergeFrom(input_node);
                }

                output_graph_def.Node.AddRange(new[] { output_node });
            }

            output_graph_def.Library = inference_graph.Library;
            print($"Converted {how_many_converted} variables to const ops.");
            return(output_graph_def);
        }