private OnnxTransformer(IHostEnvironment env, Arguments args, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer))) { Host.CheckValue(args, nameof(args)); foreach (var col in args.InputColumns) { Host.CheckNonWhiteSpace(col, nameof(args.InputColumns)); } foreach (var col in args.OutputColumns) { Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns)); } try { if (modelBytes == null) { Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile)); Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); Model = new OnnxModel(args.ModelFile, args.GpuDeviceId, args.FallbackToCpu); } else { Model = OnnxModel.CreateFromBytes(modelBytes, args.GpuDeviceId, args.FallbackToCpu); } } catch (OnnxRuntimeException e) { throw Host.Except(e, $"Error initializing model :{e.ToString()}"); } var modelInfo = Model.ModelInfo; Inputs = (args.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : args.InputColumns; Outputs = (args.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : args.OutputColumns; OutputTypes = new ColumnType[Outputs.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < Outputs.Length; i++) { var idx = Model.OutputNames.IndexOf(Outputs[i]); if (idx < 0) { throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model"); } var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; var shape = outputNodeInfo.Shape; var dims = AdjustDimensions(shape); OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray()); } _args = args; }
private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransform))) { Host.CheckValue(args, nameof(args)); foreach (var col in args.InputColumns) { Host.CheckNonWhiteSpace(col, nameof(args.InputColumns)); } foreach (var col in args.OutputColumns) { Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns)); } if (modelBytes == null) { Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile)); Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); Model = new OnnxModel(args.ModelFile); } else { Model = OnnxModel.CreateFromBytes(modelBytes); } var modelInfo = Model.ModelInfo; Inputs = args.InputColumns; Outputs = args.OutputColumns; OutputTypes = new ColumnType[args.OutputColumns.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < args.OutputColumns.Length; i++) { var idx = Model.OutputNames.IndexOf(args.OutputColumns[i]); if (idx < 0) { throw Host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model"); } var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; var shape = outputNodeInfo.Shape; var dims = AdjustDimensions(shape); OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims); } _args = args; }
private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); _host.CheckValue(args, nameof(args)); _host.CheckNonWhiteSpace(args.InputColumn, nameof(args.InputColumn)); _host.CheckNonWhiteSpace(args.OutputColumn, nameof(args.OutputColumn)); if (modelBytes == null) { _host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile)); _host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); Model = new OnnxModel(args.ModelFile); } else { Model = OnnxModel.CreateFromBytes(modelBytes); } var modelInfo = Model.ModelInfo; if (modelInfo.InputsInfo.Length != 1) { throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s)."); } if (modelInfo.OutputsInfo.Length != 1) { throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s)."); } Input = args.InputColumn; Output = args.OutputColumn; var outputNodeInfo = Model.ModelInfo.OutputsInfo[0]; var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type); var shape = outputNodeInfo.Shape; var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 }; OutputType = new VectorType(type, dims); _args = args; }