Ejemplo n.º 1
0
        /// <summary>
        /// Graph Transform Tool
        /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md
        /// </summary>
        /// <param name="input_graph_def">GraphDef object containing a model to be transformed</param>
        /// <param name="inputs">the model inputs</param>
        /// <param name="outputs">the model outputs</param>
        /// <param name="transforms">transform names and parameters</param>
        /// <returns></returns>
        public GraphDef TransformGraph(GraphDef input_graph_def,
                                       string[] inputs,
                                       string[] outputs,
                                       string[] transforms)
        {
            var input_graph_def_string = input_graph_def.ToByteArray();
            var inputs_string          = string.Join(",", inputs);
            var outputs_string         = string.Join(",", outputs);
            var transforms_string      = string.Join(" ", transforms);

            using (var status = new Status())
            {
                var buffer = new Buffer();
                var len    = c_api.TransformGraphWithStringInputs(input_graph_def_string,
                                                                  input_graph_def_string.Length,
                                                                  inputs_string,
                                                                  outputs_string,
                                                                  transforms_string,
                                                                  buffer,
                                                                  status);

                status.Check(false);
                var bytes = buffer.ToArray();
                return(GraphDef.Parser.ParseFrom(bytes));
            }
        }
Ejemplo n.º 2
0
 private NodeDef GetNodeDef()
 {
     using var buffer = new Buffer();
     c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle);
     tf.Status.Check(throwException: true);
     return(NodeDef.Parser.ParseFrom(buffer.ToArray()));
 }
Ejemplo n.º 3
0
        public virtual object get_attr(string name)
        {
            using var buf = new Buffer();
            c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
            tf.Status.Check(true);

            var x = AttrValue.Parser.ParseFrom(buf.ToArray());

            string oneof_value = x.ValueCase.ToString();

            if (string.IsNullOrEmpty(oneof_value))
            {
                return(null);
            }

            switch (oneof_value.ToLower())
            {
            case "list":
                throw new NotImplementedException($"Unsupported field type in {oneof_value}");

            case "type":
                return(x.Type);

            case "s":
                return(x.S.ToStringUtf8());

            default:
                return(x.GetType().GetProperty(oneof_value).GetValue(x));
            }
        }
Ejemplo n.º 4
0
        public virtual T[] get_attr_list <T>(string name)
        {
            if (tf.executing_eagerly())
            {
                return((T[])get_attr(name));
            }

            using var buf = new Buffer();
            c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
            tf.Status.Check(true);

            var x = AttrValue.Parser.ParseFrom(buf.ToArray());

            string oneof_value = x.ValueCase.ToString();

            if (string.IsNullOrEmpty(oneof_value))
            {
                return(null);
            }

            switch (typeof(T).Name)
            {
            case nameof(Int32):
                return(x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray());

            case nameof(Int64):
                return(x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray());

            default:
                return(null);
            }
        }
Ejemplo n.º 5
0
        public static FunctionDef GetFunctionDef(IntPtr func)
        {
            using var s      = new Status();
            using var buffer = new Buffer();
            c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle);
            s.Check(true);
            var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray());

            return(func_def);
        }
Ejemplo n.º 6
0
        public static Dictionary <string, OpDef> get_registered_ops()
        {
            if (_registered_ops.Count == 0)
            {
                lock (_registered_ops)
                {
                    // double validation to avoid multi-thread executing
                    if (_registered_ops.Count > 0)
                    {
                        return(_registered_ops);
                    }

                    using var buffer = new Buffer(c_api.TF_GetAllOpList());
                    var op_list = OpList.Parser.ParseFrom(buffer.ToArray());
                    foreach (var op_def in op_list.Op)
                    {
                        _registered_ops[op_def.Name] = op_def;
                    }
                }
            }

            return(_registered_ops);
        }