public OnnxContextImpl(IHostEnvironment env, string name, string producerName, string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OnnxContext)); _host.CheckValue(name, nameof(name)); _host.CheckValue(name, nameof(domain)); _nodes = new List <OnnxCSharpToProtoWrapper.NodeProto>(); _intermediateValues = new List <OnnxUtils.ModelArgs>(); _inputs = new List <OnnxUtils.ModelArgs>(); _initializers = new List <OnnxCSharpToProtoWrapper.TensorProto>(); _outputs = new List <OnnxUtils.ModelArgs>(); _columnNameMap = new Dictionary <string, string>(); _variableNames = new HashSet <string>(); _nodeNames = new HashSet <string>(); _name = name; _producerName = producerName; _producerVersion = producerVersion; _modelVersion = modelVersion; _domain = domain; _onnxVersion = onnxVersion; _opSetVersion = opSetVersion <= CurrentOpSetVersion ? opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}"); }
/// <summary> /// Exports to onnx. /// </summary> /// <param name="firstTransform">select the first transform</param> /// <param name="inputs">inputs, modified by the function to contains what is actually used</param> /// <param name="outputs">outputs, modified by the function to contains what is actually used</param> public ScikitOnnxContext ToOnnx(int firstTransform = 0, string[] inputs = null, string[] outputs = null, string name = null, string producer = "Scikit.ML", long version = 0, string domain = "onnx.ai.ml", OnnxVersion onnxVersion = OnnxVersion.Stable) { var begin = _transforms == null ? null : _transforms[firstTransform].transform; var res = Convert2Onnx.ToOnnx(_predictor == null ? _transforms.Last().transform : GetScorer(), ref inputs, ref outputs, name, producer, version, domain, onnxVersion, new IDataView[] { begin }, _env); return(res); }
public OnnxContextImpl(IHostEnvironment env, string name, string producerName, string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OnnxContext)); _host.CheckValue(name, nameof(name)); _host.CheckValue(name, nameof(domain)); _nodes = new List <OnnxCSharpToProtoWrapper.NodeProto>(); _intermediateValues = new List <OnnxUtils.ModelArgs>(); _inputs = new List <OnnxUtils.ModelArgs>(); _initializers = new List <OnnxCSharpToProtoWrapper.TensorProto>(); _outputs = new List <OnnxUtils.ModelArgs>(); _columnNameMap = new Dictionary <string, string>(); _variableNames = new HashSet <string>(); _nodeNames = new HashSet <string>(); _name = name; _producerName = producerName; _producerVersion = producerVersion; _modelVersion = modelVersion; _domain = domain; _onnxVersion = onnxVersion; }
public static ScikitOnnxContext ToOnnx(IDataTransform trans, ref string[] inputs, ref string[] outputs, string name = null, string producer = "Scikit.ML", long version = 0, string domain = "onnx.ai.ml", OnnxVersion onnxVersion = OnnxVersion.Stable, IDataView[] begin = null, IHostEnvironment host = null) { if (host == null) { using (var env = new DelegateEnvironment()) return(ToOnnx(trans, ref inputs, ref outputs, name, producer, version, domain, onnxVersion, begin, env)); } if (name == null) { name = trans.GetType().Name; } GuessOutputs(trans, ref outputs); GuessInputs(trans, ref inputs, begin); if (inputs == null || inputs.Length == 0) { throw host.Except("Inputs cannot be empty."); } if (outputs == null || outputs.Length == 0) { throw host.Except("Outputs cannot be empty."); } var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new ScikitOnnxContext(host, name, producer, versionInfo.FileVersion, version, domain, onnxVersion); var hasin = new HashSet <string>(inputs); var uniqueVars = EnumerateVariables(trans, begin).ToArray(); var mapInputType = new Dictionary <string, ColumnType>(); foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.first)) { mapInputType[it.variableName] = it.variableType; } foreach (var col in inputs) { ctx.AddInputVariable(mapInputType[col], col); } var views = TagHelper.EnumerateAllViews(trans, begin); var transforms = views.Where(c => (c.Item1 as IDataTransform) != null) .Select(c => c.Item1) .ToArray(); foreach (var tr in transforms.Reverse()) { var tron = tr as ICanSaveOnnx; if (tron == null) { throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in Onnx format."); } if (!tron.CanSaveOnnx(ctx)) { throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in ONNX format."); } var tron2 = tron as ISaveAsOnnx; if (!tron2.CanSaveOnnx(ctx)) { throw host.ExceptNotSupp($"Transform {tr.GetType()} does not implement SaveAsOnnx."); } tron2.SaveAsOnnx(ctx); } var mapOuputType = new Dictionary <string, ColumnType>(); foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.last)) { mapOuputType[it.variableName] = it.variableType; } foreach (var col in outputs) { var variableName = ctx.TryGetVariableName(col); var trueVariableName = ctx.AddIntermediateVariable(null, col, true); ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), ""); ctx.AddOutputVariable(mapOuputType[col], trueVariableName); } return(ctx); }