Exemplo n.º 1
0
        private static void _RemoveDefaultAttrs(Dictionary <string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def)
        {
            var producer_op_dict = new Dictionary <string, OpDef>();

            producer_op_list.Op.Select(op =>
            {
                producer_op_dict[op.Name] = op;
                return(op);
            }).ToArray();

            foreach (var node in graph_def.Node)
            {
                // Remove any default attr values that aren't in op_def.
                if (producer_op_dict.ContainsKey(node.Op))
                {
                    var op_def          = op_dict[node.Op];
                    var producer_op_def = producer_op_dict[node.Op];
                    foreach (var key in node.Attr)
                    {
                        if (_FindAttrInOpDef(key.Key, op_def) == null)
                        {
                            var attr_def = _FindAttrInOpDef(key.Key, producer_op_def);
                            if (attr_def != null && attr_def.DefaultValue != null &&
                                node.Attr[key.Key] == attr_def.DefaultValue)
                            {
                                node.Attr[key.Key].ClearValue();
                            }
                        }
                    }
                }
            }
        }
 public void add_op_list(OpList op_list)
 {
     foreach (var op_def in op_list.Op)
     {
         add_op(op_def);
     }
 }
Exemplo n.º 3
0
        public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
                                                            Dictionary <string, Tensor> input_map = null,
                                                            string[] return_elements = null,
                                                            string name             = null,
                                                            OpList producer_op_list = null)
        {
            var op_dict = op_def_registry.get_registered_ops();

            graph_def       = _ProcessGraphDefParam(graph_def, op_dict);
            input_map       = _ProcessInputMapParam(input_map);
            return_elements = _ProcessReturnElementsParam(return_elements);

            if (producer_op_list != null)
            {
                _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def);
            }

            string prefix = "";
            var    graph  = ops.get_default_graph();

            tf_with(ops.name_scope(name, "import", input_map.Values), scope =>
            {
                prefix = scope;

                /*if (!string.IsNullOrEmpty(prefix))
                 *  prefix = prefix.Substring(0, prefix.Length - 1);
                 * else
                 *  prefix = "";*/

                // Generate any input map tensors inside name scope
                input_map = _ConvertInputMapValues(name, input_map);
            });

            TF_ImportGraphDefResults results = null;
            var bytes = graph_def.ToByteString().ToArray();

            using (var buffer = c_api_util.tf_buffer(bytes))
                using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
                    using (var status = new Status())
                    {
                        _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
                        // need to create a class ImportGraphDefWithResults with IDisposal
                        results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
                        status.Check(true);
                    }

            _ProcessNewOps(graph);

            if (return_elements == null)
            {
                return(null);
            }
            else
            {
                return(_GatherReturnElements(return_elements, graph, results));
            }
        }
Exemplo n.º 4
0
        private static OpList stripped_op_list_for_graph(GraphDef graph_def)
        {
            var used_ops = ops_used_by_graph_def(graph_def);

            // Verify that all used ops are registered.
            // var registered_ops = op_def_registry.get_registered_ops();

            var op_list = new OpList();

            /*used_ops.OrderBy(x => x).Select(x => {
             *
             * }).ToArray();*/

            return(op_list);
        }
Exemplo n.º 5
0
        public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                                                             bool clear_devices  = false,
                                                                                                                             string import_scope = "",
                                                                                                                             Dictionary <string, Tensor> input_map = null,
                                                                                                                             string unbound_inputs_col_name        = "unbound_inputs",
                                                                                                                             string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <ByteString, IVariableV1>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            IVariableV1 variable = null;
                            if (!variable_objects.ContainsKey(value))
                            {
                                var proto = VariableDef.Parser.ParseFrom(value);
                                if (proto.IsResource)
                                {
                                    variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                else
                                {
                                    variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                variable_objects[value] = variable;
                            }
                            variable = variable_objects[value];
                            graph.add_to_collection(col.Key, variable);
                        }
                    }
                    else
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            switch (col.Key)
                            {
                            case "cond_context":
                            {
                                var proto       = CondContextDef.Parser.ParseFrom(value);
                                var condContext = new CondContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, condContext);
                            }
                            break;

                            case "while_context":
                            {
                                var proto        = WhileContextDef.Parser.ParseFrom(value);
                                var whileContext = new WhileContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, whileContext);
                            }
                            break;

                            default:
                                Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
                                continue;
                            }
                        }
                    }

                    break;

                default:
                    Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
                    break;
                }
            }

            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
                                                               scope: scope_to_prepend_to_names);
            var var_list = new Dictionary <string, IVariableV1>();

            variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);

            return(var_list, imported_return_elements);
        }
Exemplo n.º 6
0
 public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
                                              Dictionary <string, Tensor> input_map = null,
                                              string[] return_elements = null,
                                              string name             = null,
                                              OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
Exemplo n.º 7
0
        public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                              bool clear_devices  = false,
                                                                                              string import_scope = "",
                                                                                              Dictionary <string, Tensor> input_map = null,
                                                                                              string unbound_inputs_col_name        = "unbound_inputs",
                                                                                              string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <string, RefVariable>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            var proto = VariableDef.Parser.ParseFrom(value);
                            throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                        }
                    }
                    else
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }

                    break;
                }
            }

            return(null, null);
        }