/// <summary> /// Reads the value of a variable. /// </summary> /// <param name="resource"></param> /// <param name="dtype"></param> /// <param name="name"></param> /// <returns></returns> public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) => tf.Context.ExecuteOp("ReadVariableOp", name, new ExecuteOpArgs(resource) .SetAttributes(new { dtype }));
public Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name);
public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) { var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type }); return(_op.outputs[0]); }
public Tensor[] TFE_FastPathExecute(Context ctx, string device_name, string opName, string name, Action callbacks, params object[] args) { if (ctx == null) { throw new ValueError("This function does not handle the case of the path where " + "all inputs are not already EagerTensors."); } int args_size = args.Length; var attr_list_sizes = new Dictionary <string, long>(); FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo() { ctx = ctx, args = args, device_name = device_name, op_name = opName, name = name, }; op_exec_info.run_gradient_callback = HasAccumulatorOrTape(); op_exec_info.run_post_exec_callbacks = callbacks != null; op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; var status = tf.Status; var op = GetOp(ctx, opName, status); var op_def = tf.get_default_graph().GetOpDef(opName); var flattened_attrs = new List <object>(op_def.InputArg.Count); var flattened_inputs = new List <Tensor>(op_def.InputArg.Count); // Set non-inferred attrs, including setting defaults if the attr is passed in // as None. for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) { var attr_name = args[i].ToString(); var attr_value = args[i + 1]; var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name); if (attr != null) { flattened_attrs.Add(attr_name); flattened_attrs.Add(attr_value); SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); status.Check(true); } } c_api.TFE_OpSetDevice(op, device_name, status.Handle); status.Check(true); // Add inferred attrs and inputs. for (int i = 0; i < op_def.InputArg.Count; i++) { var input = args[kFastPathExecuteInputStartIndex + i]; var input_arg = op_def.InputArg[i]; if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { int len = (input as object[]).Length; c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); if (op_exec_info.run_callbacks) { flattened_attrs.Add(input_arg.NumberAttr); flattened_attrs.Add(len); } attr_list_sizes[input_arg.NumberAttr] = len; if (len > 0) { var fast_input_array = (object[])args[i]; // First item adds the type attr. if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) { return(null); } for (var j = 1; j < len; j++) { // Since the list is homogeneous, we don't need to re-add the attr. if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status)) { return(null); } } } } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { var attr_name = input_arg.TypeListAttr; var fast_input_array = input as object[]; var len = fast_input_array.Length; var attr_values = new TF_DataType[len]; for (var j = 0; j < len; j++) { var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); attr_values[j] = eager_tensor.dtype; c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); if (op_exec_info.run_callbacks) { flattened_inputs.Add(eager_tensor); } } if (op_exec_info.run_callbacks) { flattened_attrs.Add(attr_name); flattened_attrs.Add(attr_values); } c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length); attr_list_sizes[attr_name] = len; } else { // The item is a single item. AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); } } int num_retvals = 0; for (int i = 0; i < op_def.OutputArg.Count; i++) { var output_arg = op_def.OutputArg[i]; var delta = 1L; if (!string.IsNullOrEmpty(output_arg.NumberAttr)) { delta = attr_list_sizes[output_arg.NumberAttr]; } else if (!string.IsNullOrEmpty(output_arg.TypeListAttr)) { delta = attr_list_sizes[output_arg.TypeListAttr]; } if (delta < 0) { throw new RuntimeError("Attributes suggest that the size of an output list is less than 0"); } num_retvals += (int)delta; } var retVals = new SafeTensorHandleHandle[num_retvals]; c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); status.Check(true); var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); if (op_exec_info.run_callbacks) { RunCallbacks(op_exec_info, kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result); } return(flat_result); }
public Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) => math_ops.cast(x, dtype, name);
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) => c_api.TFE_OpSetAttrType(op, attr_name, value);
private void _init_from_args(object initial_value, bool trainable = true, List <string> collections = null, bool validate_shape = true, string caching_device = "", string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { if (initial_value is null) { throw new ValueError("initial_value must be specified."); } var init_from_fn = initial_value.GetType().Name == "Func`1"; if (collections == null) { collections = new List <string> { ops.GraphKeys.GLOBAL_VARIABLES }; } // Store the graph key so optimizers know how to only retrieve variables from // this graph. _graph_key = ops.get_default_graph()._graph_key; _trainable = trainable; if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) { collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); } ops.init_scope(); var values = init_from_fn ? new object[0] : new object[] { initial_value }; with(ops.name_scope(name, "Variable", values), scope => { name = scope; if (init_from_fn) { // Use attr_scope and device(None) to simulate the behavior of // colocate_with when the variable we want to colocate with doesn't // yet exist. string true_name = ops._name_from_scope_name(name); var attr = new AttrValue { List = new AttrValue.Types.ListValue() }; attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); with(ops.name_scope("Initializer"), scope2 => { _initial_value = (initial_value as Func <Tensor>)(); _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); }); _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); } // Or get the initial value from a Tensor or Python object. else { _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); var shape = _initial_value.shape; dtype = _initial_value.dtype; _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); } // Manually overrides the variable's shape with the initial value's. if (validate_shape) { var initial_value_shape = _initial_value.GetShape(); if (!initial_value_shape.is_fully_defined()) { throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); } } // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; if (!String.IsNullOrEmpty(caching_device)) { } else { ops.colocate_with(_initializer_op); _snapshot = gen_array_ops.identity(_variable, name = "read"); } ops.add_to_collections(collections, this as VariableV1); }); }
/// <summary> /// /// </summary> /// <param name="type"></param> /// <returns></returns> /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> public static TF_DataType as_tf_dtype(this Type type) { while (type.IsArray) { type = type.GetElementType(); } TF_DataType dtype = TF_DataType.DtInvalid; switch (type.Name) { case "Char": dtype = TF_DataType.TF_UINT8; break; case "SByte": dtype = TF_DataType.TF_INT8; break; case "Byte": dtype = TF_DataType.TF_UINT8; break; case "Int16": dtype = TF_DataType.TF_INT16; break; case "UInt16": dtype = TF_DataType.TF_UINT16; break; case "Int32": dtype = TF_DataType.TF_INT32; break; case "UInt32": dtype = TF_DataType.TF_UINT32; break; case "Int64": dtype = TF_DataType.TF_INT64; break; case "UInt64": dtype = TF_DataType.TF_UINT64; break; case "Single": dtype = TF_DataType.TF_FLOAT; break; case "Double": dtype = TF_DataType.TF_DOUBLE; break; case "Complex": dtype = TF_DataType.TF_COMPLEX128; break; case "String": dtype = TF_DataType.TF_STRING; break; case "Boolean": dtype = TF_DataType.TF_BOOL; break; default: dtype = TF_DataType.DtInvalid; break; } return(dtype); }
public static DataType as_datatype_enum(this TF_DataType type) { return((DataType)type); }
public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) { return(gen_array_ops.placeholder(dtype, shape, name)); }
/// <summary> /// Creates a recurrent neural network specified by RNNCell `cell`. /// </summary> /// <param name="cell">An instance of RNNCell.</param> /// <param name="inputs">The RNN inputs.</param> /// <param name="dtype"></param> /// <param name="swap_memory"></param> /// <param name="time_major"></param> /// <returns>A pair (outputs, state)</returns> public (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, int?parallel_iterations = null, bool swap_memory = false, bool time_major = false) => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
private void _init_from_args(object initial_value = null, bool trainable = true, List <string> collections = null, string caching_device = "", string name = null, TF_DataType dtype = TF_DataType.DtInvalid, VariableAggregation aggregation = VariableAggregation.None, TensorShape shape = null) { var init_from_fn = initial_value.GetType().Name == "Func`1" || initial_value.GetType().GetInterface("IInitializer") != null; if (collections == null) { collections = new List <string>() { tf.GraphKeys.GLOBAL_VARIABLES } } ; _trainable = trainable; if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) { collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); } _in_graph_mode = !tf.Context.executing_eagerly(); tf_with(ops.init_scope2(), delegate { var values = init_from_fn ? new object[0] : new object[] { initial_value }; tf_with(ops.name_scope(name, "Variable", values), scope => { name = scope; var handle_name = ops.name_from_scope_name(name); string unique_id = ""; string shared_name = ""; if (_in_graph_mode) { shared_name = handle_name; unique_id = shared_name; } else { unique_id = $"{handle_name}_{ops.uid()}"; shared_name = tf.Context.shared_name(); } var attr = new AttrValue(); attr.List = new AttrValue.Types.ListValue(); attr.List.S.Add(ByteString.CopyFromUtf8($"loc:@{handle_name}")); tf_with(ops.name_scope("Initializer"), delegate { if (initial_value.GetType().GetInterface("IInitializer") != null) { initial_value = ops.convert_to_tensor((initial_value as IInitializer).Apply(new InitializerArgs(shape, dtype: dtype))); } else { var value = init_from_fn ? (initial_value as Func <Tensor>)() : initial_value; initial_value = ops.convert_to_tensor(value, name: "initial_value", dtype: dtype); } }); _shape = shape ?? (initial_value as Tensor).TensorShape; _initial_value = initial_value as Tensor; if (_in_graph_mode) { handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; ops.colocate_with(initializer_op); _graph_element = gen_array_ops.identity(handle, name = "read"); ops.add_to_collections <IVariableV1>(collections, this); _dtype = handle.dtype; } else { handle = resource_variable_ops.eager_safe_variable_handle( initial_value: _initial_value, shape: _shape, shared_name: shared_name, name: name, graph_mode: _in_graph_mode); gen_resource_variable_ops.assign_variable_op(handle, _initial_value); is_initialized_op = null; initializer_op = null; _graph_element = null; _dtype = _initial_value.dtype.as_base_dtype(); initial_value = _in_graph_mode ? initial_value : null; } base.__init__(trainable: trainable, handle: handle, name: name, unique_id: unique_id, handle_name: handle_name); }); }); }
/// <summary> /// Returns shape of tensors. /// </summary> /// <param name="input"></param> /// <param name="out_type"></param> /// <param name="name"></param> /// <returns></returns> public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) => tf.Context.ExecuteOp("ShapeN", name, new ExecuteOpArgs() { OpInputArgs = new object[] { input } }.SetAttributes(new { out_type }));
public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) => tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) .SetAttributes(new { out_type }));
/// <summary> /// Creates a constant tensor. /// /// The resulting tensor is populated with values of type `dtype`, as /// specified by arguments `value` and (optionally) `shape` /// </summary> /// <param name="value">A constant value (or list) of output type `dtype`.</param> /// <param name="dtype">The type of the elements of the resulting tensor.</param> /// <param name="shape">Optional dimensions of resulting tensor.</param> /// <param name="name">Optional name for the tensor.</param> /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> /// <returns></returns> public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") { return(_constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true)); }
public static TF_DataType as_base_dtype(this TF_DataType type) { return((int)type > 100 ? (TF_DataType)((int)type - 100) : type); }
protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) => c_api.TF_SetAttrType(desc, attrName, dtype);
public static int name(this TF_DataType type) { return((int)type); }
public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") { if (limit == null) { limit = start; start = 0; } if (delta == null) { delta = 1; } return(with(ops.name_scope(name, "Range", new { start, limit, delta }), scope => { name = scope; var start1 = ops.convert_to_tensor(start, name: "start"); var limit1 = ops.convert_to_tensor(limit, name: "limit"); var delta1 = ops.convert_to_tensor(delta, name: "delta"); return gen_math_ops.range(start1, limit1, delta1, name); })); }
public static string as_numpy_name(this TF_DataType type) => type switch {
public Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) => gen_array_ops.placeholder(dtype, shape, name);
public DenseSpec(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { _shape = shape; _dtype = dtype; _name = name; }
public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name);
/// <summary> /// Create a TensorProto. /// </summary> /// <param name="values"></param> /// <param name="dtype"></param> /// <param name="shape"></param> /// <param name="verify_shape"></param> /// <param name="allow_broadcast"></param> /// <returns></returns> public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool verify_shape = false, bool allow_broadcast = false) { if (allow_broadcast && verify_shape) { throw new ValueError("allow_broadcast and verify_shape are not both allowed."); } if (values is TensorProto tp) { return(tp); } if (dtype != TF_DataType.DtInvalid) { ; } bool is_quantized = new TF_DataType[] { TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 }.Contains(dtype); // We first convert value to a numpy array or scalar. NDArray nparray = null; if (values is NDArray nd) { nparray = nd; } else { if (values == null) { throw new ValueError("None values not supported."); } switch (values) { case bool boolVal: nparray = boolVal; break; case int intVal: nparray = intVal; break; case int[] intVals: nparray = np.array(intVals); break; case float floatVal: nparray = floatVal; break; case double doubleVal: nparray = doubleVal; break; case string strVal: nparray = strVal; break; case string[] strVals: nparray = strVals; break; default: throw new Exception("make_tensor_proto Not Implemented"); } } var numpy_dtype = dtypes.as_dtype(nparray.dtype); if (numpy_dtype == TF_DataType.DtInvalid) { throw new TypeError($"Unrecognized data type: {nparray.dtype}"); } // If dtype was specified and is a quantized type, we convert // numpy_dtype back into the quantized version. if (is_quantized) { numpy_dtype = dtype; } bool is_same_size = false; int shape_size = 0; // If shape is not given, get the shape from the numpy array. if (shape == null) { shape = nparray.shape; is_same_size = true; shape_size = nparray.size; } else { shape_size = new TensorShape(shape).Size; is_same_size = shape_size == nparray.size; } var tensor_proto = new tensor_pb2.TensorProto { Dtype = numpy_dtype.as_datatype_enum(), TensorShape = tensor_util.as_shape(shape) }; if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) { byte[] bytes = nparray.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); return(tensor_proto); } if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) { if (values is string str) { tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); } else if (values is string[] str_values) { tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); } return(tensor_proto); } var proto_values = nparray.ravel(); switch (nparray.dtype.Name) { case "Bool": tensor_proto.BoolVal.AddRange(proto_values.Data <bool>()); break; case "Int32": tensor_proto.IntVal.AddRange(proto_values.Data <int>()); break; case "Single": tensor_proto.FloatVal.AddRange(proto_values.Data <float>()); break; case "Double": tensor_proto.DoubleVal.AddRange(proto_values.Data <double>()); break; case "String": tensor_proto.StringVal.AddRange(proto_values.Data <string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString()))); break; default: throw new Exception("make_tensor_proto Not Implemented"); } return(tensor_proto); }
public Tensor argmax(Tensor input, int axis = -1, string name = null, int?dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
public NDArray frombuffer(byte[] bytes, Shape shape, TF_DataType dtype) { return(new NDArray(bytes, shape, dtype)); }
/// <summary> /// Returns shape of tensors. /// </summary> /// <param name="input"></param> /// <param name="out_type"></param> /// <param name="name"></param> /// <returns></returns> public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) { var _op = _op_def_lib._apply_op_helper("ShapeN", name, new { input, out_type }); return(_op.outputs); }
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, string name = null, bool expand_animations = true) { Tensor substr = null; Func <ITensorOrOperation> _jpeg = () => { int jpeg_channels = channels; var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); return(tf_with(ops.control_dependencies(new[] { assert_channels }), delegate { return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); })); }; Func <ITensorOrOperation> _gif = () => { int gif_channels = channels; var good_channels = math_ops.logical_and( math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); return(tf_with(ops.control_dependencies(new[] { assert_channels }), delegate { var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); if (!expand_animations) { // result = array_ops.gather(result, 0); throw new NotImplementedException(""); } return result; })); }; Func <ITensorOrOperation> _bmp = () => { int bmp_channels = channels; var signature = string_ops.substr(contents, 0, 2); var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); return(tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate { return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); })); }; Func <ITensorOrOperation> _png = () => { return(convert_image_dtype(gen_image_ops.decode_png( contents, channels, dtype: dtype), dtype)); }; Func <ITensorOrOperation> check_gif = () => { var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif"); return(control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif")); }; Func <ITensorOrOperation> check_png = () => { return(control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png")); }; return(tf_with(ops.name_scope(name, "decode_image"), scope => { substr = string_ops.substr(contents, 0, 3); return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); })); }
/*public Operation(Graph g, string opType, string oper_name) * { * _graph = g; * * var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); * c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); * lock (Locks.ProcessWide) * using (var status = new Status()) * { * _handle = c_api.TF_FinishOperation(_operDesc, status); * status.Check(true); * } * * // Dict mapping op name to file and line information for op colocation * // context managers. * _control_flow_context = graph._get_control_flow_context(); * }*/ /// <summary> /// Creates an `Operation`. /// </summary> /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> /// <param name="g">`Graph`. The parent graph.</param> /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> /// <param name="output_types">list of `DType` objects.</param> /// <param name="control_inputs"> /// list of operations or tensors from which to have a /// control dependency. /// </param> /// <param name="input_types"> /// List of `DType` objects representing the /// types of the tensors accepted by the `Operation`. By default /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect /// reference-typed inputs must specify these explicitly. /// </param> /// <param name="original_op"></param> /// <param name="op_def"></param> public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { _graph = g; // Build the list of control inputs. var control_input_ops = new List <Operation>(); if (control_inputs != null) { foreach (var c in control_inputs) { switch (c) { case Operation c1: control_input_ops.Add(c1); break; case Tensor tensor: control_input_ops.Add(tensor.op); break; // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented //case IndexedSlices islices: // control_input_ops.Add(islices.op); // break; default: throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); } } } _id_value = _graph._next_id(); // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); // This will be set by self.inputs. if (op_def == null) { op_def = g.GetOpDef(node_def.Op); } var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); (_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); _is_stateful = op_def.IsStateful; // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { output_types[i] = OutputType(i); } _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { _outputs[i] = new Tensor(this, i, output_types[i]); } graph._add_op(this); if (_handle != IntPtr.Zero) { _control_flow_post_processing(); } }
/// <summary> /// Returns the index with the largest value across dimensions of a tensor. /// </summary> /// <param name="input"></param> /// <param name="dimension"></param> /// <param name="output_type"></param> /// <param name="name"></param> /// <returns></returns> public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) { var _op = _op_def_lib._apply_op_helper("ArgMax", name, new { input, dimension, output_type }); return(_op.outputs[0]); }