/// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
                // This code path will be re-enabled when there is appropriate support in onnxruntime
                throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                    model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }

            var initializerTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Initializer)
            {
                initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
            }

            var outputTypePool = new Dictionary <string, DataViewType>();
            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
            var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
            var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

            // Create a view to the used ONNX model from ONNXRuntime's perspective.
            ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
        }
Example #2
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _disposed = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        _session = new InferenceSession(modelFile);
                    }
                    else
                    {
                        // If called from OnnxTransform, is caught and rethrown
                        throw;
                    }
                }
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            try
            {
                // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
                // doesn't expose full type information via its C# APIs.
                var model = new OnnxCSharpToProtoWrapper.ModelProto();
                // If we own the model file set the DeleteOnClose flag so it is always deleted.
                if (ownModelFile)
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose);
                }
                else
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read);
                }

                // The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
                using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
                    using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 100))
                        model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

                // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
                var inputTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Input)
                {
                    inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                }

                var initializerTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Initializer)
                {
                    initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
                }

                var outputTypePool = new Dictionary <string, DataViewType>();
                // Build casters which maps NamedOnnxValue to .NET objects.
                var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();
                foreach (var valueInfo in model.Graph.Output)
                {
                    outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                    casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
                }

                var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
                var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
                var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

                // Create a view to the used ONNX model from ONNXRuntime's perspective.
                ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

                Graph = model.Graph;
            }
            catch
            {
                _session.Dispose();
                _session = null;
                throw;
            }
        }
Example #3
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
                // This code path will be re-enabled when there is appropriate support in onnxruntime
                throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                    model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }
            var outputTypePool = new Dictionary <string, DataViewType>();

            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var onnxRuntimeInputInfos = new List <OnnxVariableInfo>();

            // Collect input information for this ONNX model from ONNXRuntime's perspective.
            foreach (var pair in _session.InputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = inputTypePool[name];

                OnnxVariableInfo info = null;
                if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
                {
                    // If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
                    // ONNX model file and the deduced DataViewVectorType.

                    if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
                    {
                        throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
                                                         "The specified shape " + string.Join(",", shapeDictionary[name]) +
                                                         " is not compatible with the shape " + string.Join(",", meta.Dimensions) +
                                                         " loaded from the ONNX model file. Only unknown dimension can replace or " +
                                                         "be replaced by another dimension.");
                    }

                    if (dataViewType is VectorDataViewType vectorType)
                    {
                        if (shapeDictionary[name].All(value => value > 0))
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
                        }
                        else
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType);
                        }
                    }

                    info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null);
                }
                else
                {
                    // No user-specified shape is found, so the shape loaded from ONNX model file is used.
                    info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
                }
                onnxRuntimeInputInfos.Add(info);
            }

            var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>();

            // Collect output information for this ONNX model from ONNXRuntime's perspective.
            foreach (var pair in _session.OutputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = outputTypePool[name];
                var caster       = casterPool[name];

                OnnxVariableInfo info = null;
                if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
                {
                    // If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
                    // ONNX model file.

                    if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
                    {
                        throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
                                                         "The specified shape " + string.Join(",", shapeDictionary[name]) +
                                                         " is not compatible with the shape " + string.Join(",", meta.Dimensions) +
                                                         " loaded from the ONNX model file. Only unknown dimension can replace or " +
                                                         "be replaced by another dimension.");
                    }

                    if (dataViewType is VectorDataViewType vectorType)
                    {
                        if (shapeDictionary[name].All(value => value > 0))
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
                        }
                        else
                        {
                            dataViewType = new VectorDataViewType(vectorType.ItemType);
                        }
                    }

                    info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, caster);
                }
                else
                {
                    // No user-specified shape is found, so the shape loaded from ONNX model file is used.
                    info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
                }

                onnxRuntimeOutputInfos.Add(info);
            }

            // Create a view to the used ONNX model from ONNXRuntime's perspective.
            ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
        }
Example #4
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        _session = new InferenceSession(modelFile);
                    }
                    else
                    {
                        // if called from OnnxTranform, is caught and rethrown.
                        throw;
                    }
                }
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }
            var outputTypePool = new Dictionary <string, DataViewType>();

            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var onnxRuntimeInputInfos = new List <OnnxVariableInfo>();

            foreach (var pair in _session.InputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = inputTypePool[name];
                var info         = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
                onnxRuntimeInputInfos.Add(info);
            }

            var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>();

            foreach (var pair in _session.OutputMetadata)
            {
                var name         = pair.Key;
                var meta         = pair.Value;
                var dataViewType = outputTypePool[name];
                var caster       = casterPool[name];
                var info         = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
                onnxRuntimeOutputInfos.Add(info);
            }

            ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
        }