/// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
                // This code path will be re-enabled when there is appropriate support in onnxruntime
                throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                    model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }

            var initializerTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Initializer)
            {
                initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
            }

            var outputTypePool = new Dictionary <string, DataViewType>();
            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
            var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
            var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

            // Create a view to the used ONNX model from ONNXRuntime's perspective.
            ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
        }
Example #2
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _disposed = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        _session = new InferenceSession(modelFile);
                    }
                    else
                    {
                        // If called from OnnxTransform, is caught and rethrown
                        throw;
                    }
                }
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            try
            {
                // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
                // doesn't expose full type information via its C# APIs.
                var model = new OnnxCSharpToProtoWrapper.ModelProto();
                // If we own the model file set the DeleteOnClose flag so it is always deleted.
                if (ownModelFile)
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose);
                }
                else
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read);
                }

                // The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
                using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
                    using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 100))
                        model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

                // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
                var inputTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Input)
                {
                    inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                }

                var initializerTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Initializer)
                {
                    initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
                }

                var outputTypePool = new Dictionary <string, DataViewType>();
                // Build casters which maps NamedOnnxValue to .NET objects.
                var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();
                foreach (var valueInfo in model.Graph.Output)
                {
                    outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                    casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
                }

                var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
                var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
                var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

                // Create a view to the used ONNX model from ONNXRuntime's perspective.
                ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

                Graph = model.Graph;
            }
            catch
            {
                _session.Dispose();
                _session = null;
                throw;
            }
        }