示例#1
0
 /// <summary>
 /// Reads the value of a variable.
 /// </summary>
 /// <param name="resource"></param>
 /// <param name="dtype"></param>
 /// <param name="name"></param>
 /// <returns></returns>
 public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null)
 => tf.Context.ExecuteOp("ReadVariableOp", name, new ExecuteOpArgs(resource)
                         .SetAttributes(new { dtype }));
示例#2
0
 public Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
 => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name);
示例#3
0
        public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
        {
            var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type });

            return(_op.outputs[0]);
        }
        public Tensor[] TFE_FastPathExecute(Context ctx,
                                            string device_name,
                                            string opName,
                                            string name,
                                            Action callbacks,
                                            params object[] args)
        {
            if (ctx == null)
            {
                throw new ValueError("This function does not handle the case of the path where " +
                                     "all inputs are not already EagerTensors.");
            }

            int args_size       = args.Length;
            var attr_list_sizes = new Dictionary <string, long>();

            FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo()
            {
                ctx         = ctx,
                args        = args,
                device_name = device_name,
                op_name     = opName,
                name        = name,
            };

            op_exec_info.run_gradient_callback   = HasAccumulatorOrTape();
            op_exec_info.run_post_exec_callbacks = callbacks != null;
            op_exec_info.run_callbacks           = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks;

            var status = tf.Status;
            var op     = GetOp(ctx, opName, status);

            var op_def = tf.get_default_graph().GetOpDef(opName);

            var flattened_attrs  = new List <object>(op_def.InputArg.Count);
            var flattened_inputs = new List <Tensor>(op_def.InputArg.Count);

            // Set non-inferred attrs, including setting defaults if the attr is passed in
            // as None.
            for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
            {
                var attr_name  = args[i].ToString();
                var attr_value = args[i + 1];

                var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name);
                if (attr != null)
                {
                    flattened_attrs.Add(attr_name);
                    flattened_attrs.Add(attr_value);

                    SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
                    status.Check(true);
                }
            }

            c_api.TFE_OpSetDevice(op, device_name, status.Handle);
            status.Check(true);

            // Add inferred attrs and inputs.
            for (int i = 0; i < op_def.InputArg.Count; i++)
            {
                var input     = args[kFastPathExecuteInputStartIndex + i];
                var input_arg = op_def.InputArg[i];
                if (!string.IsNullOrEmpty(input_arg.NumberAttr))
                {
                    int len = (input as object[]).Length;
                    c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(input_arg.NumberAttr);
                        flattened_attrs.Add(len);
                    }
                    attr_list_sizes[input_arg.NumberAttr] = len;

                    if (len > 0)
                    {
                        var fast_input_array = (object[])args[i];
                        // First item adds the type attr.
                        if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
                        {
                            return(null);
                        }

                        for (var j = 1; j < len; j++)
                        {
                            // Since the list is homogeneous, we don't need to re-add the attr.
                            if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status))
                            {
                                return(null);
                            }
                        }
                    }
                }
                else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
                {
                    var attr_name        = input_arg.TypeListAttr;
                    var fast_input_array = input as object[];
                    var len         = fast_input_array.Length;
                    var attr_values = new TF_DataType[len];

                    for (var j = 0; j < len; j++)
                    {
                        var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
                        attr_values[j] = eager_tensor.dtype;

                        c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);

                        if (op_exec_info.run_callbacks)
                        {
                            flattened_inputs.Add(eager_tensor);
                        }
                    }

                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(attr_name);
                        flattened_attrs.Add(attr_values);
                    }
                    c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length);
                    attr_list_sizes[attr_name] = len;
                }
                else
                {
                    // The item is a single item.
                    AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
                }
            }

            int num_retvals = 0;

            for (int i = 0; i < op_def.OutputArg.Count; i++)
            {
                var output_arg = op_def.OutputArg[i];
                var delta      = 1L;
                if (!string.IsNullOrEmpty(output_arg.NumberAttr))
                {
                    delta = attr_list_sizes[output_arg.NumberAttr];
                }
                else if (!string.IsNullOrEmpty(output_arg.TypeListAttr))
                {
                    delta = attr_list_sizes[output_arg.TypeListAttr];
                }
                if (delta < 0)
                {
                    throw new RuntimeError("Attributes suggest that the size of an output list is less than 0");
                }
                num_retvals += (int)delta;
            }

            var retVals = new SafeTensorHandleHandle[num_retvals];

            c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
            status.Check(true);

            var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();

            if (op_exec_info.run_callbacks)
            {
                RunCallbacks(op_exec_info,
                             kFastPathExecuteInputStartIndex + op_def.InputArg.Count(),
                             flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result);
            }

            return(flat_result);
        }
示例#5
0
 public Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
 => math_ops.cast(x, dtype, name);
示例#6
0
 protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value)
 => c_api.TFE_OpSetAttrType(op, attr_name, value);
示例#7
0
        private void _init_from_args(object initial_value,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     bool validate_shape       = true,
                                     string caching_device     = "",
                                     string name       = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid)
        {
            if (initial_value is null)
            {
                throw new ValueError("initial_value must be specified.");
            }

            var init_from_fn = initial_value.GetType().Name == "Func`1";

            if (collections == null)
            {
                collections = new List <string> {
                    ops.GraphKeys.GLOBAL_VARIABLES
                };
            }

            // Store the graph key so optimizers know how to only retrieve variables from
            // this graph.
            _graph_key = ops.get_default_graph()._graph_key;

            _trainable = trainable;
            if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
            }

            ops.init_scope();
            var values = init_from_fn ? new object[0] : new object[] { initial_value };

            with(ops.name_scope(name, "Variable", values), scope =>
            {
                name = scope;
                if (init_from_fn)
                {
                    // Use attr_scope and device(None) to simulate the behavior of
                    // colocate_with when the variable we want to colocate with doesn't
                    // yet exist.
                    string true_name = ops._name_from_scope_name(name);
                    var attr         = new AttrValue
                    {
                        List = new AttrValue.Types.ListValue()
                    };
                    attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
                    with(ops.name_scope("Initializer"), scope2 =>
                    {
                        _initial_value = (initial_value as Func <Tensor>)();
                        _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
                    });
                    _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
                }
                // Or get the initial value from a Tensor or Python object.
                else
                {
                    _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");

                    var shape = _initial_value.shape;
                    dtype     = _initial_value.dtype;
                    _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
                }

                // Manually overrides the variable's shape with the initial value's.
                if (validate_shape)
                {
                    var initial_value_shape = _initial_value.GetShape();
                    if (!initial_value_shape.is_fully_defined())
                    {
                        throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
                    }
                }

                // If 'initial_value' makes use of other variables, make sure we don't
                // have an issue if these other variables aren't initialized first by
                // using their initialized_value() method.
                var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);

                _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

                if (!String.IsNullOrEmpty(caching_device))
                {
                }
                else
                {
                    ops.colocate_with(_initializer_op);

                    _snapshot = gen_array_ops.identity(_variable, name = "read");
                }

                ops.add_to_collections(collections, this as VariableV1);
            });
        }
示例#8
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="type"></param>
        /// <returns></returns>
        /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception>
        public static TF_DataType as_tf_dtype(this Type type)
        {
            while (type.IsArray)
            {
                type = type.GetElementType();
            }

            TF_DataType dtype = TF_DataType.DtInvalid;

            switch (type.Name)
            {
            case "Char":
                dtype = TF_DataType.TF_UINT8;
                break;

            case "SByte":
                dtype = TF_DataType.TF_INT8;
                break;

            case "Byte":
                dtype = TF_DataType.TF_UINT8;
                break;

            case "Int16":
                dtype = TF_DataType.TF_INT16;
                break;

            case "UInt16":
                dtype = TF_DataType.TF_UINT16;
                break;

            case "Int32":
                dtype = TF_DataType.TF_INT32;
                break;

            case "UInt32":
                dtype = TF_DataType.TF_UINT32;
                break;

            case "Int64":
                dtype = TF_DataType.TF_INT64;
                break;

            case "UInt64":
                dtype = TF_DataType.TF_UINT64;
                break;

            case "Single":
                dtype = TF_DataType.TF_FLOAT;
                break;

            case "Double":
                dtype = TF_DataType.TF_DOUBLE;
                break;

            case "Complex":
                dtype = TF_DataType.TF_COMPLEX128;
                break;

            case "String":
                dtype = TF_DataType.TF_STRING;
                break;

            case "Boolean":
                dtype = TF_DataType.TF_BOOL;
                break;

            default:
                dtype = TF_DataType.DtInvalid;
                break;
            }

            return(dtype);
        }
示例#9
0
 public static DataType as_datatype_enum(this TF_DataType type)
 {
     return((DataType)type);
 }
示例#10
0
 public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
 {
     return(gen_array_ops.placeholder(dtype, shape, name));
 }
示例#11
0
 /// <summary>
 /// Creates a recurrent neural network specified by RNNCell `cell`.
 /// </summary>
 /// <param name="cell">An instance of RNNCell.</param>
 /// <param name="inputs">The RNN inputs.</param>
 /// <param name="dtype"></param>
 /// <param name="swap_memory"></param>
 /// <param name="time_major"></param>
 /// <returns>A pair (outputs, state)</returns>
 public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
                                     Tensor sequence_length  = null, TF_DataType dtype = TF_DataType.DtInvalid,
                                     int?parallel_iterations = null, bool swap_memory  = false, bool time_major = false)
 => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
示例#12
0
        private void _init_from_args(object initial_value      = null,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     string caching_device     = "",
                                     string name       = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid,
                                     VariableAggregation aggregation = VariableAggregation.None,
                                     TensorShape shape = null)
        {
            var init_from_fn = initial_value.GetType().Name == "Func`1" ||
                               initial_value.GetType().GetInterface("IInitializer") != null;

            if (collections == null)
            {
                collections = new List <string>()
                {
                    tf.GraphKeys.GLOBAL_VARIABLES
                }
            }
            ;
            _trainable = trainable;

            if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
            }

            _in_graph_mode = !tf.Context.executing_eagerly();
            tf_with(ops.init_scope2(), delegate
            {
                var values = init_from_fn ? new object[0] : new object[] { initial_value };
                tf_with(ops.name_scope(name, "Variable", values), scope =>
                {
                    name               = scope;
                    var handle_name    = ops.name_from_scope_name(name);
                    string unique_id   = "";
                    string shared_name = "";

                    if (_in_graph_mode)
                    {
                        shared_name = handle_name;
                        unique_id   = shared_name;
                    }
                    else
                    {
                        unique_id   = $"{handle_name}_{ops.uid()}";
                        shared_name = tf.Context.shared_name();
                    }

                    var attr  = new AttrValue();
                    attr.List = new AttrValue.Types.ListValue();
                    attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}"));
                    tf_with(ops.name_scope("Initializer"), delegate
                    {
                        if (initial_value.GetType().GetInterface("IInitializer") != null)
                        {
                            initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype)));
                        }
                        else
                        {
                            var value     = init_from_fn ? (initial_value as Func <Tensor>)() : initial_value;
                            initial_value = ops.convert_to_tensor(value,
                                                                  name: "initial_value",
                                                                  dtype: dtype);
                        }
                    });
                    _shape         = shape ?? (initial_value as Tensor).TensorShape;
                    _initial_value = initial_value as Tensor;



                    if (_in_graph_mode)
                    {
                        handle         = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
                        initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;

                        ops.colocate_with(initializer_op);

                        _graph_element = gen_array_ops.identity(handle, name = "read");
                        ops.add_to_collections <IVariableV1>(collections, this);
                        _dtype = handle.dtype;
                    }
                    else
                    {
                        handle = resource_variable_ops.eager_safe_variable_handle(
                            initial_value: _initial_value,
                            shape: _shape,
                            shared_name: shared_name,
                            name: name,
                            graph_mode: _in_graph_mode);

                        gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
                        is_initialized_op = null;
                        initializer_op    = null;
                        _graph_element    = null;
                        _dtype            = _initial_value.dtype.as_base_dtype();
                        initial_value     = _in_graph_mode ? initial_value : null;
                    }

                    base.__init__(trainable: trainable,
                                  handle: handle,
                                  name: name,
                                  unique_id: unique_id,
                                  handle_name: handle_name);
                });
            });
        }
示例#13
0
 /// <summary>
 /// Returns shape of tensors.
 /// </summary>
 /// <param name="input"></param>
 /// <param name="out_type"></param>
 /// <param name="name"></param>
 /// <returns></returns>
 public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
 => tf.Context.ExecuteOp("ShapeN", name, new ExecuteOpArgs()
 {
     OpInputArgs = new object[] { input }
 }.SetAttributes(new { out_type }));
示例#14
0
 public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
 => tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input)
                         .SetAttributes(new { out_type }));
示例#15
0
 /// <summary>
 /// Creates a constant tensor.
 ///
 /// The resulting tensor is populated with values of type `dtype`, as
 /// specified by arguments `value` and (optionally) `shape`
 /// </summary>
 /// <param name="value">A constant value (or list) of output type `dtype`.</param>
 /// <param name="dtype">The type of the elements of the resulting tensor.</param>
 /// <param name="shape">Optional dimensions of resulting tensor.</param>
 /// <param name="name">Optional name for the tensor.</param>
 /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
 /// <returns></returns>
 public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const")
 {
     return(_constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true));
 }
示例#16
0
 public static TF_DataType as_base_dtype(this TF_DataType type)
 {
     return((int)type > 100 ? (TF_DataType)((int)type - 100) : type);
 }
示例#17
0
 protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
 => c_api.TF_SetAttrType(desc, attrName, dtype);
示例#18
0
 public static int name(this TF_DataType type)
 {
     return((int)type);
 }
示例#19
0
        public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
        {
            if (limit == null)
            {
                limit = start;
                start = 0;
            }

            if (delta == null)
            {
                delta = 1;
            }

            return(with(ops.name_scope(name, "Range", new { start, limit, delta }), scope =>
            {
                name = scope;
                var start1 = ops.convert_to_tensor(start, name: "start");
                var limit1 = ops.convert_to_tensor(limit, name: "limit");
                var delta1 = ops.convert_to_tensor(delta, name: "delta");

                return gen_math_ops.range(start1, limit1, delta1, name);
            }));
        }
示例#20
0
 public static string as_numpy_name(this TF_DataType type)
 => type switch
 {
示例#21
0
 public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
 => gen_array_ops.placeholder(dtype, shape, name);
示例#22
0
 public DenseSpec(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
 {
     _shape = shape;
     _dtype = dtype;
     _name  = name;
 }
示例#23
0
 public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
 => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name);
示例#24
0
        /// <summary>
        /// Create a TensorProto.
        /// </summary>
        /// <param name="values"></param>
        /// <param name="dtype"></param>
        /// <param name="shape"></param>
        /// <param name="verify_shape"></param>
        /// <param name="allow_broadcast"></param>
        /// <returns></returns>
        public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false)
        {
            if (allow_broadcast && verify_shape)
            {
                throw new ValueError("allow_broadcast and verify_shape are not both allowed.");
            }
            if (values is TensorProto tp)
            {
                return(tp);
            }

            if (dtype != TF_DataType.DtInvalid)
            {
                ;
            }

            bool is_quantized = new TF_DataType[]
            {
                TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16,
                TF_DataType.TF_QINT32
            }.Contains(dtype);

            // We first convert value to a numpy array or scalar.
            NDArray nparray = null;

            if (values is NDArray nd)
            {
                nparray = nd;
            }
            else
            {
                if (values == null)
                {
                    throw new ValueError("None values not supported.");
                }

                switch (values)
                {
                case bool boolVal:
                    nparray = boolVal;
                    break;

                case int intVal:
                    nparray = intVal;
                    break;

                case int[] intVals:
                    nparray = np.array(intVals);
                    break;

                case float floatVal:
                    nparray = floatVal;
                    break;

                case double doubleVal:
                    nparray = doubleVal;
                    break;

                case string strVal:
                    nparray = strVal;
                    break;

                case string[] strVals:
                    nparray = strVals;
                    break;

                default:
                    throw new Exception("make_tensor_proto Not Implemented");
                }
            }

            var numpy_dtype = dtypes.as_dtype(nparray.dtype);

            if (numpy_dtype == TF_DataType.DtInvalid)
            {
                throw new TypeError($"Unrecognized data type: {nparray.dtype}");
            }

            // If dtype was specified and is a quantized type, we convert
            // numpy_dtype back into the quantized version.
            if (is_quantized)
            {
                numpy_dtype = dtype;
            }

            bool is_same_size = false;
            int  shape_size   = 0;

            // If shape is not given, get the shape from the numpy array.
            if (shape == null)
            {
                shape        = nparray.shape;
                is_same_size = true;
                shape_size   = nparray.size;
            }
            else
            {
                shape_size   = new TensorShape(shape).Size;
                is_same_size = shape_size == nparray.size;
            }

            var tensor_proto = new tensor_pb2.TensorProto
            {
                Dtype       = numpy_dtype.as_datatype_enum(),
                TensorShape = tensor_util.as_shape(shape)
            };

            if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1)
            {
                byte[] bytes = nparray.ToByteArray();
                tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray());
                return(tensor_proto);
            }

            if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray))
            {
                if (values is string str)
                {
                    tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
                }
                else if (values is string[] str_values)
                {
                    tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
                }
                return(tensor_proto);
            }

            var proto_values = nparray.ravel();

            switch (nparray.dtype.Name)
            {
            case "Bool":
                tensor_proto.BoolVal.AddRange(proto_values.Data <bool>());
                break;

            case "Int32":
                tensor_proto.IntVal.AddRange(proto_values.Data <int>());
                break;

            case "Single":
                tensor_proto.FloatVal.AddRange(proto_values.Data <float>());
                break;

            case "Double":
                tensor_proto.DoubleVal.AddRange(proto_values.Data <double>());
                break;

            case "String":
                tensor_proto.StringVal.AddRange(proto_values.Data <string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString())));
                break;

            default:
                throw new Exception("make_tensor_proto Not Implemented");
            }

            return(tensor_proto);
        }
示例#25
0
 public Tensor argmax(Tensor input, int axis = -1, string name = null, int?dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
 => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
示例#26
0
 public NDArray frombuffer(byte[] bytes, Shape shape, TF_DataType dtype)
 {
     return(new NDArray(bytes, shape, dtype));
 }
示例#27
0
        /// <summary>
        /// Returns shape of tensors.
        /// </summary>
        /// <param name="input"></param>
        /// <param name="out_type"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
        {
            var _op = _op_def_lib._apply_op_helper("ShapeN", name, new { input, out_type });

            return(_op.outputs);
        }
示例#28
0
        public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
                                          string name = null, bool expand_animations           = true)
        {
            Tensor substr = null;

            Func <ITensorOrOperation> _jpeg = () =>
            {
                int    jpeg_channels   = channels;
                var    good_channels   = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
                string channels_msg    = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
                var    assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
                return(tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
                {
                    return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
                }));
            };

            Func <ITensorOrOperation> _gif = () =>
            {
                int gif_channels  = channels;
                var good_channels = math_ops.logical_and(
                    math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
                    math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));

                string channels_msg    = "Channels must be in (None, 0, 3) when decoding GIF images";
                var    assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
                return(tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
                {
                    var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
                    if (!expand_animations)
                    {
                        // result = array_ops.gather(result, 0);
                        throw new NotImplementedException("");
                    }
                    return result;
                }));
            };

            Func <ITensorOrOperation> _bmp = () =>
            {
                int    bmp_channels    = channels;
                var    signature       = string_ops.substr(contents, 0, 2);
                var    is_bmp          = math_ops.equal(signature, "BM", name: "is_bmp");
                string decode_msg      = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
                var    assert_decode   = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
                var    good_channels   = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
                string channels_msg    = "Channels must be in (None, 0, 3) when decoding BMP images";
                var    assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
                return(tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
                {
                    return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
                }));
            };

            Func <ITensorOrOperation> _png = () =>
            {
                return(convert_image_dtype(gen_image_ops.decode_png(
                                               contents,
                                               channels,
                                               dtype: dtype),
                                           dtype));
            };

            Func <ITensorOrOperation> check_gif = () =>
            {
                var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif");
                return(control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"));
            };

            Func <ITensorOrOperation> check_png = () =>
            {
                return(control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png"));
            };

            return(tf_with(ops.name_scope(name, "decode_image"), scope =>
            {
                substr = string_ops.substr(contents, 0, 3);
                return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
            }));
        }
示例#29
0
        /*public Operation(Graph g, string opType, string oper_name)
         * {
         *  _graph = g;
         *
         *  var _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
         *  c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
         *  lock (Locks.ProcessWide)
         *      using (var status = new Status())
         *      {
         *          _handle = c_api.TF_FinishOperation(_operDesc, status);
         *          status.Check(true);
         *      }
         *
         *  // Dict mapping op name to file and line information for op colocation
         *  // context managers.
         *  _control_flow_context = graph._get_control_flow_context();
         * }*/

        /// <summary>
        /// Creates an `Operation`.
        /// </summary>
        /// <param name="node_def">`node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.</param>
        /// <param name="g">`Graph`. The parent graph.</param>
        /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
        /// <param name="output_types">list of `DType` objects.</param>
        /// <param name="control_inputs">
        /// list of operations or tensors from which to have a
        /// control dependency.
        /// </param>
        /// <param name="input_types">
        /// List of `DType` objects representing the
        /// types of the tensors accepted by the `Operation`. By default
        /// uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
        /// reference-typed inputs must specify these explicitly.
        /// </param>
        /// <param name="original_op"></param>
        /// <param name="op_def"></param>
        public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;
            // Build the list of control inputs.
            var control_input_ops = new List <Operation>();

            if (control_inputs != null)
            {
                foreach (var c in control_inputs)
                {
                    switch (c)
                    {
                    case Operation c1:
                        control_input_ops.Add(c1);
                        break;

                    case Tensor tensor:
                        control_input_ops.Add(tensor.op);
                        break;

                    // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
                    //case IndexedSlices islices:
                    //    control_input_ops.Add(islices.op);
                    //    break;
                    default:
                        throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
                    }
                }
            }

            _id_value = _graph._next_id();

            // Dict mapping op name to file and line information for op colocation
            // context managers.
            _control_flow_context = graph._get_control_flow_context();

            // This will be set by self.inputs.
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            (_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
            _is_stateful      = op_def.IsStateful;

            // Initialize self._outputs.
            output_types = new TF_DataType[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            _outputs = new Tensor[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                _outputs[i] = new Tensor(this, i, output_types[i]);
            }

            graph._add_op(this);

            if (_handle != IntPtr.Zero)
            {
                _control_flow_post_processing();
            }
        }
示例#30
0
        /// <summary>
        /// Returns the index with the largest value across dimensions of a tensor.
        /// </summary>
        /// <param name="input"></param>
        /// <param name="dimension"></param>
        /// <param name="output_type"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
        {
            var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type });

            return(_op.outputs[0]);
        }