private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer))) { Host.CheckValue(options, nameof(options)); foreach (var col in options.InputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.InputColumns)); } foreach (var col in options.OutputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns)); } try { if (modelBytes == null) { Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile)); Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile); Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu); } else { Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu); } } catch (OnnxRuntimeException e) { throw Host.Except(e, $"Error initializing model :{e.ToString()}"); } var modelInfo = Model.ModelInfo; Inputs = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns; Outputs = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns; OutputTypes = new DataViewType[Outputs.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < Outputs.Length; i++) { var idx = Model.OutputNames.IndexOf(Outputs[i]); if (idx < 0) { throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model"); } var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; var shape = outputNodeInfo.Shape; var dims = AdjustDimensions(shape); OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray()); } _options = options; }
private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer))) { Host.CheckValue(options, nameof(options)); foreach (var col in options.InputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.InputColumns)); } foreach (var col in options.OutputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns)); } // Use ONNXRuntime to figure out the right input and output configuration. // However, ONNXRuntime doesn't provide strongly-typed method to access the produced // variables, we will inspect the ONNX model file to get information regarding types. try { if (modelBytes == null) { // Entering this region means that the model file is passed in by the user. Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile)); Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile); // Because we cannot delete the user file, ownModelFile should be false. Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false); } else { // Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need // to create a temporal file to store it and then call ONNXRuntime's API to load that file. Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu); } } catch (OnnxRuntimeException e) { throw Host.Except(e, $"Error initializing model :{e.ToString()}"); } var modelInfo = Model.ModelInfo; Inputs = (options.InputColumns.Count() == 0) ? Model.ModelInfo.InputNames.ToArray() : options.InputColumns; Outputs = (options.OutputColumns.Count() == 0) ? Model.ModelInfo.OutputNames.ToArray() : options.OutputColumns; OutputTypes = new DataViewType[Outputs.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < Outputs.Length; i++) { var outputInfo = Model.ModelInfo.GetOutput(Outputs[i]); OutputTypes[i] = outputInfo.DataViewType; } _options = options; }