public Tensor[] TFE_ExecuteCancelable(Context ctx,
                                              string device_name,
                                              string op_name,
                                              Tensor[] inputs,
                                              object[] attrs,
                                              int num_outputs)
        {
            var status = tf.Status;
            var op     = GetOp(ctx, op_name, status);

            c_api.TFE_OpSetDevice(op, device_name, status.Handle);
            if (status.ok())
            {
                for (int i = 0; i < inputs.Length; ++i)
                {
                    SafeEagerTensorHandle tensor_handle = inputs[i] switch
                    {
                        EagerTensor et => et.EagerTensorHandle,
                        Tensor nd => nd.EagerTensorHandle,
                                          _ => throw new NotImplementedException("Eager tensor handle has not been allocated.")
                    };
                    c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
                    status.Check(true);
                }
            }
            if (status.ok() && attrs != null)
            {
                SetOpAttrs(op, attrs);
            }

            var outputs = new SafeEagerTensorHandle[num_outputs];

            if (status.ok())
            {
                c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
                status.Check(true);
            }
            return(outputs.Select(x => new EagerTensor(x)).ToArray());
        }
    }
        public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info)
        {
            if (op_exec_info.ctx == null)
            {
                op_exec_info.ctx = tf.Context;
            }
            if (string.IsNullOrEmpty(op_exec_info.device_name))
            {
                op_exec_info.device_name = tf.Context.DeviceName;
            }

            var attr_list_sizes = new Dictionary <string, long>();

            op_exec_info.run_gradient_callback   = HasAccumulatorOrTape();
            op_exec_info.run_post_exec_callbacks = op_exec_info.callbacks != null;
            op_exec_info.run_callbacks           = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks;

            var status = tf.Status;
            var op     = GetOp(op_exec_info.ctx, op_exec_info.op_name, status);

            var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name);

            var flattened_attrs  = new List <object>(op_def.Attr.Count * 2);
            var flattened_inputs = new List <Tensor>(op_def.InputArg.Count);

            // Set non-inferred attrs, including setting defaults if the attr is passed in
            // as None.
            if (op_exec_info.attrs != null)
            {
                foreach (var attr1 in op_exec_info.attrs)
                {
                    var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr1.Key);
                    if (attr != null)
                    {
                        flattened_attrs.Add(attr.Name);
                        flattened_attrs.Add(attr1.Value);

                        SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr.Name, attr1.Value, attr_list_sizes, status);
                        status.Check(true);
                    }
                }
            }

            // c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle);
            // status.Check(true);

            // Add inferred attrs and inputs.
            for (int i = 0; i < op_def.InputArg.Count; i++)
            {
                var input     = op_exec_info.args[i];
                var input_arg = op_def.InputArg[i];
                if (!string.IsNullOrEmpty(input_arg.NumberAttr))
                {
                    int len = (input as object[]).Length;
                    c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len);
                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(input_arg.NumberAttr);
                        flattened_attrs.Add(len);
                    }
                    attr_list_sizes[input_arg.NumberAttr] = len;

                    if (len > 0)
                    {
                        var fast_input_array = (object[])op_exec_info.args[i];
                        // First item adds the type attr.
                        if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status))
                        {
                            return(null);
                        }

                        for (var j = 1; j < len; j++)
                        {
                            // Since the list is homogeneous, we don't need to re-add the attr.
                            if (!AddInputToOp(fast_input_array[j], false, input_arg, flattened_attrs, flattened_inputs, op, status))
                            {
                                return(null);
                            }
                        }
                    }
                }
                else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
                {
                    var attr_name        = input_arg.TypeListAttr;
                    var fast_input_array = input as object[];
                    var len         = fast_input_array.Length;
                    var attr_values = new TF_DataType[len];

                    for (var j = 0; j < len; j++)
                    {
                        var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
                        attr_values[j] = eager_tensor.dtype;

                        c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);

                        if (op_exec_info.run_callbacks)
                        {
                            flattened_inputs.Add(eager_tensor);
                        }
                    }

                    if (op_exec_info.run_callbacks)
                    {
                        flattened_attrs.Add(attr_name);
                        flattened_attrs.Add(attr_values);
                    }
                    c_api.TFE_OpSetAttrTypeList(op, attr_name, attr_values, attr_values.Length);
                    attr_list_sizes[attr_name] = len;
                }
                else
                {
                    // The item is a single item.
                    AddInputToOp(op_exec_info.args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status);
                }
            }

            int num_retvals = 0;

            for (int i = 0; i < op_def.OutputArg.Count; i++)
            {
                var output_arg = op_def.OutputArg[i];
                var delta      = 1L;
                if (!string.IsNullOrEmpty(output_arg.NumberAttr))
                {
                    delta = attr_list_sizes[output_arg.NumberAttr];
                }
                else if (!string.IsNullOrEmpty(output_arg.TypeListAttr))
                {
                    delta = attr_list_sizes[output_arg.TypeListAttr];
                }
                if (delta < 0)
                {
                    throw new RuntimeError("Attributes suggest that the size of an output list is less than 0");
                }
                num_retvals += (int)delta;
            }

            var retVals = new SafeEagerTensorHandle[num_retvals];

            c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
            status.Check(true);

            var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();

            if (op_exec_info.run_callbacks)
            {
                RunCallbacks(op_exec_info,
                             op_def.InputArg.Count(),
                             flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result);
            }

            return(flat_result);
        }
Beispiel #3
0
 public EagerTensor(SafeEagerTensorHandle handle)
 {
     _id = ops.uid();
     _eagerTensorHandle = handle;
 }