예제 #1
0
        public NativeOnnxTensorMemory(IntPtr onnxValueHandle)
        {
            IntPtr typeAndShape = IntPtr.Zero;

            try
            {
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorShapeAndType(onnxValueHandle, out typeAndShape));

                TensorElementType elemType = NativeMethods.ONNXRuntimeGetTensorElementType(typeAndShape);

                Type type  = null;
                int  width = 0;
                TensorElementTypeConverter.GetTypeAndWidth(elemType, out type, out width);
                if (typeof(T) != type)
                {
                    throw new NotSupportedException(nameof(NativeOnnxTensorMemory <T>) + " does not support T = " + nameof(T));
                }
                _elementWidth = width;

                _onnxValueHandle = onnxValueHandle;
                // derive the databuffer pointer, element_count, element_width, and shape
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorMutableData(_onnxValueHandle, out _dataBufferHandle));
                // throws OnnxRuntimeException if native call failed

                ulong dimension = NativeMethods.ONNXRuntimeGetNumOfDimensions(typeAndShape);
                long  count     = NativeMethods.ONNXRuntimeGetTensorShapeElementCount(typeAndShape); // count can be negative.
                if (count < 0)
                {
                    throw new NotSupportedException("Symbolic dimensions in the tensor is not supported");
                }

                long[] shape = new long[dimension];
                NativeMethods.ONNXRuntimeGetDimensions(typeAndShape, shape, dimension); //Note: shape must be alive during the call

                _elementCount = (int)count;
                _dimensions   = new int[dimension];
                for (ulong i = 0; i < dimension; i++)
                {
                    _dimensions[i] = (int)shape[i];
                }
            }
            catch (Exception e)
            {
                //TODO: cleanup any partially created state
                //Do not call ReleaseTensor here. If the constructor has thrown exception, then this NativeOnnxTensorWrapper is not created, so caller should take appropriate action to dispose
                throw e;
            }
            finally
            {
                if (typeAndShape != IntPtr.Zero)
                {
                    NativeMethods.ONNXRuntimeReleaseObject(typeAndShape);
                }
            }
        }
예제 #2
0
        private NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo)
        {
            IntPtr tensorInfo = NativeMethods.ONNXRuntimeCastTypeInfoToTensorInfo(typeInfo);
            // Convert the newly introduced ONNXRuntimeTypeInfo* to the older ONNXRuntimeTypeAndShapeInfo*

            TensorElementType type = NativeMethods.ONNXRuntimeGetTensorElementType(tensorInfo);
            Type dotnetType        = null;
            int  width             = 0;

            TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width);
            ulong numDimensions = NativeMethods.ONNXRuntimeGetNumOfDimensions(tensorInfo);

            long[] dimensions = new long[(int)numDimensions];
            NativeMethods.ONNXRuntimeGetDimensions(tensorInfo, dimensions, numDimensions);
            int[] intDimensions = new int[(int)numDimensions];
            for (ulong i = 0; i < numDimensions; i++)
            {
                intDimensions[i] = (int)dimensions[i];
            }
            return(new NodeMetadata(intDimensions, dotnetType));
        }
예제 #3
0
        internal static NamedOnnxValue CreateFromOnnxValue(string name, IntPtr nativeOnnxValue)
        {
            NamedOnnxValue result = null;

            /* Get Tensor element type */  //TODO: Assumed value is Tensor, need to support non-tensor types in future
            IntPtr            typeAndShape = IntPtr.Zero;
            TensorElementType elemType     = TensorElementType.DataTypeMax;

            try
            {
                NativeApiStatus.VerifySuccess(NativeMethods.ONNXRuntimeGetTensorShapeAndType(nativeOnnxValue, out typeAndShape));
                elemType = NativeMethods.ONNXRuntimeGetTensorElementType(typeAndShape);
            }
            finally
            {
                if (typeAndShape != IntPtr.Zero)
                {
                    NativeMethods.ONNXRuntimeReleaseObject(typeAndShape);
                }
            }

            switch (elemType)
            {
            case TensorElementType.Float:
                result = NameOnnxValueFromNativeTensor <float>(name, nativeOnnxValue);
                break;

            case TensorElementType.Double:
                result = NameOnnxValueFromNativeTensor <double>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int16:
                result = NameOnnxValueFromNativeTensor <short>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt16:
                result = NameOnnxValueFromNativeTensor <ushort>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int32:
                result = NameOnnxValueFromNativeTensor <int>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt32:
                result = NameOnnxValueFromNativeTensor <uint>(name, nativeOnnxValue);
                break;

            case TensorElementType.Int64:
                result = NameOnnxValueFromNativeTensor <long>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt64:
                result = NameOnnxValueFromNativeTensor <ulong>(name, nativeOnnxValue);
                break;

            case TensorElementType.UInt8:
                result = NameOnnxValueFromNativeTensor <byte>(name, nativeOnnxValue);
                break;

            default:
                throw new NotSupportedException("Tensor of element type: " + elemType + " is not supported");
            }

            return(result);
        }