public NativeOnnxTensorMemory(IntPtr onnxValueHandle)
        {
            IntPtr typeAndShape = IntPtr.Zero;

            try
            {
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorShapeAndType(onnxValueHandle, out typeAndShape));

                TensorElementType elemType = NativeMethods.ONNXRuntimeGetTensorElementType(typeAndShape);

                Type type  = null;
                int  width = 0;
                TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width);
                if (typeof(T) != type)
                {
                    throw new NotSupportedException(nameof(NativeOnnxTensorMemory <T>) + " does not support T = " + nameof(T));
                }
                _elementWidth = width;

                _onnxValueHandle = onnxValueHandle;
                // derive the databuffer pointer, element_count, element_width, and shape
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorMutableData(_onnxValueHandle, out _dataBufferHandle));
                // throws OnnxRuntimeException if native call failed

                ulong dimension = NativeMethods.ONNXRuntimeGetNumOfDimensions(typeAndShape);
                long  count     = NativeMethods.ONNXRuntimeGetTensorShapeElementCount(typeAndShape); // count can be negative.
                if (count < 0)
                {
                    throw new NotSupportedException("Symbolic dimensions in the tensor is not supported");
                }

                long[] shape = new long[dimension];
                NativeMethods.ONNXRuntimeGetDimensions(typeAndShape, shape, dimension); //Note: shape must be alive during the call

                _elementCount = (int)count;
                _dimensions   = new int[dimension];
                for (ulong i = 0; i < dimension; i++)
                {
                    _dimensions[i] = (int)shape[i];
                }
            }
            catch (Exception e)
            {
                //TODO: cleanup any partially created state
                //Do not call ReleaseTensor here. If the constructor has thrown exception, then this NativeOnnxTensorWrapper is not created, so caller should take appropriate action to dispose
                throw e;
            }
            finally
            {
                if (typeAndShape != IntPtr.Zero)
                {
                    NativeMethods.ONNXRuntimeReleaseObject(typeAndShape);
                }
            }
        }
Example #2
0
        internal static NamedOnnxValue CreateFromOnnxValue(string name, IntPtr nativeOnnxValue)
        {
            NamedOnnxValue result = null;

            /* Get Tensor element type */  //TODO: Assumed value is Tensor, need to support non-tensor types in future
            IntPtr            typeAndShape = IntPtr.Zero;
            TensorElementType elemType     = TensorElementType.DataTypeMax;

            try
            {
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorShapeAndType(nativeOnnxValue, out typeAndShape));
                elemType = NativeMethods.ONNXRuntimeGetTensorElementType(typeAndShape);
            }
            finally
            {
                if (typeAndShape != IntPtr.Zero)
                {
                    NativeMethods.ONNXRuntimeReleaseObject(typeAndShape);
                }
            }

            switch (elemType)
            {
            case TensorElementType.Float:
                result = NameOnnxValueFromNativeTensor <float>(name, nativeOnnxValue);
                break;

            case TensorElementType.Double:
                result = NameOnnxValueFromNativeTensor <double>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int16:
                result = NameOnnxValueFromNativeTensor <short>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt16:
                result = NameOnnxValueFromNativeTensor <ushort>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int32:
                result = NameOnnxValueFromNativeTensor <int>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt32:
                result = NameOnnxValueFromNativeTensor <uint>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int64:
                result = NameOnnxValueFromNativeTensor <long>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt64:
                result = NameOnnxValueFromNativeTensor <ulong>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt8:
                result = NameOnnxValueFromNativeTensor <byte>(name, nativeOnnxValue);
                break;

            default:
                throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");
            }

            return(result);
        }