public override Graph BuildGraph() { string pb = Path.Combine(modelDir, frozen_graph); var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(pb)); // transform_graph graph_def = tf.graph_transforms.TransformGraph(graph_def, new[] { "Placeholder" }, new[] { "Score" }, new[] { "remove_nodes(op=PlaceholderWithDefault)", "strip_unused_nodes(type=float, shape=\"1,28,28,1\")", "remove_nodes(op=Identity, op=CheckNumerics, op=Switch)", "fold_constants(ignore_errors=true)", "fold_batch_norms", "fold_old_batch_norms", "sort_by_execution_order" }); // convert_to_constant var keep_prob = tf.constant(1.0f, dtype: tf.float32, shape: new int[0], name: "keep_prob"); var weight_factor = tf.constant(1.0f, dtype: tf.float32, shape: new int[0], name: "weight_factor"); var is_training = tf.constant(false, dtype: tf.@bool, shape: new int[0], name: "is_training"); var new_graph_def = new GraphDef(); foreach (var node in graph_def.Node) { switch (node.Name) { case "keep_prob": new_graph_def.Node.Add(keep_prob.op.node_def); break; case "weight_factor": new_graph_def.Node.Add(weight_factor.op.node_def); break; case "is_training": new_graph_def.Node.Add(is_training.op.node_def); break; default: new_graph_def.Node.Add(node.Clone()); break; } } // optimize_batch_normalization graph_def = new_graph_def; new_graph_def = new GraphDef(); foreach (var node in graph_def.Node) { var modified_node = node.Clone(); if (node.Name.StartsWith("conv")) { } else if (node.Name.StartsWith("fc") || node.Name.StartsWith("logits")) { } new_graph_def.Node.Add(modified_node); } // transform_graph new_graph_def = tf.graph_transforms.TransformGraph(graph_def, new[] { "Placeholder" }, new[] { "Score" }, new[] { "remove_nodes(op=PlaceholderWithDefault)", "strip_unused_nodes(type=float, shape=\"1,28,28,1\")", "remove_nodes(op=Identity, op=CheckNumerics, op=Switch)", "fold_constants(ignore_errors=true)", "fold_batch_norms", "fold_old_batch_norms", "sort_by_execution_order" }); // remove_dropout graph_def = new_graph_def; new_graph_def = new GraphDef(); foreach (var node in graph_def.Node) { var modified_node = node.Clone(); if (node.Name.StartsWith("dropout1") || node.Name.StartsWith("dropout2")) { continue; } if (node.Name == "fc2/fc2/batch_norm/batchnorm/mul_1") { modified_node.Input[0] = "mul"; modified_node.Input[1] = "fc2/weights"; } if (node.Name == "logits/logits/batch_norm/batchnorm/mul_1") { modified_node.Input[0] = "fc2/activation"; modified_node.Input[1] = "logits/weights"; } new_graph_def.Node.Add(modified_node); } // save the graph string output_pb = Path.Combine(modelDir, output_graph); File.WriteAllBytes(output_pb, new_graph_def.ToByteArray()); return(null); }
/// <summary> /// Dump GraphDef object into a Json File. /// </summary> /// <param name="datasource">GraphDef object</param> /// <param name="JsonFilePath">json File Path</param> void JsonDump(GraphDef datasource, string JsonFilePath) { }
static void Main(string[] args) { //图加载 using (var graph = new TFGraph()) { //graph.Import(File.ReadAllBytes("saved_model.pb")); } //TFBuff编解码 var hello = Encoding.UTF8.GetBytes("Hello, world!"); var buffer = new TFBuffer(hello); var bytes = buffer.ToArray(); Console.WriteLine(Encoding.UTF8.GetString(bytes)); //方案1 //string stringdata = File.ReadAllText(FileName); //string [] s=stringdata.Split('\n'); //FileStream fs = new FileStream("E:\\VisualStudio\\Expressior\\src\\Experimental\\mnist\\out\\model_spectext\\csout.txt", FileMode.Create); //StreamWriter sw =new StreamWriter(fs,Encoding.UTF8); //for (int i = 0; i <s.Length; i++) //{ // sw.WriteLine(s[i]); //} //方案2 //byte[] data = File.ReadAllBytes(inFile); //var datastring = ""; FileStream fs = new FileStream(outFile, FileMode.Create); StreamWriter sw = new StreamWriter(fs); //for (int i = 0; i < 1000; i++) //{ // byte[] str = new byte[2]; // str[0] = data[i]; // datastring = Encoding.Unicode.GetString(str); // sw.Write(datastring); //} //方案3:proto编解码 byte[] data = File.ReadAllBytes(inFile); MemoryStream ms1 = new MemoryStream(data); GraphDef Mygraph = Deserialize <GraphDef>(ms1); sw.Flush(); sw.Close(); fs.Close(); //String temp = graph.ToString(); ////Console.WriteLine(graph); ////获得字节数组 //TFBuffer outputGraphDef=new TFBuffer(); //graph.ToGraphDef(outputGraphDef); ////String str = System.Text.Encoding.ASCII.GetString(outputGraphDef.ToArray()); //byte[] data = outputGraphDef.ToArray(); ////string result = Encoding.GetEncoding("ascii").GetString(data); ////Google.Protobuf.ByteString byteString = Google.Protobuf.ByteString.CopyFrom(data, 0, data.Length); ////byteString.WriteTo(fs); ////开始写入 //String tmp2 = Encoding.UTF8.GetString(data); //sw.Write(stringdata); ////开始写入 //fs.Write(data, 0, data.Length); ////清空缓冲区、关闭流 //fs.Flush(); //fs.Close(); }
/// <summary> /// dump GraphDef into Json File /// </summary> /// <param name="data">GraphDef object</param> /// <param name="jsonpath">File Path for json</param> public static void DumpJson(GraphDef data, string jsonpath) { }