Exemple #1
0
        public Tensor[] TFE_ExecuteCancelable(Context ctx,
                                              string device_name,
                                              string op_name,
                                              Tensor[] inputs,
                                              object[] attrs,
                                              int num_outputs)
        {
            var status = tf.Status;
            var op     = GetOp(ctx, op_name, status);

            status.Check(true);
            c_api.TFE_OpSetDevice(op, device_name, status.Handle);
            if (status.ok())
            {
                for (int i = 0; i < inputs.Length; ++i)
                {
                    SafeTensorHandleHandle tensor_handle;
                    switch (inputs[i])
                    {
                    case EagerTensor et:
                        tensor_handle = et.EagerTensorHandle;
                        break;

                    default:
                        tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status.Handle);
                        break;
                    }
                    c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
                    status.Check(true);
                }
            }
            if (status.ok() && attrs != null)
            {
                SetOpAttrs(op, attrs);
            }

            var outputs = new SafeTensorHandleHandle[num_outputs];

            if (status.ok())
            {
                c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
                status.Check(true);
            }
            return(outputs.Select(x => new EagerTensor(x)).ToArray());
        }
Exemple #2
0
        public Tensor[] TFE_ExecuteCancelable(Context ctx,
                                              string device_name,
                                              string op_name,
                                              Tensor[] inputs,
                                              object[] attrs,
                                              int num_outputs)
        {
            var status = tf.Status;
            var op     = GetOp(ctx, op_name, status);

            status.Check(true);
            c_api.TFE_OpSetDevice(op, device_name, status.Handle);
            if (status.ok())
            {
                for (int i = 0; i < inputs.Length; ++i)
                {
                    SafeTensorHandleHandle tensor_handle = inputs[i] switch
                    {
                        EagerTensor et => et.EagerTensorHandle,
                                           _ => throw new NotImplementedException("")
                    };
                    c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
                    status.Check(true);
                }
            }
            if (status.ok() && attrs != null)
            {
                SetOpAttrs(op, attrs);
            }

            var outputs = new SafeTensorHandleHandle[num_outputs];

            if (status.ok())
            {
                c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
                status.Check(true);
            }
            return(outputs.Select(x => new EagerTensor(x, op)).ToArray());
        }
    }
Exemple #3
0
 public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero)
 {
     EagerTensorHandle = handle;
     Resolve();
 }
        public Tensor[] TFE_FastPathExecute(Context ctx,
                                            string device_name,
                                            string opName,
                                            string name,
                                            Action callbacks,
                                            params object[] args)
        {
            if (ctx == null)
            {
                throw new ValueError("This function does not handle the case of the path where " +
                                     "all inputs are not already EagerTensors.");
            }

            int args_size       = args.Length;
            var attr_list_sizes = new Dictionary <string, long>();

            FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo()
            {
                ctx         = ctx,
                args        = args,
                device_name = device_name,
                op_name     = opName,
                name        = name,
            };

            op_exec_info.run_gradient_callback   = HasAccumulatorOrTape();
            op_exec_info.run_post_exec_callbacks = callbacks != null;
            op_exec_info.run_callbacks           = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks;

            var status = tf.Status;
            var op     = GetOp(ctx, opName, status);

            var op_def = tf.get_default_graph().GetOpDef(opName);

            var flattened_attrs  = new List <object>(op_def.InputArg.Count);
            var flattened_inputs = new List <Tensor>(op_def.InputArg.Count);

            // Set non-inferred attrs, including setting defaults if the attr is passed in
            // as None.
            for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2)
            {
                var attr_name  = args[i].ToString();
                var attr_value = args[i + 1];

                var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name);
                if (attr != null)
                {
                    flattened_attrs.Add(attr_name);
                    flattened_attrs.Add(attr_value);

                    SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status);
                    status.Check(true);
                }
            }

            c_api.TFE_OpSetDevice(op, device_name, status.Handle);
            status.Check(true);

            // Add inferred attrs and inputs.
            for (int i = 0; i < op_def.InputArg.Count; i++)
            {
                var input     = args[kFastPathExecuteInputStartIndex + i];
                var input_arg = op_def.InputArg[i];
                if (!string.IsNullOrEmpty(input_arg.NumberAttr))
                {
                    int len = (input as object[]).Length;
                    c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(input_arg.NumberAttr);
                        flattened_attrs.Add(len);
                    }
                    attr_list_sizes[input_arg.NumberAttr] = len;

                    if (len > 0)
                    {
                        var fast_input_array = (object[])args[i];
                        // First item adds the type attr.
                        if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
                        {
                            return(null);
                        }

                        for (var j = 1; j < len; j++)
                        {
                            // Since the list is homogeneous, we don't need to re-add the attr.
                            if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status))
                            {
                                return(null);
                            }
                        }
                    }
                }
                else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
                {
                    var attr_name        = input_arg.TypeListAttr;
                    var fast_input_array = input as object[];
                    var len         = fast_input_array.Length;
                    var attr_values = new TF_DataType[len];

                    for (var j = 0; j < len; j++)
                    {
                        var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
                        attr_values[j] = eager_tensor.dtype;

                        c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);

                        if (op_exec_info.run_callbacks)
                        {
                            flattened_inputs.Add(eager_tensor);
                        }
                    }

                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(attr_name);
                        flattened_attrs.Add(attr_values);
                    }
                    c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length);
                    attr_list_sizes[attr_name] = len;
                }
                else
                {
                    // The item is a single item.
                    AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
                }
            }

            int num_retvals = 0;

            for (int i = 0; i < op_def.OutputArg.Count; i++)
            {
                var output_arg = op_def.OutputArg[i];
                var delta      = 1L;
                if (!string.IsNullOrEmpty(output_arg.NumberAttr))
                {
                    delta = attr_list_sizes[output_arg.NumberAttr];
                }
                else if (!string.IsNullOrEmpty(output_arg.TypeListAttr))
                {
                    delta = attr_list_sizes[output_arg.TypeListAttr];
                }
                if (delta < 0)
                {
                    throw new RuntimeError("Attributes suggest that the size of an output list is less than 0");
                }
                num_retvals += (int)delta;
            }

            var retVals = new SafeTensorHandleHandle[num_retvals];

            c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
            status.Check(true);

            var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();

            if (op_exec_info.run_callbacks)
            {
                RunCallbacks(op_exec_info,
                             kFastPathExecuteInputStartIndex + op_def.InputArg.Count(),
                             flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result);
            }

            return(flat_result);
        }
Exemple #5
0
 public EagerTensor(SafeTensorHandleHandle handle, SafeOpHandle opHandle) : base(IntPtr.Zero)
 {
     _opHandle         = opHandle;
     EagerTensorHandle = handle;
     Resolve();
 }
Exemple #6
0
 public EagerTensor(SafeTensorHandleHandle handle)
 {
     _id = ops.uid();
     EagerTensorHandle = handle;
     Resolve();
 }