Example #1
0
        private static void add_collection_def(MetaGraphDef meta_graph_def,
                                               string key,
                                               Graph graph         = null,
                                               string export_scope = "")
        {
            if (!meta_graph_def.CollectionDef.ContainsKey(key))
            {
                meta_graph_def.CollectionDef[key] = new CollectionDef();
            }
            var col_def = meta_graph_def.CollectionDef[key];

            col_def.NodeList  = new Types.NodeList();
            col_def.BytesList = new Types.BytesList();
            foreach (object value in graph.get_collection(key))
            {
                switch (value)
                {
                case RefVariable x:
                    var proto = x.to_proto(export_scope);
                    col_def.BytesList.Value.Add(proto.ToByteString());
                    break;

                case ITensorOrOperation x2:
                    col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
                    break;

                default:
                    break;
                }
            }
        }
Example #2
0
 /// <summary>
 /// Return a saver for restoring variable values to an imported MetaGraph.
 /// </summary>
 /// <param name="meta_graph_def"></param>
 /// <param name="import_scope"></param>
 /// <param name="imported_vars"></param>
 /// <returns></returns>
 public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def,
                                                            string import_scope,
                                                            Dictionary <string, VariableV1> imported_vars)
 {
     if (meta_graph_def.SaverDef != null)
     {
         // Infer the scope that is prepended by `import_scoped_meta_graph`.
         string scope     = import_scope;
         var    var_names = imported_vars.Keys.ToArray();
         if (var_names.Length > 0)
         {
             var sample_key = var_names[0];
             var sample_var = imported_vars[sample_key];
             scope = string.Join("", sample_var.name.Skip(sample_key.Length));
         }
         return(new Saver(saver_def: meta_graph_def.SaverDef, name: scope));
     }
     else
     {
         if (variables._all_saveable_objects(scope: import_scope).Length > 0)
         {
             // Return the default saver instance for all graph variables.
             return(new Saver());
         }
         else
         {
             // If no graph variables exist, then a Saver cannot be constructed.
             Console.WriteLine("Saver not created because there are no variables in the" +
                               " graph to restore");
             return(null);
         }
     }
 }
Example #3
0
        public void save(Trackable obj, string export_dir, SaveOptions options = null)
        {
            var saved_model    = new SavedModel();
            var meta_graph_def = new MetaGraphDef();

            saved_model.MetaGraphs.Add(meta_graph_def);
            _build_meta_graph(obj, export_dir, options, meta_graph_def);
        }
Example #4
0
        private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def    = null,
                                                          GraphDef graph_def           = null,
                                                          string export_scope          = "",
                                                          string exclude_nodes         = "",
                                                          SaverDef saver_def           = null,
                                                          bool clear_extraneous_savers = false,
                                                          bool strip_default_attrs     = false)
        {
            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph().as_default();
            // Creates a MetaGraphDef proto.
            var meta_graph_def = new MetaGraphDef();

            if (meta_info_def == null)
            {
                meta_info_def = new MetaInfoDef();
            }

            // Set the tf version strings to the current tf build.
            meta_info_def.TensorflowVersion    = tf.VERSION;
            meta_info_def.TensorflowGitVersion = "unknown";
            meta_graph_def.MetaInfoDef         = meta_info_def;

            // Adds graph_def or the default.
            if (graph_def == null)
            {
                meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true);
            }
            else
            {
                meta_graph_def.GraphDef = graph_def;
            }

            // Fills in meta_info_def.stripped_op_list using the ops from graph_def.
            if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
                meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
            {
                meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);
            }

            var clist = graph.get_all_collection_keys();

            foreach (var ctype in clist)
            {
                if (clear_extraneous_savers)
                {
                    throw new NotImplementedException("create_meta_graph_def clear_extraneous_savers");
                }
                else
                {
                    add_collection_def(meta_graph_def, ctype, graph);
                }
            }

            return(meta_graph_def);
        }
Example #5
0
        private static void add_collection_def(MetaGraphDef meta_graph_def,
                                               string key,
                                               Graph graph         = null,
                                               string export_scope = "")
        {
            if (!meta_graph_def.CollectionDef.ContainsKey(key))
            {
                meta_graph_def.CollectionDef[key] = new CollectionDef();
            }
            var col_def = meta_graph_def.CollectionDef[key];

            switch (graph.get_collection(key))
            {
            case List <IVariableV1> collection_list:
                col_def.BytesList = new Types.BytesList();
                foreach (var x in collection_list)
                {
                    if (x is RefVariable x_ref_var)
                    {
                        var proto = x_ref_var.to_proto(export_scope);
                        col_def.BytesList.Value.Add(proto.ToByteString());
                    }
                    else if (x is ResourceVariable x_res_var)
                    {
                        var proto = x_res_var.to_proto(export_scope);
                        col_def.BytesList.Value.Add(proto.ToByteString());
                    }
                }
                break;

            case List <RefVariable> collection_list:
                col_def.BytesList = new Types.BytesList();
                foreach (var x in collection_list)
                {
                    var proto = x.to_proto(export_scope);
                    col_def.BytesList.Value.Add(proto.ToByteString());
                }

                break;

            case List <object> collection_list:
                col_def.NodeList = new Types.NodeList();
                foreach (var x in collection_list)
                {
                    if (x is ITensorOrOperation x2)
                    {
                        col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
                    }
                }
                break;

            case List <Operation> collection_list:
                break;
            }
        }
Example #6
0
        public static string write_graph(MetaGraphDef graph_def, string logdir, string name, bool as_text = true)
        {
            string path = Path.Combine(logdir, name);

            if (as_text)
            {
                File.WriteAllText(path, graph_def.ToString());
            }
            else
            {
                File.WriteAllBytes(path, graph_def.ToByteArray());
            }
            return(path);
        }
Example #7
0
        public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                                                             bool clear_devices  = false,
                                                                                                                             string import_scope = "",
                                                                                                                             Dictionary <string, Tensor> input_map = null,
                                                                                                                             string unbound_inputs_col_name        = "unbound_inputs",
                                                                                                                             string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <ByteString, IVariableV1>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            IVariableV1 variable = null;
                            if (!variable_objects.ContainsKey(value))
                            {
                                var proto = VariableDef.Parser.ParseFrom(value);
                                if (proto.IsResource)
                                {
                                    variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                else
                                {
                                    variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                variable_objects[value] = variable;
                            }
                            variable = variable_objects[value];
                            graph.add_to_collection(col.Key, variable);
                        }
                    }
                    else
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            switch (col.Key)
                            {
                            case "cond_context":
                            {
                                var proto       = CondContextDef.Parser.ParseFrom(value);
                                var condContext = new CondContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, condContext);
                            }
                            break;

                            case "while_context":
                            {
                                var proto        = WhileContextDef.Parser.ParseFrom(value);
                                var whileContext = new WhileContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, whileContext);
                            }
                            break;

                            default:
                                Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
                                continue;
                            }
                        }
                    }

                    break;

                default:
                    Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
                    break;
                }
            }

            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
                                                               scope: scope_to_prepend_to_names);
            var var_list = new Dictionary <string, IVariableV1>();

            variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);

            return(var_list, imported_return_elements);
        }
Example #8
0
        public async Task TestResnet()
        {
            using (Tensor imageTensor = ImageIO.ReadTensorFromImageFile <float>("surfers.jpg", 224, 224, 0, 1.0f / 255.0f))
                using (Resnet resnet = new Resnet())
                {
                    await resnet.Init();

                    MetaGraphDef metaGraphDef = MetaGraphDef.Parser.ParseFrom(resnet.MetaGraphDefBuffer.Data);
                    var          signatureDef = metaGraphDef.SignatureDef["serving_default"];
                    var          inputNode    = signatureDef.Inputs;
                    var          outputNode   = signatureDef.Outputs;

                    HashSet <string> opNames        = new HashSet <string>();
                    HashSet <string> couldBeInputs  = new HashSet <string>();
                    HashSet <string> couldBeOutputs = new HashSet <string>();
                    foreach (Operation op in resnet.Graph)
                    {
                        String name = op.Name;
                        opNames.Add(name);

                        if (op.NumInputs == 0 && op.OpType.Equals("Placeholder"))
                        {
                            couldBeInputs.Add(op.Name);
                            AttrMetadata dtypeMeta   = op.GetAttrMetadata("dtype");
                            AttrMetadata shapeMeta   = op.GetAttrMetadata("shape");
                            DataType     type        = op.GetAttrType("dtype");
                            Int64[]      shape       = op.GetAttrShape("shape");
                            Buffer       valueBuffer = op.GetAttrValueProto("shape");
                            Buffer       shapeBuffer = op.GetAttrTensorShapeProto("shape");
                            Tensorflow.TensorShapeProto shapeProto =
                                Tensorflow.TensorShapeProto.Parser.ParseFrom(shapeBuffer.Data);
                        }

                        if (op.OpType.Equals("Const"))
                        {
                            AttrMetadata dtypeMeta = op.GetAttrMetadata("dtype");
                            AttrMetadata valueMeta = op.GetAttrMetadata("value");
                            using (Tensor valueTensor = op.GetAttrTensor("value"))
                            {
                                var dim = valueTensor.Dim;
                            }
                        }

                        if (op.OpType.Equals("Conv2D"))
                        {
                            AttrMetadata stridesMeta = op.GetAttrMetadata("strides");
                            AttrMetadata paddingMeta = op.GetAttrMetadata("padding");
                            AttrMetadata boolMeta    = op.GetAttrMetadata("use_cudnn_on_gpu");
                            Int64[]      strides     = op.GetAttrIntList("strides");
                            bool         useCudnn    = op.GetAttrBool("use_cudnn_on_gpu");
                            String       padding     = op.GetAttrString("padding");
                        }

                        foreach (Output output in op.Outputs)
                        {
                            int[] shape = resnet.Graph.GetTensorShape(output);
                            if (output.NumConsumers == 0)
                            {
                                couldBeOutputs.Add(name);
                            }
                        }

                        Buffer           buffer = resnet.Graph.GetOpDef(op.OpType);
                        Tensorflow.OpDef opDef  = Tensorflow.OpDef.Parser.ParseFrom(buffer.Data);
                    }

                    using (Buffer versionDef = resnet.Graph.Versions())
                    {
                        int l = versionDef.Length;
                    }

                    Resnet.RecognitionResult[][] results = resnet.Recognize(imageTensor);
                }
        }
Example #9
0
 void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options,
                        MetaGraphDef meta_graph_def = null)
 {
 }
Example #10
0
        public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                              bool clear_devices  = false,
                                                                                              string import_scope = "",
                                                                                              Dictionary <string, Tensor> input_map = null,
                                                                                              string unbound_inputs_col_name        = "unbound_inputs",
                                                                                              string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <string, RefVariable>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            var proto = VariableDef.Parser.ParseFrom(value);
                            throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                        }
                    }
                    else
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }

                    break;
                }
            }

            return(null, null);
        }