private List <OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary <string, NodeMetadata> nodeMetadata, IDictionary <string, int[]> shapeDictionary, Dictionary <string, DataViewType> typePool, Dictionary <string, Func <NamedOnnxValue, object> > casterPool) { var onnxVariableInfos = new List <OnnxVariableInfo>(); foreach (var pair in nodeMetadata) { var name = pair.Key; var meta = pair.Value; var dataViewType = typePool[name]; var caster = casterPool?[name]; if (name.StartsWith("mlnet.") && (name.EndsWith(".unusedInput") || name.EndsWith(".unusedOutput"))) { continue; } OnnxVariableInfo info = null; if (shapeDictionary != null && shapeDictionary.ContainsKey(name)) { if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList())) { throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary), "The specified shape " + string.Join(",", shapeDictionary[name]) + " is not compatible with the shape " + string.Join(",", meta.Dimensions) + " loaded from the ONNX model file. Only unknown dimension can replace or " + "be replaced by another dimension."); } if (dataViewType is VectorDataViewType vectorType) { if (shapeDictionary[name].All(value => value > 0)) { dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]); } else { dataViewType = new VectorDataViewType(vectorType.ItemType); } } info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, caster); } else { // No user-specified shape is found, so the shape loaded from ONNX model file is used. info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster); } onnxVariableInfos.Add(info); } return(onnxVariableInfos); }
/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null) { ModelFile = modelFile; // If we don't own the model file, _disposed should be false to prevent deleting user's file. _ownModelFile = ownModelFile; _disposed = false; if (gpuDeviceId != null) { // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms. // This code path will be re-enabled when there is appropriate support in onnxruntime throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!"); } else { _session = new InferenceSession(modelFile); } // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. ModelFile = modelFile; var model = new OnnxCSharpToProtoWrapper.ModelProto(); using (var modelStream = File.OpenRead(modelFile)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var onnxRuntimeInputInfos = new List <OnnxVariableInfo>(); // Collect input information for this ONNX model from ONNXRuntime's perspective. foreach (var pair in _session.InputMetadata) { var name = pair.Key; var meta = pair.Value; var dataViewType = inputTypePool[name]; OnnxVariableInfo info = null; if (shapeDictionary != null && shapeDictionary.ContainsKey(name)) { // If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from // ONNX model file and the deduced DataViewVectorType. if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList())) { throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary), "The specified shape " + string.Join(",", shapeDictionary[name]) + " is not compatible with the shape " + string.Join(",", meta.Dimensions) + " loaded from the ONNX model file. Only unknown dimension can replace or " + "be replaced by another dimension."); } if (dataViewType is VectorDataViewType vectorType) { if (shapeDictionary[name].All(value => value > 0)) { dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]); } else { dataViewType = new VectorDataViewType(vectorType.ItemType); } } info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null); } else { // No user-specified shape is found, so the shape loaded from ONNX model file is used. info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null); } onnxRuntimeInputInfos.Add(info); } var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>(); // Collect output information for this ONNX model from ONNXRuntime's perspective. foreach (var pair in _session.OutputMetadata) { var name = pair.Key; var meta = pair.Value; var dataViewType = outputTypePool[name]; var caster = casterPool[name]; OnnxVariableInfo info = null; if (shapeDictionary != null && shapeDictionary.ContainsKey(name)) { // If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from // ONNX model file. if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList())) { throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary), "The specified shape " + string.Join(",", shapeDictionary[name]) + " is not compatible with the shape " + string.Join(",", meta.Dimensions) + " loaded from the ONNX model file. Only unknown dimension can replace or " + "be replaced by another dimension."); } if (dataViewType is VectorDataViewType vectorType) { if (shapeDictionary[name].All(value => value > 0)) { dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]); } else { dataViewType = new VectorDataViewType(vectorType.ItemType); } } info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, caster); } else { // No user-specified shape is found, so the shape loaded from ONNX model file is used. info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster); } onnxRuntimeOutputInfos.Add(info); } // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos); }
/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false) { ModelFile = modelFile; // If we don't own the model file, _disposed should be false to prevent deleting user's file. _ownModelFile = ownModelFile; _disposed = false; if (gpuDeviceId != null) { try { _session = new InferenceSession(modelFile, SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value)); } catch (OnnxRuntimeException) { if (fallbackToCpu) { _session = new InferenceSession(modelFile); } else { // if called from OnnxTranform, is caught and rethrown. throw; } } } else { _session = new InferenceSession(modelFile); } // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. ModelFile = modelFile; var model = new OnnxCSharpToProtoWrapper.ModelProto(); using (var modelStream = File.OpenRead(modelFile)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var onnxRuntimeInputInfos = new List <OnnxVariableInfo>(); foreach (var pair in _session.InputMetadata) { var name = pair.Key; var meta = pair.Value; var dataViewType = inputTypePool[name]; var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null); onnxRuntimeInputInfos.Add(info); } var onnxRuntimeOutputInfos = new List <OnnxVariableInfo>(); foreach (var pair in _session.OutputMetadata) { var name = pair.Key; var meta = pair.Value; var dataViewType = outputTypePool[name]; var caster = casterPool[name]; var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster); onnxRuntimeOutputInfos.Add(info); } ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos); }