public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary <string, AttrValue> attrs = null) { var node_def = new node_def_pb2.NodeDef(); node_def.Op = op_type; node_def.Name = name; foreach (var attr in attrs) { node_def.Attr.Add(attr.Key, attr.Value); } return(node_def); }
private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null) { var output_node = new NodeDef { Op = "Const", Name = node_name }; output_node.Attr["dtype"] = dtype; output_node.Attr["value"] = new AttrValue() { Tensor = tensor_util.make_tensor_proto( data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape) }; return(output_node); }
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List <Tensor> inputs) { var op_desc = graph.NewOperation(node_def.Op, node_def.Name); // Add inputs if (inputs != null) { foreach (var op_input in inputs) { bool isList = false; if (!isList) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); } else { c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); } } } var status = new Status(); // Add control inputs // Add attrs foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); status.Check(true); } var c_op = c_api.TF_FinishOperation(op_desc, status); status.Check(true); return(c_op); }
/// <summary> /// Update the input to this operation at the given index. /// /// NOTE: This is for TF internal use only.Please don't use it. /// </summary> /// <param name="index">the index of the input to update.</param> /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> public void _update_input(int index, Tensor tensor) { _assert_same_graph(tensor); var input = _tf_input(index); var output = tensor._as_tf_output(); // Reset cached inputs. _inputs_val = null; _node_def = null; // after the c_api call next time _inputs is accessed // the updated inputs are reloaded from the c_api lock (Locks.ProcessWide) { // disable // c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle); //var updated_inputs = inputs; tf.Status.Check(); } }
public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { Graph = g; _id_value = Graph._next_id(); if (op_def == null) { op_def = g.GetOpDef(node_def.Op); } _handle = ops._create_c_op(g, node_def, inputs); output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { output_types[i] = OutputType(i); } Graph._add_op(this); }
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List <Tensor> inputs) { var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); // Add inputs if (inputs != null) { foreach (var op_input in inputs) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); } } var status = new Status(); // Add control inputs // Add attrs foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); if (status.Code != TF_Code.TF_OK) { throw new Exception(status.Message); } } var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); if (status.Code != TF_Code.TF_OK) { throw new Exception(status.Message); } return(c_op); }
private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) { foreach (var attr_def in op_def.Attr) { var key = attr_def.Name; if (attr_def.DefaultValue != null) { if (node_def.Attr.ContainsKey(key)) { var value = node_def.Attr[key]; if (value == null) { node_def.Attr[key] = attr_def.DefaultValue; } } else { node_def.Attr[key] = attr_def.DefaultValue; } } } }
/// <summary> /// Creates an `Operation`. /// </summary> /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> /// <param name="g">`Graph`. The parent graph.</param> /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> /// <param name="output_types">list of `DType` objects.</param> /// <param name="control_inputs"> /// list of operations or tensors from which to have a /// control dependency. /// </param> /// <param name="input_types"> /// List of `DType` objects representing the /// types of the tensors accepted by the `Operation`. By default /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect /// reference-typed inputs must specify these explicitly. /// </param> /// <param name="original_op"></param> /// <param name="op_def"></param> public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { _graph = g; // Build the list of control inputs. var control_input_ops = new List <Operation>(); if (control_inputs != null) { foreach (var c in control_inputs) { switch (c) { case Operation c1: control_input_ops.Add(c1); break; case Tensor tensor: control_input_ops.Add(tensor.op); break; // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented //case IndexedSlices islices: // control_input_ops.Add(islices.op); // break; default: throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); } } } // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); // This will be set by self.inputs. if (op_def == null) { op_def = g.GetOpDef(node_def.Op); } var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { output_types[i] = OutputType(i); } _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { _outputs[i] = new Tensor(this, i, OutputType(i)); } graph._add_op(this); if (_handle != IntPtr.Zero) { _control_flow_post_processing(); } }
/// <summary> /// Replaces all the variables in a graph with constants of the same values. /// </summary> /// <param name="sess">Active TensorFlow session containing the variables.</param> /// <param name="input_graph_def">GraphDef object holding the network.</param> /// <param name="output_node_names">List of name strings for the result nodes of the graph.</param> /// <param name="variable_names_whitelist"></param> /// <param name="variable_names_blacklist"></param> /// <returns>GraphDef containing a simplified version of the original.</returns> public GraphDef convert_variables_to_constants(Session sess, GraphDef input_graph_def, string[] output_node_names, string[] variable_names_whitelist = null, string[] variable_names_blacklist = null) { // This graph only includes the nodes needed to evaluate the output nodes, and // removes unneeded nodes like those involved in saving and assignment. var inference_graph = extract_sub_graph(input_graph_def, output_node_names); // Identify the ops in the graph. var map_name_to_node = new Dictionary <string, NodeDef>(); inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray(); // Get list of variables. var variable_names = new List <string>(); var variable_dict_names = new List <string>(); foreach (var node in inference_graph.Node) { if (new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op)) { var variable_name = node.Name; variable_dict_names.Add(variable_name); if (node.Op == "VarHandleOp") { variable_names.Add(variable_name + "/Read/ReadVariableOp:0"); } else { variable_names.Add(variable_name + ":0"); } } else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op)) { // There can be one or more Identity ops in between the ReadVariableOp and // VarHandleOp. Store the Identity ops with the associated dtypes. var source_op_name = get_input_name(node); while (map_name_to_node[source_op_name].Op == "Identity") { throw new NotImplementedException("map_name_to_node[source_op_name].Op"); /*resource_identity_types[source_op_name] = node.attr["dtype"]; * source_op_name = get_input_name(map_name_to_node[source_op_name]);*/ } } } // Gets map of variables and the associated data. NDArray returned_variables = null; if (variable_names != null) { returned_variables = sess.run(variable_names); } var variables_data_map = new Dictionary <string, NDArray>(); foreach (var(i, name) in enumerate(variable_dict_names)) { variables_data_map[name] = returned_variables[i]; } print($"Froze {len(returned_variables)} variables."); // Reconstruct the graph with constants in place of variables. var output_graph_def = new GraphDef(); int how_many_converted = 0; foreach (var input_node in inference_graph.Node) { var output_node = new NodeDef(); if (variables_data_map.ContainsKey(input_node.Name)) { var data = variables_data_map[input_node.Name]; output_node = create_const_op(input_node.Name, input_node.Attr["dtype"], data, data.shape); how_many_converted += 1; } // else if (resource_identity_types.ContainsKey(input_node.Name)) else if (input_node.Op == "ReadVariableOp") { output_node.Op = "Identity"; output_node.Name = input_node.Name; output_node.Input.AddRange(new[] { input_node.Input[0] }); output_node.Attr["T"] = input_node.Attr["dtype"]; } else if (input_node.Op == "ResourceGather") { } else { output_node.MergeFrom(input_node); } output_graph_def.Node.AddRange(new[] { output_node }); } output_graph_def.Library = inference_graph.Library; print($"Converted {how_many_converted} variables to const ops."); return(output_graph_def); }