Ejemplo n.º 1
0
 /// <summary>
 /// Derive the corresponding <see cref="Type"/> for ONNX variable typed to <paramref name="typeProto"/>.
 /// The corresponding <see cref="Type"/> should match the type system in ONNXRuntime's C# APIs.
 /// </summary>
 /// <param name="typeProto">ONNX variable's type.</param>
 public static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
 {
     if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
     {
         if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0)
         {
             return(GetNativeScalarType(typeProto.TensorType.ElemType));
         }
         else
         {
             Type tensorType  = typeof(VBuffer <>);
             Type elementType = GetNativeScalarType(typeProto.TensorType.ElemType);
             return(tensorType.MakeGenericType(elementType));
         }
     }
     else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
     {
         var enumerableType = typeof(IEnumerable <>);
         var elementType    = GetNativeType(typeProto.SequenceType.ElemType);
         return(enumerableType.MakeGenericType(elementType));
     }
     else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
     {
         var  dictionaryType = typeof(IDictionary <,>);
         Type keyType        = GetNativeScalarType(typeProto.MapType.KeyType);
         Type valueType      = GetNativeType(typeProto.MapType.ValueType);
         return(dictionaryType.MakeGenericType(keyType, valueType));
     }
     return(null);
 }
Ejemplo n.º 2
0
        /// <summary>
        /// Derive the corresponding <see cref="DataViewType"/> for ONNX variable typed to <paramref name="typeProto"/>.
        /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs.
        /// </summary>
        /// <param name="typeProto">ONNX variable's type.</param>
        public static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
        {
            var oneOfFieldName = typeProto.ValueCase.ToString();

            if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
            {
                if (typeProto.TensorType.Shape.Dim.Count == 0)
                {
                    // ONNX scalar is a tensor without shape information; that is,
                    // ONNX scalar's shape is an empty list.
                    return(GetScalarDataViewType(typeProto.TensorType.ElemType));
                }
                else
                {
                    var shape = GetTensorDims(typeProto.TensorType.Shape);
                    if (shape == null)
                    {
                        // Scalar has null shape.
                        return(GetScalarDataViewType(typeProto.TensorType.ElemType));
                    }
                    else if (shape.Count() != 0 && shape.Aggregate((x, y) => x * y) > 0)
                    {
                        // Known shape tensor.
                        return(new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape.ToArray()));
                    }
                    else
                    {
                        // Tensor with unknown shape.
                        return(new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0));
                    }
                }
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
            {
                var elemTypeProto = typeProto.SequenceType.ElemType;
                var elemType      = GetNativeType(elemTypeProto);
                return(new OnnxSequenceType(elemType));
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
            {
                var keyType   = GetNativeScalarType(typeProto.MapType.KeyType);
                var valueType = GetNativeType(typeProto.MapType.ValueType);
                return(new OnnxMapType(keyType, valueType));
            }
            else
            {
                throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
            }
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Create a <see cref="Func{T, TResult}"/> to map a <see cref="NamedOnnxValue"/> to the associated .NET <see langword="object"/>.
        /// The resulted .NET object's actual type is <paramref name="resultedType"/>.
        /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs.
        /// </summary>
        /// <param name="typeProto">ONNX variable's type.</param>
        /// <param name="resultedType">C# type of <paramref name="typeProto"/>.</param>
        public static Func <NamedOnnxValue, object> GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto typeProto, out Type resultedType)
        {
            var oneOfFieldName = typeProto.ValueCase.ToString();

            if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
            {
                var shape = GetTensorDims(typeProto.TensorType.Shape);

                if (shape == null)
                {
                    // Entering this scope means that an ONNX scalar is found. Note that ONNX scalar is typed to tensor without a shape.

                    // Get tensor element type.
                    var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;

                    // Access the first element as a scalar.
                    var accessInfo        = typeof(Tensor <>).GetMethod(nameof(Tensor <int> .GetValue));
                    var accessSpecialized = accessInfo.MakeGenericMethod(type);

                    // NamedOnnxValue to scalar.
                    Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                    {
                        var scalar = accessSpecialized.Invoke(value, new object[] { 0 });
                        return(scalar);
                    };

                    resultedType = type;

                    return(caster);
                }
                else
                {
                    // Entering this scope means an ONNX tensor is found.

                    var type              = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;
                    var methodInfo        = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsTensor));
                    var methodSpecialized = methodInfo.MakeGenericMethod(type);

                    // NamedOnnxValue to Tensor.
                    Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) => methodSpecialized.Invoke(value, new object[] { });

                    resultedType = typeof(Tensor <>).MakeGenericType(type);

                    return(caster);
                }
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
            {
                // Now, we see a Sequence in ONNX. If its element type is T, the variable produced by
                // ONNXRuntime would be typed to IEnumerable<T>.

                // Find a proper caster (a function which maps NamedOnnxValue to a .NET object) for the element in
                // the ONNX sequence. Note that ONNX sequence is typed to IEnumerable<NamedOnnxValue>, so we need
                // to convert NamedOnnxValue to a proper type such as IDictionary<>.
                var elementCaster = GetDataViewValueCasterAndResultedType(typeProto.SequenceType.ElemType, out Type elementType);

                // Set the .NET type which corresponds to the first input argument, typeProto.
                resultedType = typeof(IEnumerable <>).MakeGenericType(elementType);

                // Create the element's caster to map IEnumerable<NamedOnnxValue> produced by ONNXRuntime to
                // IEnumerable<elementType>.
                var methodInfo        = typeof(CastHelper).GetMethod(nameof(CastHelper.CastOnnxSequenceToIEnumerable));
                var methodSpecialized = methodInfo.MakeGenericMethod(typeof(NamedOnnxValue), elementType);

                // Use element-level caster to create sequence caster.
                Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                {
                    var enumerable = value.AsEnumerable <NamedOnnxValue>();
                    return(methodSpecialized.Invoke(null, new object[] { enumerable, elementCaster }));
                };

                return(caster);
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
            {
                // Entering this scope means a ONNX Map (equivalent to IDictionary<>) will be produced.

                var keyType   = GetNativeScalarType(typeProto.MapType.KeyType);
                var valueType = GetNativeType(typeProto.MapType.ValueType);

                // The resulted type of the object returned by the caster below.
                resultedType = typeof(IDictionary <,>).MakeGenericType(keyType, valueType);

                // Create a method to convert NamedOnnxValue to IDictionary<keyValue, valueType>.
                var asDictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary));
                var asDictionaryMethod     = asDictionaryMethodInfo.MakeGenericMethod(keyType, valueType);

                // Create a caster to convert NamedOnnxValue to IDictionary<keyValue, valueType>.
                Func <NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                {
                    return(asDictionaryMethod.Invoke(value, new object[] { }));
                };

                return(caster);
            }
            else
            {
                throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
            }
        }