private static void _RemoveDefaultAttrs(Dictionary <string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def) { var producer_op_dict = new Dictionary <string, OpDef>(); producer_op_list.Op.Select(op => { producer_op_dict[op.Name] = op; return(op); }).ToArray(); foreach (var node in graph_def.Node) { // Remove any default attr values that aren't in op_def. if (producer_op_dict.ContainsKey(node.Op)) { var op_def = op_dict[node.Op]; var producer_op_def = producer_op_dict[node.Op]; foreach (var key in node.Attr) { if (_FindAttrInOpDef(key.Key, op_def) == null) { var attr_def = _FindAttrInOpDef(key.Key, producer_op_def); if (attr_def != null && attr_def.DefaultValue != null && node.Attr[key.Key] == attr_def.DefaultValue) { node.Attr[key.Key].ClearValue(); } } } } } }
public void add_op_list(OpList op_list) { foreach (var op_def in op_list.Op) { add_op(op_def); } }
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary <string, Tensor> input_map = null, string[] return_elements = null, string name = null, OpList producer_op_list = null) { var op_dict = op_def_registry.get_registered_ops(); graph_def = _ProcessGraphDefParam(graph_def, op_dict); input_map = _ProcessInputMapParam(input_map); return_elements = _ProcessReturnElementsParam(return_elements); if (producer_op_list != null) { _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def); } string prefix = ""; var graph = ops.get_default_graph(); tf_with(ops.name_scope(name, "import", input_map.Values), scope => { prefix = scope; /*if (!string.IsNullOrEmpty(prefix)) * prefix = prefix.Substring(0, prefix.Length - 1); * else * prefix = "";*/ // Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map); }); TF_ImportGraphDefResults results = null; var bytes = graph_def.ToByteString().ToArray(); using (var buffer = c_api_util.tf_buffer(bytes)) using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) using (var status = new Status()) { _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); // need to create a class ImportGraphDefWithResults with IDisposal results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); status.Check(true); } _ProcessNewOps(graph); if (return_elements == null) { return(null); } else { return(_GatherReturnElements(return_elements, graph, results)); } }
private static OpList stripped_op_list_for_graph(GraphDef graph_def) { var used_ops = ops_used_by_graph_def(graph_def); // Verify that all used ops are registered. // var registered_ops = op_def_registry.get_registered_ops(); var op_list = new OpList(); /*used_ops.OrderBy(x => x).Select(x => { * * }).ToArray();*/ return(op_list); }
public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, bool clear_devices = false, string import_scope = "", Dictionary <string, Tensor> input_map = null, string unbound_inputs_col_name = "unbound_inputs", string[] return_elements = null) { var meta_graph_def = meta_graph_or_file; if (!string.IsNullOrEmpty(unbound_inputs_col_name)) { foreach (var col in meta_graph_def.CollectionDef) { if (col.Key == unbound_inputs_col_name) { throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } } // Sets graph to default graph if it's not passed in. var graph = ops.get_default_graph(); // Gathers the list of nodes we are interested in. OpList producer_op_list = null; if (meta_graph_def.MetaInfoDef.StrippedOpList != null) { producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; } var input_graph_def = meta_graph_def.GraphDef; // Remove all the explicit device specifications for this node. This helps to // make the graph more portable. if (clear_devices) { foreach (var node in input_graph_def.Node) { node.Device = ""; } } var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); var imported_return_elements = importer.import_graph_def(input_graph_def, name: scope_to_prepend_to_names, input_map: input_map, producer_op_list: producer_op_list, return_elements: return_elements); // Restores all the other collections. var variable_objects = new Dictionary <ByteString, IVariableV1>(); foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) { // Don't add unbound_inputs to the new graph. if (col.Key == unbound_inputs_col_name) { continue; } switch (col.Value.KindCase) { case KindOneofCase.NodeList: foreach (var value in col.Value.NodeList.Value) { var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); graph.add_to_collection(col.Key, col_op); } break; case KindOneofCase.BytesList: //var proto_type = ops.get_collection_proto_type(key) if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) { foreach (var value in col.Value.BytesList.Value) { IVariableV1 variable = null; if (!variable_objects.ContainsKey(value)) { var proto = VariableDef.Parser.ParseFrom(value); if (proto.IsResource) { variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); } else { variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); } variable_objects[value] = variable; } variable = variable_objects[value]; graph.add_to_collection(col.Key, variable); } } else { foreach (var value in col.Value.BytesList.Value) { switch (col.Key) { case "cond_context": { var proto = CondContextDef.Parser.ParseFrom(value); var condContext = new CondContext().from_proto(proto, import_scope); graph.add_to_collection(col.Key, condContext); } break; case "while_context": { var proto = WhileContextDef.Parser.ParseFrom(value); var whileContext = new WhileContext().from_proto(proto, import_scope); graph.add_to_collection(col.Key, whileContext); } break; default: Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); continue; } } } break; default: Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); break; } } var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope: scope_to_prepend_to_names); var var_list = new Dictionary <string, IVariableV1>(); variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); return(var_list, imported_return_elements); }
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary <string, Tensor> input_map = null, string[] return_elements = null, string name = null, OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, bool clear_devices = false, string import_scope = "", Dictionary <string, Tensor> input_map = null, string unbound_inputs_col_name = "unbound_inputs", string[] return_elements = null) { var meta_graph_def = meta_graph_or_file; if (!string.IsNullOrEmpty(unbound_inputs_col_name)) { foreach (var col in meta_graph_def.CollectionDef) { if (col.Key == unbound_inputs_col_name) { throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } } // Sets graph to default graph if it's not passed in. var graph = ops.get_default_graph(); // Gathers the list of nodes we are interested in. OpList producer_op_list = null; if (meta_graph_def.MetaInfoDef.StrippedOpList != null) { producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; } var input_graph_def = meta_graph_def.GraphDef; // Remove all the explicit device specifications for this node. This helps to // make the graph more portable. if (clear_devices) { foreach (var node in input_graph_def.Node) { node.Device = ""; } } var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); var imported_return_elements = importer.import_graph_def(input_graph_def, name: scope_to_prepend_to_names, input_map: input_map, producer_op_list: producer_op_list, return_elements: return_elements); // Restores all the other collections. var variable_objects = new Dictionary <string, RefVariable>(); foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) { // Don't add unbound_inputs to the new graph. if (col.Key == unbound_inputs_col_name) { continue; } switch (col.Value.KindCase) { case KindOneofCase.NodeList: foreach (var value in col.Value.NodeList.Value) { var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); graph.add_to_collection(col.Key, col_op); } break; case KindOneofCase.BytesList: //var proto_type = ops.get_collection_proto_type(key) if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) { foreach (var value in col.Value.BytesList.Value) { var proto = VariableDef.Parser.ParseFrom(value); throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } else { throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } break; } } return(null, null); }