Пример #1
0
        public static int Add(int a, int b, SessionOptions sessionOptions = null)
        {
            //Create a Tensor from value "a"
            Tensor tensorA = new Tensor(a);
            //Create a Tensor from value "b"
            Tensor tensorB = new Tensor(b);

            //create a new Graph
            Graph graph = new Graph();
            //place holder for tensorA
            Operation opA = graph.Placeholder(DataType.Int32, null, "valA");
            //place holder for tensorB
            Operation opB = graph.Placeholder(DataType.Int32, null, "valB");
            //The actual operation
            Operation sumOp = graph.Add(opA, opB, "sum");

            using (Buffer versionDef = new Buffer())
                using (Buffer graphDef = new Buffer())
                {
                    graph.Versions(versionDef);

                    Tensorflow.VersionDef vdef = Tensorflow.VersionDef.Parser.ParseFrom(versionDef.Data);

                    graph.ToGraphDef(graphDef);

                    Tensorflow.GraphDef gdef = Tensorflow.GraphDef.Parser.ParseFrom(graphDef.Data);

                    using (MemoryStream ms = new MemoryStream())
                        using (Google.Protobuf.CodedOutputStream stream = new CodedOutputStream(ms))
                        {
                            gdef.WriteTo(stream);
                            stream.Flush();
                            byte[] serializedGraphDef2 = ms.ToArray();

                            byte[] serializedGraphDef1 = graphDef.Data;

                            bool equal = true;
                            for (int i = 0; i < serializedGraphDef1.Length; i++)
                            {
                                if (serializedGraphDef1[i] != serializedGraphDef2[i])
                                {
                                    equal = false;
                                }
                            }
                        }

                    foreach (Operation op in graph)
                    {
                        String device = op.Device;
                    }
                }
            Session session = new Session(graph, sessionOptions);

            Session.Device[] devices = session.ListDevices();
            Tensor[]         results = session.Run(new Output[] { opA, opB }, new Tensor[] { tensorA, tensorB }, new Output[] { sumOp });
            return(results[0].Flat <int>()[0]);
        }
Пример #2
0
        public MetaGraphDef export_meta_graph(string filename              = "",
                                              byte[] meta_info_def         = null,
                                              GraphDef graph_def           = null,
                                              SaverDef saver_def           = null,
                                              string[] collection_list     = null,
                                              bool as_text                 = false,
                                              bool clear_devices           = false,
                                              bool clear_extraneous_savers = false,
                                              bool strip_default_attrs     = false,
                                              string export_scope          = "")
        {
            var meta_graph_def = meta_graph.export_scoped_meta_graph(
                filename: filename,
                meta_info_def: meta_info_def,
                graph_def: graph_def,
                saver_def: saver_def,
                // collection_list: collection_list,
                as_text: as_text,
                clear_devices: clear_devices,
                clear_extraneous_savers: clear_extraneous_savers,
                strip_default_attrs: strip_default_attrs);

            return(meta_graph_def.Item1);
        }
Пример #3
0
        /// <summary>
        /// Returns `MetaGraphDef` proto. Optionally writes it to filename.
        /// </summary>
        /// <param name="filename"></param>
        /// <param name="graph_def"></param>
        /// <param name="as_text"></param>
        /// <param name="unbound_inputs_col_name"></param>
        /// <param name="clear_devices"></param>
        /// <param name="saver_def"></param>
        /// <param name="clear_extraneous_savers"></param>
        /// <param name="strip_default_attrs"></param>
        /// <param name="meta_info_def"></param>
        /// <returns></returns>
        public static (MetaGraphDef, Dictionary <string, IVariableV1>) export_scoped_meta_graph(string filename                = "",
                                                                                                GraphDef graph_def             = null,
                                                                                                bool as_text                   = false,
                                                                                                string unbound_inputs_col_name = "unbound_inputs",
                                                                                                bool clear_devices             = false,
                                                                                                SaverDef saver_def             = null,
                                                                                                bool clear_extraneous_savers   = false,
                                                                                                bool strip_default_attrs       = false,
                                                                                                byte[] meta_info_def           = null)
        {
            var graph = ops.get_default_graph();

            var var_list  = new Dictionary <string, IVariableV1>();
            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);

            if (variables != null)
            {
                foreach (var v in variables)
                {
                    var_list[v.Name] = v;
                }
            }

            var scoped_meta_graph_def = create_meta_graph_def(
                graph_def: graph_def,
                export_scope: "",
                exclude_nodes: "",
                clear_extraneous_savers: clear_extraneous_savers,
                saver_def: saver_def,
                strip_default_attrs: strip_default_attrs);

            if (!string.IsNullOrEmpty(filename))
            {
                graph_io.write_graph(scoped_meta_graph_def, "", filename, as_text: as_text);
            }

            return(scoped_meta_graph_def, var_list);
        }
Пример #4
0
 public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
                                              Dictionary <string, Tensor> input_map = null,
                                              string[] return_elements = null,
                                              string name             = null,
                                              OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
Пример #5
0
        private (Dictionary <string, string[]>, Dictionary <string, NodeDef>, Dictionary <string, int>) _extract_graph_summary(GraphDef graph_def)
        {
            var name_to_input_name = new Dictionary <string, string[]>();
            var name_to_node       = new Dictionary <string, NodeDef>();
            var name_to_seq_num    = new Dictionary <string, int>();

            int seq = 0;

            foreach (var node in graph_def.Node)
            {
                var n = _node_name(node.Name);
                name_to_node[n]       = node;
                name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray();
                name_to_seq_num[n]    = seq;
                seq++;
            }

            return(name_to_input_name, name_to_node, name_to_seq_num);
        }
Пример #6
0
        /// <summary>
        /// Replaces all the variables in a graph with constants of the same values.
        /// </summary>
        /// <param name="sess">Active TensorFlow session containing the variables.</param>
        /// <param name="input_graph_def">GraphDef object holding the network.</param>
        /// <param name="output_node_names">List of name strings for the result nodes of the graph.</param>
        /// <param name="variable_names_whitelist"></param>
        /// <param name="variable_names_blacklist"></param>
        /// <returns>GraphDef containing a simplified version of the original.</returns>
        public GraphDef convert_variables_to_constants(Session sess,
                                                       GraphDef input_graph_def,
                                                       string[] output_node_names,
                                                       string[] variable_names_whitelist = null,
                                                       string[] variable_names_blacklist = null)
        {
            // This graph only includes the nodes needed to evaluate the output nodes, and
            // removes unneeded nodes like those involved in saving and assignment.
            var inference_graph = extract_sub_graph(input_graph_def, output_node_names);

            // Identify the ops in the graph.
            var map_name_to_node = new Dictionary <string, NodeDef>();

            inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray();

            // Get list of variables.
            var variable_names      = new List <string>();
            var variable_dict_names = new List <string>();

            foreach (var node in inference_graph.Node)
            {
                if (new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op))
                {
                    var variable_name = node.Name;

                    variable_dict_names.Add(variable_name);
                    if (node.Op == "VarHandleOp")
                    {
                        variable_names.Add(variable_name + "/Read/ReadVariableOp:0");
                    }
                    else
                    {
                        variable_names.Add(variable_name + ":0");
                    }
                }
                else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op))
                {
                    // There can be one or more Identity ops in between the ReadVariableOp and
                    // VarHandleOp.  Store the Identity ops with the associated dtypes.
                    var source_op_name = get_input_name(node);
                    while (map_name_to_node[source_op_name].Op == "Identity")
                    {
                        throw new NotImplementedException("map_name_to_node[source_op_name].Op");

                        /*resource_identity_types[source_op_name] = node.attr["dtype"];
                         * source_op_name = get_input_name(map_name_to_node[source_op_name]);*/
                    }
                }
            }

            // Gets map of variables and the associated data.
            NDArray returned_variables = null;

            if (variable_names != null)
            {
                returned_variables = sess.run(variable_names);
            }

            var variables_data_map = new Dictionary <string, NDArray>();

            foreach (var(i, name) in enumerate(variable_dict_names))
            {
                variables_data_map[name] = returned_variables[i];
            }
            print($"Froze {len(returned_variables)} variables.");

            // Reconstruct the graph with constants in place of variables.
            var output_graph_def   = new GraphDef();
            int how_many_converted = 0;

            foreach (var input_node in inference_graph.Node)
            {
                var output_node = new NodeDef();
                if (variables_data_map.ContainsKey(input_node.Name))
                {
                    var data = variables_data_map[input_node.Name];
                    output_node = create_const_op(input_node.Name, input_node.Attr["dtype"],
                                                  data, data.shape);
                    how_many_converted += 1;
                }
                // else if (resource_identity_types.ContainsKey(input_node.Name))
                else if (input_node.Op == "ReadVariableOp")
                {
                    output_node.Op   = "Identity";
                    output_node.Name = input_node.Name;
                    output_node.Input.AddRange(new[] { input_node.Input[0] });
                    output_node.Attr["T"] = input_node.Attr["dtype"];
                }
                else if (input_node.Op == "ResourceGather")
                {
                }
                else
                {
                    output_node.MergeFrom(input_node);
                }

                output_graph_def.Node.AddRange(new[] { output_node });
            }

            output_graph_def.Library = inference_graph.Library;
            print($"Converted {how_many_converted} variables to const ops.");
            return(output_graph_def);
        }
Пример #7
0
        private static void _RemoveDefaultAttrs(Dictionary <string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def)
        {
            var producer_op_dict = new Dictionary <string, OpDef>();

            producer_op_list.Op.Select(op =>
            {
                producer_op_dict[op.Name] = op;
                return(op);
            }).ToArray();

            foreach (var node in graph_def.Node)
            {
                // Remove any default attr values that aren't in op_def.
                if (producer_op_dict.ContainsKey(node.Op))
                {
                    var op_def          = op_dict[node.Op];
                    var producer_op_def = producer_op_dict[node.Op];
                    foreach (var key in node.Attr)
                    {
                        if (_FindAttrInOpDef(key.Key, op_def) == null)
                        {
                            var attr_def = _FindAttrInOpDef(key.Key, producer_op_def);
                            if (attr_def != null && attr_def.DefaultValue != null &&
                                node.Attr[key.Key] == attr_def.DefaultValue)
                            {
                                node.Attr[key.Key].ClearValue();
                            }
                        }
                    }
                }
            }
        }