public Tensor[] TFE_ExecuteCancelable(Context ctx, string device_name, string op_name, Tensor[] inputs, object[] attrs, int num_outputs) { var status = tf.Status; var op = GetOp(ctx, op_name, status); status.Check(true); c_api.TFE_OpSetDevice(op, device_name, status.Handle); if (status.ok()) { for (int i = 0; i < inputs.Length; ++i) { SafeTensorHandleHandle tensor_handle; switch (inputs[i]) { case EagerTensor et: tensor_handle = et.EagerTensorHandle; break; default: tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status.Handle); break; } c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); status.Check(true); } } if (status.ok() && attrs != null) { SetOpAttrs(op, attrs); } var outputs = new SafeTensorHandleHandle[num_outputs]; if (status.ok()) { c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); status.Check(true); } return(outputs.Select(x => new EagerTensor(x)).ToArray()); }
public Tensor[] TFE_ExecuteCancelable(Context ctx, string device_name, string op_name, Tensor[] inputs, object[] attrs, int num_outputs) { var status = tf.Status; var op = GetOp(ctx, op_name, status); status.Check(true); c_api.TFE_OpSetDevice(op, device_name, status.Handle); if (status.ok()) { for (int i = 0; i < inputs.Length; ++i) { SafeTensorHandleHandle tensor_handle = inputs[i] switch { EagerTensor et => et.EagerTensorHandle, _ => throw new NotImplementedException("") }; c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); status.Check(true); } } if (status.ok() && attrs != null) { SetOpAttrs(op, attrs); } var outputs = new SafeTensorHandleHandle[num_outputs]; if (status.ok()) { c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); status.Check(true); } return(outputs.Select(x => new EagerTensor(x, op)).ToArray()); } }
public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero) { EagerTensorHandle = handle; Resolve(); }
public Tensor[] TFE_FastPathExecute(Context ctx, string device_name, string opName, string name, Action callbacks, params object[] args) { if (ctx == null) { throw new ValueError("This function does not handle the case of the path where " + "all inputs are not already EagerTensors."); } int args_size = args.Length; var attr_list_sizes = new Dictionary <string, long>(); FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo() { ctx = ctx, args = args, device_name = device_name, op_name = opName, name = name, }; op_exec_info.run_gradient_callback = HasAccumulatorOrTape(); op_exec_info.run_post_exec_callbacks = callbacks != null; op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; var status = tf.Status; var op = GetOp(ctx, opName, status); var op_def = tf.get_default_graph().GetOpDef(opName); var flattened_attrs = new List <object>(op_def.InputArg.Count); var flattened_inputs = new List <Tensor>(op_def.InputArg.Count); // Set non-inferred attrs, including setting defaults if the attr is passed in // as None. for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) { var attr_name = args[i].ToString(); var attr_value = args[i + 1]; var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name); if (attr != null) { flattened_attrs.Add(attr_name); flattened_attrs.Add(attr_value); SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); status.Check(true); } } c_api.TFE_OpSetDevice(op, device_name, status.Handle); status.Check(true); // Add inferred attrs and inputs. for (int i = 0; i < op_def.InputArg.Count; i++) { var input = args[kFastPathExecuteInputStartIndex + i]; var input_arg = op_def.InputArg[i]; if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { int len = (input as object[]).Length; c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); if (op_exec_info.run_callbacks) { flattened_attrs.Add(input_arg.NumberAttr); flattened_attrs.Add(len); } attr_list_sizes[input_arg.NumberAttr] = len; if (len > 0) { var fast_input_array = (object[])args[i]; // First item adds the type attr. if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) { return(null); } for (var j = 1; j < len; j++) { // Since the list is homogeneous, we don't need to re-add the attr. if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status)) { return(null); } } } } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { var attr_name = input_arg.TypeListAttr; var fast_input_array = input as object[]; var len = fast_input_array.Length; var attr_values = new TF_DataType[len]; for (var j = 0; j < len; j++) { var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); attr_values[j] = eager_tensor.dtype; c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); if (op_exec_info.run_callbacks) { flattened_inputs.Add(eager_tensor); } } if (op_exec_info.run_callbacks) { flattened_attrs.Add(attr_name); flattened_attrs.Add(attr_values); } c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length); attr_list_sizes[attr_name] = len; } else { // The item is a single item. AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); } } int num_retvals = 0; for (int i = 0; i < op_def.OutputArg.Count; i++) { var output_arg = op_def.OutputArg[i]; var delta = 1L; if (!string.IsNullOrEmpty(output_arg.NumberAttr)) { delta = attr_list_sizes[output_arg.NumberAttr]; } else if (!string.IsNullOrEmpty(output_arg.TypeListAttr)) { delta = attr_list_sizes[output_arg.TypeListAttr]; } if (delta < 0) { throw new RuntimeError("Attributes suggest that the size of an output list is less than 0"); } num_retvals += (int)delta; } var retVals = new SafeTensorHandleHandle[num_retvals]; c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); status.Check(true); var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); if (op_exec_info.run_callbacks) { RunCallbacks(op_exec_info, kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result); } return(flat_result); }
public EagerTensor(SafeTensorHandleHandle handle, SafeOpHandle opHandle) : base(IntPtr.Zero) { _opHandle = opHandle; EagerTensorHandle = handle; Resolve(); }
public EagerTensor(SafeTensorHandleHandle handle) { _id = ops.uid(); EagerTensorHandle = handle; Resolve(); }