/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null) { ModelFile = modelFile; // If we don't own the model file, _disposed should be false to prevent deleting user's file. _ownModelFile = ownModelFile; _disposed = false; if (gpuDeviceId != null) { // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms. // This code path will be re-enabled when there is appropriate support in onnxruntime throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!"); } else { _session = new InferenceSession(modelFile); } // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. ModelFile = modelFile; var model = new OnnxCSharpToProtoWrapper.ModelProto(); using (var modelStream = File.OpenRead(modelFile)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var initializerTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Initializer) { initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); }
/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null) { // If we don't own the model file, _disposed should be false to prevent deleting user's file. _disposed = false; if (gpuDeviceId != null) { try { _session = new InferenceSession(modelFile, SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value)); } catch (OnnxRuntimeException) { if (fallbackToCpu) { _session = new InferenceSession(modelFile); } else { // If called from OnnxTransform, is caught and rethrown throw; } } } else { _session = new InferenceSession(modelFile); } try { // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. var model = new OnnxCSharpToProtoWrapper.ModelProto(); // If we own the model file set the DeleteOnClose flag so it is always deleted. if (ownModelFile) { ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose); } else { ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read); } // The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here. using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 100)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var initializerTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Initializer) { initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); Graph = model.Graph; } catch { _session.Dispose(); _session = null; throw; } }