public override Graph BuildGraph()
        {
            string pb        = Path.Combine(modelDir, frozen_graph);
            var    graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(pb));

            // transform_graph
            graph_def = tf.graph_transforms.TransformGraph(graph_def,
                                                           new[] { "Placeholder" },
                                                           new[] { "Score" },
                                                           new[] { "remove_nodes(op=PlaceholderWithDefault)",
                                                                   "strip_unused_nodes(type=float, shape=\"1,28,28,1\")",
                                                                   "remove_nodes(op=Identity, op=CheckNumerics, op=Switch)",
                                                                   "fold_constants(ignore_errors=true)",
                                                                   "fold_batch_norms",
                                                                   "fold_old_batch_norms",
                                                                   "sort_by_execution_order" });

            // convert_to_constant
            var keep_prob     = tf.constant(1.0f, dtype: tf.float32, shape: new int[0], name: "keep_prob");
            var weight_factor = tf.constant(1.0f, dtype: tf.float32, shape: new int[0], name: "weight_factor");
            var is_training   = tf.constant(false, dtype: tf.@bool, shape: new int[0], name: "is_training");

            var new_graph_def = new GraphDef();

            foreach (var node in graph_def.Node)
            {
                switch (node.Name)
                {
                case "keep_prob":
                    new_graph_def.Node.Add(keep_prob.op.node_def);
                    break;

                case "weight_factor":
                    new_graph_def.Node.Add(weight_factor.op.node_def);
                    break;

                case "is_training":
                    new_graph_def.Node.Add(is_training.op.node_def);
                    break;

                default:
                    new_graph_def.Node.Add(node.Clone());
                    break;
                }
            }

            // optimize_batch_normalization
            graph_def     = new_graph_def;
            new_graph_def = new GraphDef();
            foreach (var node in graph_def.Node)
            {
                var modified_node = node.Clone();
                if (node.Name.StartsWith("conv"))
                {
                }
                else if (node.Name.StartsWith("fc") || node.Name.StartsWith("logits"))
                {
                }

                new_graph_def.Node.Add(modified_node);
            }

            // transform_graph
            new_graph_def = tf.graph_transforms.TransformGraph(graph_def,
                                                               new[] { "Placeholder" },
                                                               new[] { "Score" },
                                                               new[] { "remove_nodes(op=PlaceholderWithDefault)",
                                                                       "strip_unused_nodes(type=float, shape=\"1,28,28,1\")",
                                                                       "remove_nodes(op=Identity, op=CheckNumerics, op=Switch)",
                                                                       "fold_constants(ignore_errors=true)",
                                                                       "fold_batch_norms",
                                                                       "fold_old_batch_norms",
                                                                       "sort_by_execution_order" });

            // remove_dropout
            graph_def     = new_graph_def;
            new_graph_def = new GraphDef();
            foreach (var node in graph_def.Node)
            {
                var modified_node = node.Clone();
                if (node.Name.StartsWith("dropout1") || node.Name.StartsWith("dropout2"))
                {
                    continue;
                }

                if (node.Name == "fc2/fc2/batch_norm/batchnorm/mul_1")
                {
                    modified_node.Input[0] = "mul";
                    modified_node.Input[1] = "fc2/weights";
                }

                if (node.Name == "logits/logits/batch_norm/batchnorm/mul_1")
                {
                    modified_node.Input[0] = "fc2/activation";
                    modified_node.Input[1] = "logits/weights";
                }

                new_graph_def.Node.Add(modified_node);
            }

            // save the graph
            string output_pb = Path.Combine(modelDir, output_graph);

            File.WriteAllBytes(output_pb, new_graph_def.ToByteArray());
            return(null);
        }
Ejemplo n.º 2
0
 /// <summary>
 /// Dump GraphDef object into a Json File.
 /// </summary>
 /// <param name="datasource">GraphDef object</param>
 /// <param name="JsonFilePath">json File Path</param>
 void JsonDump(GraphDef datasource, string JsonFilePath)
 {
 }
Ejemplo n.º 3
0
        static void Main(string[] args)
        {
            //图加载
            using (var graph = new TFGraph())
            {
                //graph.Import(File.ReadAllBytes("saved_model.pb"));
            }
            //TFBuff编解码
            var hello  = Encoding.UTF8.GetBytes("Hello, world!");
            var buffer = new TFBuffer(hello);
            var bytes  = buffer.ToArray();

            Console.WriteLine(Encoding.UTF8.GetString(bytes));

            //方案1
            //string stringdata = File.ReadAllText(FileName);
            //string [] s=stringdata.Split('\n');
            //FileStream fs = new FileStream("E:\\VisualStudio\\Expressior\\src\\Experimental\\mnist\\out\\model_spectext\\csout.txt", FileMode.Create);
            //StreamWriter sw =new StreamWriter(fs,Encoding.UTF8);
            //for (int i = 0; i <s.Length; i++)
            //{
            //    sw.WriteLine(s[i]);
            //}

            //方案2
            //byte[] data = File.ReadAllBytes(inFile);
            //var datastring = "";
            FileStream   fs = new FileStream(outFile, FileMode.Create);
            StreamWriter sw = new StreamWriter(fs);

            //for (int i = 0; i < 1000; i++)
            //{
            //    byte[] str = new byte[2];
            //    str[0] = data[i];
            //    datastring = Encoding.Unicode.GetString(str);
            //    sw.Write(datastring);
            //}
            //方案3:proto编解码
            byte[]       data    = File.ReadAllBytes(inFile);
            MemoryStream ms1     = new MemoryStream(data);
            GraphDef     Mygraph = Deserialize <GraphDef>(ms1);



            sw.Flush();
            sw.Close();
            fs.Close();

            //String temp = graph.ToString();
            ////Console.WriteLine(graph);

            ////获得字节数组
            //TFBuffer outputGraphDef=new TFBuffer();
            //graph.ToGraphDef(outputGraphDef);


            ////String str = System.Text.Encoding.ASCII.GetString(outputGraphDef.ToArray());
            //byte[] data = outputGraphDef.ToArray();
            ////string result = Encoding.GetEncoding("ascii").GetString(data);
            ////Google.Protobuf.ByteString byteString = Google.Protobuf.ByteString.CopyFrom(data, 0, data.Length);

            ////byteString.WriteTo(fs);

            ////开始写入
            //String tmp2 = Encoding.UTF8.GetString(data);
            //sw.Write(stringdata);



            ////开始写入
            //fs.Write(data, 0, data.Length);
            ////清空缓冲区、关闭流
            //fs.Flush();
            //fs.Close();
        }
Ejemplo n.º 4
0
 /// <summary>
 /// dump GraphDef into Json File
 /// </summary>
 /// <param name="data">GraphDef object</param>
 /// <param name="jsonpath">File Path for json</param>
 public static void DumpJson(GraphDef data, string jsonpath)
 {
 }