/// <summary> /// Derive the corresponding <see cref="Type"/> for ONNX variable typed to <paramref name="typeProto"/>. /// The corresponding <see cref="Type"/> should match the type system in ONNXRuntime's C# APIs. /// </summary> /// <param name="typeProto">ONNX variable's type.</param> public static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto) { if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) { if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0) { return(GetNativeScalarType(typeProto.TensorType.ElemType)); } else { Type tensorType = typeof(VBuffer <>); Type elementType = GetNativeScalarType(typeProto.TensorType.ElemType); return(tensorType.MakeGenericType(elementType)); } } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) { var enumerableType = typeof(IEnumerable <>); var elementType = GetNativeType(typeProto.SequenceType.ElemType); return(enumerableType.MakeGenericType(elementType)); } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) { var dictionaryType = typeof(IDictionary <,>); Type keyType = GetNativeScalarType(typeProto.MapType.KeyType); Type valueType = GetNativeType(typeProto.MapType.ValueType); return(dictionaryType.MakeGenericType(keyType, valueType)); } return(null); }
/// <summary> /// Derive the corresponding <see cref="DataViewType"/> for ONNX variable typed to <paramref name="typeProto"/>. /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs. /// </summary> /// <param name="typeProto">ONNX variable's type.</param> public static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto) { var oneOfFieldName = typeProto.ValueCase.ToString(); if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) { if (typeProto.TensorType.Shape.Dim.Count == 0) { // ONNX scalar is a tensor without shape information; that is, // ONNX scalar's shape is an empty list. return(GetScalarDataViewType(typeProto.TensorType.ElemType)); } else { var shape = GetTensorDims(typeProto.TensorType.Shape); if (shape == null) { // Scalar has null shape. return(GetScalarDataViewType(typeProto.TensorType.ElemType)); } else if (shape.Count() != 0 && shape.Aggregate((x, y) => x * y) > 0) { // Known shape tensor. return(new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape.ToArray())); } else { // Tensor with unknown shape. return(new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0)); } } } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) { var elemTypeProto = typeProto.SequenceType.ElemType; var elemType = GetNativeType(elemTypeProto); return(new OnnxSequenceType(elemType)); } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) { var keyType = GetNativeScalarType(typeProto.MapType.KeyType); var valueType = GetNativeType(typeProto.MapType.ValueType); return(new OnnxMapType(keyType, valueType)); } else { throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}"); } }
/// <summary> /// Create a <see cref="Func{T, TResult}"/> to map a <see cref="NamedOnnxValue"/> to the associated .NET <see langword="object"/>. /// The resulted .NET object's actual type is <paramref name="resultedType"/>. /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs. /// </summary> /// <param name="typeProto">ONNX variable's type.</param> /// <param name="resultedType">C# type of <paramref name="typeProto"/>.</param> public static Func <NamedOnnxValue, object> GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto typeProto, out Type resultedType) { var oneOfFieldName = typeProto.ValueCase.ToString(); if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) { var shape = GetTensorDims(typeProto.TensorType.Shape); if (shape == null) { // Entering this scope means that an ONNX scalar is found. Note that ONNX scalar is typed to tensor without a shape. // Get tensor element type. var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType; // Access the first element as a scalar. var accessInfo = typeof(Tensor <>).GetMethod(nameof(Tensor <int> .GetValue)); var accessSpecialized = accessInfo.MakeGenericMethod(type); // NamedOnnxValue to scalar. Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) => { var scalar = accessSpecialized.Invoke(value, new object[] { 0 }); return(scalar); }; resultedType = type; return(caster); } else { // Entering this scope means an ONNX tensor is found. var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType; var methodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsTensor)); var methodSpecialized = methodInfo.MakeGenericMethod(type); // NamedOnnxValue to Tensor. Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) => methodSpecialized.Invoke(value, new object[] { }); resultedType = typeof(Tensor <>).MakeGenericType(type); return(caster); } } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) { // Now, we see a Sequence in ONNX. If its element type is T, the variable produced by // ONNXRuntime would be typed to IEnumerable<T>. // Find a proper caster (a function which maps NamedOnnxValue to a .NET object) for the element in // the ONNX sequence. Note that ONNX sequence is typed to IEnumerable<NamedOnnxValue>, so we need // to convert NamedOnnxValue to a proper type such as IDictionary<>. var elementCaster = GetDataViewValueCasterAndResultedType(typeProto.SequenceType.ElemType, out Type elementType); // Set the .NET type which corresponds to the first input argument, typeProto. resultedType = typeof(IEnumerable <>).MakeGenericType(elementType); // Create the element's caster to map IEnumerable<NamedOnnxValue> produced by ONNXRuntime to // IEnumerable<elementType>. var methodInfo = typeof(CastHelper).GetMethod(nameof(CastHelper.CastOnnxSequenceToIEnumerable)); var methodSpecialized = methodInfo.MakeGenericMethod(typeof(NamedOnnxValue), elementType); // Use element-level caster to create sequence caster. Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) => { var enumerable = value.AsEnumerable <NamedOnnxValue>(); return(methodSpecialized.Invoke(null, new object[] { enumerable, elementCaster })); }; return(caster); } else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) { // Entering this scope means a ONNX Map (equivalent to IDictionary<>) will be produced. var keyType = GetNativeScalarType(typeProto.MapType.KeyType); var valueType = GetNativeType(typeProto.MapType.ValueType); // The resulted type of the object returned by the caster below. resultedType = typeof(IDictionary <,>).MakeGenericType(keyType, valueType); // Create a method to convert NamedOnnxValue to IDictionary<keyValue, valueType>. var asDictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary)); var asDictionaryMethod = asDictionaryMethodInfo.MakeGenericMethod(keyType, valueType); // Create a caster to convert NamedOnnxValue to IDictionary<keyValue, valueType>. Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) => { return(asDictionaryMethod.Invoke(value, new object[] { })); }; return(caster); } else { throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}"); } }