コード例 #1
0
        /// <summary>
        /// Returns `MetaGraphDef` proto. Optionally writes it to filename.
        /// </summary>
        /// <param name="filename"></param>
        /// <param name="graph_def"></param>
        /// <param name="as_text"></param>
        /// <param name="unbound_inputs_col_name"></param>
        /// <param name="clear_devices"></param>
        /// <param name="saver_def"></param>
        /// <param name="clear_extraneous_savers"></param>
        /// <param name="strip_default_attrs"></param>
        /// <param name="meta_info_def"></param>
        /// <returns></returns>
        public static MetaGraphDef export_scoped_meta_graph(string filename                = "",
                                                            GraphDef graph_def             = null,
                                                            bool as_text                   = false,
                                                            string unbound_inputs_col_name = "unbound_inputs",
                                                            bool clear_devices             = false,
                                                            SaverDef saver_def             = null,
                                                            bool clear_extraneous_savers   = false,
                                                            bool strip_default_attrs       = false,
                                                            byte[] meta_info_def           = null)
        {
            var graph = ops.get_default_graph();

            var var_list  = new Dictionary <string, RefVariable>();
            var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES);

            foreach (var v in variables as RefVariable[])
            {
                var_list[v.name] = v;
            }

            var scoped_meta_graph_def = create_meta_graph_def(
                graph_def: graph_def,
                export_scope: "",
                exclude_nodes: "",
                clear_extraneous_savers: clear_extraneous_savers,
                saver_def: saver_def,
                strip_default_attrs: strip_default_attrs);

            throw new NotImplementedException("meta_graph.export_scoped_meta_graph");
        }
コード例 #2
0
        private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def    = null,
                                                          GraphDef graph_def           = null,
                                                          string export_scope          = "",
                                                          string exclude_nodes         = "",
                                                          SaverDef saver_def           = null,
                                                          bool clear_extraneous_savers = false,
                                                          bool strip_default_attrs     = false)
        {
            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph().as_default();
            // Creates a MetaGraphDef proto.
            var meta_graph_def = new MetaGraphDef();

            if (meta_info_def == null)
            {
                meta_info_def = new MetaInfoDef();
            }

            // Set the tf version strings to the current tf build.
            meta_info_def.TensorflowVersion    = tf.VERSION;
            meta_info_def.TensorflowGitVersion = "unknown";
            meta_graph_def.MetaInfoDef         = meta_info_def;

            // Adds graph_def or the default.
            if (graph_def == null)
            {
                meta_graph_def.GraphDef = graph.as_graph_def(add_shapes: true);
            }
            else
            {
                meta_graph_def.GraphDef = graph_def;
            }

            // Fills in meta_info_def.stripped_op_list using the ops from graph_def.
            if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
                meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
            {
                meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);
            }

            var clist = graph.get_all_collection_keys();

            foreach (var ctype in clist)
            {
                if (clear_extraneous_savers)
                {
                    throw new NotImplementedException("create_meta_graph_def clear_extraneous_savers");
                }
                else
                {
                    add_collection_def(meta_graph_def, ctype, graph);
                }
            }

            return(meta_graph_def);
        }
コード例 #3
0
ファイル: Saver.cs プロジェクト: wtf3505-git/TensorFlow.NET
        private void _build(string checkpoint_path, bool build_save, bool build_restore)
        {
            if (_is_built)
            {
                return;
            }

            _is_built = true;

            if (_saver_def == null)
            {
                if (_builder == null)
                {
                    _builder = new BulkSaverBuilder(_write_version);
                }

                if (_var_list == null)
                {
                    _var_list = variables._all_saveable_objects();
                }

                if (_var_list == null || _var_list.Length == 0)
                {
                    if (_allow_empty)
                    {
                        _is_empty = true;
                        return;
                    }
                    else
                    {
                        throw new ValueError("No variables to save");
                    }
                }
                _is_empty = false;

                _saver_def = _builder._build_internal(_var_list,
                                                      reshape: _reshape,
                                                      sharded: _sharded,
                                                      max_to_keep: _max_to_keep,
                                                      keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours,
                                                      name: _name,
                                                      restore_sequentially: _restore_sequentially,
                                                      filename: checkpoint_path,
                                                      build_save: build_save,
                                                      build_restore: build_restore);
            }
            else if (_saver_def != null && !string.IsNullOrEmpty(_name))
            {
                throw new NotImplementedException("Saver._build");
            }

            _check_saver_def();

            _next_checkpoint_time = time() + _saver_def.KeepCheckpointEveryNHours * 3600;
        }
コード例 #4
0
ファイル: Saver.cs プロジェクト: wtf3505-git/TensorFlow.NET
        public Saver(IVariableV1[] var_list = null,
                     bool reshape           = false,
                     bool sharded           = false,
                     int max_to_keep        = 5,
                     float keep_checkpoint_every_n_hours = 10000,
                     string name = null,
                     bool restore_sequentially = false,
                     SaverDef saver_def        = null,
                     ISaverBuilder builder     = null,
                     bool defer_build          = false,
                     bool allow_empty          = false,
                     SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2,
                     bool pad_step_number     = false,
                     bool save_relative_paths = false,
                     string filename          = "")
        {
            _var_list    = var_list;
            _reshape     = reshape;
            _sharded     = sharded;
            _max_to_keep = max_to_keep;
            _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours;
            _name = name;
            _restore_sequentially = restore_sequentially;
            _saver_def            = saver_def;
            _builder         = builder;
            _is_built        = false;
            _allow_empty     = allow_empty;
            _write_version   = write_version;
            _pad_step_number = pad_step_number;

            if (!defer_build)
            {
                build();
            }
            if (_saver_def != null)
            {
                _check_saver_def();
                _write_version = _saver_def.Version;
            }

            _save_relative_paths  = save_relative_paths;
            _object_restore_saver = null;

            _last_checkpoints          = new Dictionary <string, float>();
            _checkpoints_to_be_deleted = new Dictionary <string, float>();
        }
コード例 #5
0
ファイル: Saver.cs プロジェクト: wtf3505-git/TensorFlow.NET
        public MetaGraphDef export_meta_graph(string filename              = "",
                                              byte[] meta_info_def         = null,
                                              GraphDef graph_def           = null,
                                              SaverDef saver_def           = null,
                                              string[] collection_list     = null,
                                              bool as_text                 = false,
                                              bool clear_devices           = false,
                                              bool clear_extraneous_savers = false,
                                              bool strip_default_attrs     = false,
                                              string export_scope          = "")
        {
            var meta_graph_def = meta_graph.export_scoped_meta_graph(
                filename: filename,
                meta_info_def: meta_info_def,
                graph_def: graph_def,
                saver_def: saver_def,
                // collection_list: collection_list,
                as_text: as_text,
                clear_devices: clear_devices,
                clear_extraneous_savers: clear_extraneous_savers,
                strip_default_attrs: strip_default_attrs);

            return(meta_graph_def.Item1);
        }
コード例 #6
0
        /// <summary>
        /// Returns `MetaGraphDef` proto. Optionally writes it to filename.
        /// </summary>
        /// <param name="filename"></param>
        /// <param name="graph_def"></param>
        /// <param name="as_text"></param>
        /// <param name="unbound_inputs_col_name"></param>
        /// <param name="clear_devices"></param>
        /// <param name="saver_def"></param>
        /// <param name="clear_extraneous_savers"></param>
        /// <param name="strip_default_attrs"></param>
        /// <param name="meta_info_def"></param>
        /// <returns></returns>
        public static (MetaGraphDef, Dictionary <string, IVariableV1>) export_scoped_meta_graph(string filename                = "",
                                                                                                GraphDef graph_def             = null,
                                                                                                bool as_text                   = false,
                                                                                                string unbound_inputs_col_name = "unbound_inputs",
                                                                                                bool clear_devices             = false,
                                                                                                SaverDef saver_def             = null,
                                                                                                bool clear_extraneous_savers   = false,
                                                                                                bool strip_default_attrs       = false,
                                                                                                byte[] meta_info_def           = null)
        {
            var graph = ops.get_default_graph();

            var var_list  = new Dictionary <string, IVariableV1>();
            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);

            if (variables != null)
            {
                foreach (var v in variables)
                {
                    var_list[v.Name] = v;
                }
            }

            var scoped_meta_graph_def = create_meta_graph_def(
                graph_def: graph_def,
                export_scope: "",
                exclude_nodes: "",
                clear_extraneous_savers: clear_extraneous_savers,
                saver_def: saver_def,
                strip_default_attrs: strip_default_attrs);

            if (!string.IsNullOrEmpty(filename))
            {
                graph_io.write_graph(scoped_meta_graph_def, "", filename, as_text: as_text);
            }

            return(scoped_meta_graph_def, var_list);
        }