public static int Add(int a, int b, SessionOptions sessionOptions = null) { //Create a Tensor from value "a" Tensor tensorA = new Tensor(a); //Create a Tensor from value "b" Tensor tensorB = new Tensor(b); //create a new Graph Graph graph = new Graph(); //place holder for tensorA Operation opA = graph.Placeholder(DataType.Int32, null, "valA"); //place holder for tensorB Operation opB = graph.Placeholder(DataType.Int32, null, "valB"); //The actual operation Operation sumOp = graph.Add(opA, opB, "sum"); using (Buffer versionDef = new Buffer()) using (Buffer graphDef = new Buffer()) { graph.Versions(versionDef); Tensorflow.VersionDef vdef = Tensorflow.VersionDef.Parser.ParseFrom(versionDef.Data); graph.ToGraphDef(graphDef); Tensorflow.GraphDef gdef = Tensorflow.GraphDef.Parser.ParseFrom(graphDef.Data); using (MemoryStream ms = new MemoryStream()) using (Google.Protobuf.CodedOutputStream stream = new CodedOutputStream(ms)) { gdef.WriteTo(stream); stream.Flush(); byte[] serializedGraphDef2 = ms.ToArray(); byte[] serializedGraphDef1 = graphDef.Data; bool equal = true; for (int i = 0; i < serializedGraphDef1.Length; i++) { if (serializedGraphDef1[i] != serializedGraphDef2[i]) { equal = false; } } } foreach (Operation op in graph) { String device = op.Device; } } Session session = new Session(graph, sessionOptions); Session.Device[] devices = session.ListDevices(); Tensor[] results = session.Run(new Output[] { opA, opB }, new Tensor[] { tensorA, tensorB }, new Output[] { sumOp }); return(results[0].Flat <int>()[0]); }
public MetaGraphDef export_meta_graph(string filename = "", byte[] meta_info_def = null, GraphDef graph_def = null, SaverDef saver_def = null, string[] collection_list = null, bool as_text = false, bool clear_devices = false, bool clear_extraneous_savers = false, bool strip_default_attrs = false, string export_scope = "") { var meta_graph_def = meta_graph.export_scoped_meta_graph( filename: filename, meta_info_def: meta_info_def, graph_def: graph_def, saver_def: saver_def, // collection_list: collection_list, as_text: as_text, clear_devices: clear_devices, clear_extraneous_savers: clear_extraneous_savers, strip_default_attrs: strip_default_attrs); return(meta_graph_def.Item1); }
/// <summary> /// Returns `MetaGraphDef` proto. Optionally writes it to filename. /// </summary> /// <param name="filename"></param> /// <param name="graph_def"></param> /// <param name="as_text"></param> /// <param name="unbound_inputs_col_name"></param> /// <param name="clear_devices"></param> /// <param name="saver_def"></param> /// <param name="clear_extraneous_savers"></param> /// <param name="strip_default_attrs"></param> /// <param name="meta_info_def"></param> /// <returns></returns> public static (MetaGraphDef, Dictionary <string, IVariableV1>) export_scoped_meta_graph(string filename = "", GraphDef graph_def = null, bool as_text = false, string unbound_inputs_col_name = "unbound_inputs", bool clear_devices = false, SaverDef saver_def = null, bool clear_extraneous_savers = false, bool strip_default_attrs = false, byte[] meta_info_def = null) { var graph = ops.get_default_graph(); var var_list = new Dictionary <string, IVariableV1>(); var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); if (variables != null) { foreach (var v in variables) { var_list[v.Name] = v; } } var scoped_meta_graph_def = create_meta_graph_def( graph_def: graph_def, export_scope: "", exclude_nodes: "", clear_extraneous_savers: clear_extraneous_savers, saver_def: saver_def, strip_default_attrs: strip_default_attrs); if (!string.IsNullOrEmpty(filename)) { graph_io.write_graph(scoped_meta_graph_def, "", filename, as_text: as_text); } return(scoped_meta_graph_def, var_list); }
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary <string, Tensor> input_map = null, string[] return_elements = null, string name = null, OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
private (Dictionary <string, string[]>, Dictionary <string, NodeDef>, Dictionary <string, int>) _extract_graph_summary(GraphDef graph_def) { var name_to_input_name = new Dictionary <string, string[]>(); var name_to_node = new Dictionary <string, NodeDef>(); var name_to_seq_num = new Dictionary <string, int>(); int seq = 0; foreach (var node in graph_def.Node) { var n = _node_name(node.Name); name_to_node[n] = node; name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray(); name_to_seq_num[n] = seq; seq++; } return(name_to_input_name, name_to_node, name_to_seq_num); }
/// <summary> /// Replaces all the variables in a graph with constants of the same values. /// </summary> /// <param name="sess">Active TensorFlow session containing the variables.</param> /// <param name="input_graph_def">GraphDef object holding the network.</param> /// <param name="output_node_names">List of name strings for the result nodes of the graph.</param> /// <param name="variable_names_whitelist"></param> /// <param name="variable_names_blacklist"></param> /// <returns>GraphDef containing a simplified version of the original.</returns> public GraphDef convert_variables_to_constants(Session sess, GraphDef input_graph_def, string[] output_node_names, string[] variable_names_whitelist = null, string[] variable_names_blacklist = null) { // This graph only includes the nodes needed to evaluate the output nodes, and // removes unneeded nodes like those involved in saving and assignment. var inference_graph = extract_sub_graph(input_graph_def, output_node_names); // Identify the ops in the graph. var map_name_to_node = new Dictionary <string, NodeDef>(); inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray(); // Get list of variables. var variable_names = new List <string>(); var variable_dict_names = new List <string>(); foreach (var node in inference_graph.Node) { if (new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op)) { var variable_name = node.Name; variable_dict_names.Add(variable_name); if (node.Op == "VarHandleOp") { variable_names.Add(variable_name + "/Read/ReadVariableOp:0"); } else { variable_names.Add(variable_name + ":0"); } } else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op)) { // There can be one or more Identity ops in between the ReadVariableOp and // VarHandleOp. Store the Identity ops with the associated dtypes. var source_op_name = get_input_name(node); while (map_name_to_node[source_op_name].Op == "Identity") { throw new NotImplementedException("map_name_to_node[source_op_name].Op"); /*resource_identity_types[source_op_name] = node.attr["dtype"]; * source_op_name = get_input_name(map_name_to_node[source_op_name]);*/ } } } // Gets map of variables and the associated data. NDArray returned_variables = null; if (variable_names != null) { returned_variables = sess.run(variable_names); } var variables_data_map = new Dictionary <string, NDArray>(); foreach (var(i, name) in enumerate(variable_dict_names)) { variables_data_map[name] = returned_variables[i]; } print($"Froze {len(returned_variables)} variables."); // Reconstruct the graph with constants in place of variables. var output_graph_def = new GraphDef(); int how_many_converted = 0; foreach (var input_node in inference_graph.Node) { var output_node = new NodeDef(); if (variables_data_map.ContainsKey(input_node.Name)) { var data = variables_data_map[input_node.Name]; output_node = create_const_op(input_node.Name, input_node.Attr["dtype"], data, data.shape); how_many_converted += 1; } // else if (resource_identity_types.ContainsKey(input_node.Name)) else if (input_node.Op == "ReadVariableOp") { output_node.Op = "Identity"; output_node.Name = input_node.Name; output_node.Input.AddRange(new[] { input_node.Input[0] }); output_node.Attr["T"] = input_node.Attr["dtype"]; } else if (input_node.Op == "ResourceGather") { } else { output_node.MergeFrom(input_node); } output_graph_def.Node.AddRange(new[] { output_node }); } output_graph_def.Library = inference_graph.Library; print($"Converted {how_many_converted} variables to const ops."); return(output_graph_def); }
private static void _RemoveDefaultAttrs(Dictionary <string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def) { var producer_op_dict = new Dictionary <string, OpDef>(); producer_op_list.Op.Select(op => { producer_op_dict[op.Name] = op; return(op); }).ToArray(); foreach (var node in graph_def.Node) { // Remove any default attr values that aren't in op_def. if (producer_op_dict.ContainsKey(node.Op)) { var op_def = op_dict[node.Op]; var producer_op_def = producer_op_dict[node.Op]; foreach (var key in node.Attr) { if (_FindAttrInOpDef(key.Key, op_def) == null) { var attr_def = _FindAttrInOpDef(key.Key, producer_op_def); if (attr_def != null && attr_def.DefaultValue != null && node.Attr[key.Key] == attr_def.DefaultValue) { node.Attr[key.Key].ClearValue(); } } } } } }