예제 #1
0
        private List <OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary <string, NodeMetadata> nodeMetadata,
                                                                     IDictionary <string, int[]> shapeDictionary,
                                                                     Dictionary <string, DataViewType> typePool,
                                                                     Dictionary <string, Func <NamedOnnxValue, object> > casterPool)
        {
            var onnxVariableInfos = new List <OnnxVariableInfo>();

            foreach (var pair in nodeMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = typePool[name];
                var caster       = casterPool?[name];

                if (name.StartsWith("mlnet.") &&
                    (name.EndsWith(".unusedInput") || name.EndsWith(".unusedOutput")))
                {
                    continue;
                }

                OnnxVariableInfo info = null;
                if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
                {
                    if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
                    {
                        throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
                                                         "The specified shape " + string.Join(",", shapeDictionary[name]) +
                                                         " is not compatible with the shape " + string.Join(",", meta.Dimensions) +
                                                         " loaded from the ONNX model file. Only unknown dimension can replace or " +
                                                         "be replaced by another dimension.");
                    }

                    if (dataViewType is VectorDataViewType vectorType)
                    {
                        if (shapeDictionary[name].All(value => value > 0))
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
                        }
                        else
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType);
                        }
                    }

                    info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, caster);
                }
                else
                {
                    // No user-specified shape is found, so the shape loaded from ONNX model file is used.
                    info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
                }

                onnxVariableInfos.Add(info);
            }
            return(onnxVariableInfos);
        }
예제 #2
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
                // This code path will be re-enabled when there is appropriate support in onnxruntime
                throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                    model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }
            var outputTypePool = new Dictionary <string, DataViewType>();

            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var onnxRuntimeInputInfos = new List <OnnxVariableInfo>();

            // Collect input information for this ONNX model from ONNXRuntime's perspective.
            foreach (var pair in _session.InputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = inputTypePool[name];

                OnnxVariableInfo info = null;
                if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
                {
                    // If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
                    // ONNX model file and the deduced DataViewVectorType.

                    if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
                    {
                        throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
                                                         "The specified shape " + string.Join(",", shapeDictionary[name]) +
                                                         " is not compatible with the shape " + string.Join(",", meta.Dimensions) +
                                                         " loaded from the ONNX model file. Only unknown dimension can replace or " +
                                                         "be replaced by another dimension.");
                    }

                    if (dataViewType is VectorDataViewType vectorType)
                    {
                        if (shapeDictionary[name].All(value => value > 0))
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
                        }
                        else
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType);
                        }
                    }

                    info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null);
                }
                else
                {
                    // No user-specified shape is found, so the shape loaded from ONNX model file is used.
                    info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
                }
                onnxRuntimeInputInfos.Add(info);
            }

            var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>();

            // Collect output information for this ONNX model from ONNXRuntime's perspective.
            foreach (var pair in _session.OutputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = outputTypePool[name];
                var caster       = casterPool[name];

                OnnxVariableInfo info = null;
                if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
                {
                    // If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
                    // ONNX model file.

                    if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
                    {
                        throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
                                                         "The specified shape " + string.Join(",", shapeDictionary[name]) +
                                                         " is not compatible with the shape " + string.Join(",", meta.Dimensions) +
                                                         " loaded from the ONNX model file. Only unknown dimension can replace or " +
                                                         "be replaced by another dimension.");
                    }

                    if (dataViewType is VectorDataViewType vectorType)
                    {
                        if (shapeDictionary[name].All(value => value > 0))
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
                        }
                        else
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType);
                        }
                    }

                    info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, caster);
                }
                else
                {
                    // No user-specified shape is found, so the shape loaded from ONNX model file is used.
                    info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
                }

                onnxRuntimeOutputInfos.Add(info);
            }

            // Create a view to the used ONNX model from ONNXRuntime's perspective.
            ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
        }
예제 #3
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        _session = new InferenceSession(modelFile);
                    }
                    else
                    {
                        // if called from OnnxTranform, is caught and rethrown.
                        throw;
                    }
                }
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }
            var outputTypePool = new Dictionary <string, DataViewType>();

            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var onnxRuntimeInputInfos = new List <OnnxVariableInfo>();

            foreach (var pair in _session.InputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = inputTypePool[name];
                var info         = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
                onnxRuntimeInputInfos.Add(info);
            }

            var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>();

            foreach (var pair in _session.OutputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = outputTypePool[name];
                var caster       = casterPool[name];
                var info         = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
                onnxRuntimeOutputInfos.Add(info);
            }

            ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
        }