Exemplo n.º 1
0
        SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeEagerTensorHandle a)
        {
            using var status = TF_NewStatus();

            var op = TFE_NewOp(ctx, "Shape", status);

            CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
            TFE_OpAddInput(op, a, status);
            CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
            TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));

            return(op);
        }
Exemplo n.º 2
0
 /// <summary>
 /// Execute the operation defined by <paramref name="op"/> and return handles to computed
 /// tensors in <paramref name="retvals"/>.
 /// </summary>
 /// <remarks>
 /// Upon successful return, the first <paramref name="num_retvals"/> slots in <paramref name="retvals"/> will
 /// contain handle instances which the caller is responsible for disposing once they are no longer in use.
 /// </remarks>
 /// <param name="op"></param>
 /// <param name="retvals"></param>
 /// <param name="num_retvals"></param>
 /// <param name="status"></param>
 public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
 {
     unsafe
     {
         num_retvals = retvals?.Length ?? 0;
         var rawReturns = stackalloc IntPtr[num_retvals];
         TFE_Execute(op, rawReturns, ref num_retvals, status);
         for (var i = 0; i < num_retvals; i++)
         {
             // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
             // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
             // values.
             retvals[i] = new SafeEagerTensorHandle(rawReturns[i]);
         }
     }
 }
Exemplo n.º 3
0
 public static extern SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status);
Exemplo n.º 4
0
 public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);
Exemplo n.º 5
0
 public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status);
Exemplo n.º 6
0
 public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status);
Exemplo n.º 7
0
 public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status);
Exemplo n.º 8
0
 public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h);
Exemplo n.º 9
0
 public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status);
Exemplo n.º 10
0
 protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status)
 => c_api.TFE_OpAddInput(op, h, status);
Exemplo n.º 11
0
 protected int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status)
 => c_api.TFE_TensorHandleNumDims(h, status);
Exemplo n.º 12
0
 protected TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h)
 => c_api.TFE_TensorHandleDataType(h);
Exemplo n.º 13
0
 protected SafeEagerTensorHandle TFE_TensorHandleCopyToDevice(SafeEagerTensorHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status)
 => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
Exemplo n.º 14
0
 protected string TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status)
 => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
Exemplo n.º 15
0
 protected SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status)
 => c_api.TFE_TensorHandleResolve(h, status);
Exemplo n.º 16
0
        unsafe SafeEagerTensorHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status)
        {
            var var_handle = new SafeEagerTensorHandle[1];
            int num_retvals;

            using (var op = TFE_NewOp(ctx, "VarHandleOp", status))
            {
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }
                TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
                TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
                TFE_OpSetAttrString(op, "container", "", 0);
                TFE_OpSetAttrString(op, "shared_name", "", 0);
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }
                TFE_Execute(op, var_handle, out num_retvals, status);
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }
                CHECK_EQ(1, num_retvals);
            }

            // Assign 'value' to it.
            using (var op = TFE_NewOp(ctx, "AssignVariableOp", status))
            {
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }
                TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
                TFE_OpAddInput(op, var_handle[0], status);

                // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
                var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float));
                tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t));

                var value_handle = c_api.TFE_NewTensorHandle(t, status);
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }

                TFE_OpAddInput(op, value_handle, status);
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }

                c_api.TFE_Execute(op, null, out num_retvals, status);
                if (TF_GetCode(status) != TF_OK)
                {
                    return(new SafeEagerTensorHandle(IntPtr.Zero));
                }
                CHECK_EQ(0, num_retvals);
            }

            return(var_handle[0]);
        }