示例#1
0
        public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
                               string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(OnnxContext));
            _host.CheckValue(name, nameof(name));
            _host.CheckValue(name, nameof(domain));

            _nodes = new List <OnnxCSharpToProtoWrapper.NodeProto>();
            _intermediateValues = new List <OnnxUtils.ModelArgs>();
            _inputs             = new List <OnnxUtils.ModelArgs>();
            _initializers       = new List <OnnxCSharpToProtoWrapper.TensorProto>();
            _outputs            = new List <OnnxUtils.ModelArgs>();
            _columnNameMap      = new Dictionary <string, string>();
            _variableNames      = new HashSet <string>();
            _nodeNames          = new HashSet <string>();
            _name            = name;
            _producerName    = producerName;
            _producerVersion = producerVersion;
            _modelVersion    = modelVersion;
            _domain          = domain;
            _onnxVersion     = onnxVersion;
            _opSetVersion    = opSetVersion <= CurrentOpSetVersion ?
                               opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") :
                                     throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}");
        }
        /// <summary>
        /// Exports to onnx.
        /// </summary>
        /// <param name="firstTransform">select the first transform</param>
        /// <param name="inputs">inputs, modified by the function to contains what is actually used</param>
        /// <param name="outputs">outputs, modified by the function to contains what is actually used</param>
        public ScikitOnnxContext ToOnnx(int firstTransform      = 0, string[] inputs    = null, string[] outputs = null,
                                        string name             = null, string producer = "Scikit.ML",
                                        long version            = 0, string domain      = "onnx.ai.ml",
                                        OnnxVersion onnxVersion = OnnxVersion.Stable)
        {
            var begin = _transforms == null ? null : _transforms[firstTransform].transform;
            var res   = Convert2Onnx.ToOnnx(_predictor == null ? _transforms.Last().transform : GetScorer(),
                                            ref inputs, ref outputs, name, producer, version, domain,
                                            onnxVersion, new IDataView[] { begin }, _env);

            return(res);
        }
        public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
                               string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(OnnxContext));
            _host.CheckValue(name, nameof(name));
            _host.CheckValue(name, nameof(domain));

            _nodes = new List <OnnxCSharpToProtoWrapper.NodeProto>();
            _intermediateValues = new List <OnnxUtils.ModelArgs>();
            _inputs             = new List <OnnxUtils.ModelArgs>();
            _initializers       = new List <OnnxCSharpToProtoWrapper.TensorProto>();
            _outputs            = new List <OnnxUtils.ModelArgs>();
            _columnNameMap      = new Dictionary <string, string>();
            _variableNames      = new HashSet <string>();
            _nodeNames          = new HashSet <string>();
            _name            = name;
            _producerName    = producerName;
            _producerVersion = producerVersion;
            _modelVersion    = modelVersion;
            _domain          = domain;
            _onnxVersion     = onnxVersion;
        }
示例#4
0
        public static ScikitOnnxContext ToOnnx(IDataTransform trans, ref string[] inputs, ref string[] outputs,
                                               string name             = null, string producer = "Scikit.ML",
                                               long version            = 0, string domain      = "onnx.ai.ml",
                                               OnnxVersion onnxVersion = OnnxVersion.Stable,
                                               IDataView[] begin       = null, IHostEnvironment host = null)
        {
            if (host == null)
            {
                using (var env = new DelegateEnvironment())
                    return(ToOnnx(trans, ref inputs, ref outputs, name, producer, version, domain, onnxVersion, begin, env));
            }
            if (name == null)
            {
                name = trans.GetType().Name;
            }

            GuessOutputs(trans, ref outputs);
            GuessInputs(trans, ref inputs, begin);

            if (inputs == null || inputs.Length == 0)
            {
                throw host.Except("Inputs cannot be empty.");
            }
            if (outputs == null || outputs.Length == 0)
            {
                throw host.Except("Outputs cannot be empty.");
            }

            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new ScikitOnnxContext(host, name, producer, versionInfo.FileVersion,
                                                    version, domain, onnxVersion);

            var hasin        = new HashSet <string>(inputs);
            var uniqueVars   = EnumerateVariables(trans, begin).ToArray();
            var mapInputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.first))
            {
                mapInputType[it.variableName] = it.variableType;
            }

            foreach (var col in inputs)
            {
                ctx.AddInputVariable(mapInputType[col], col);
            }

            var views      = TagHelper.EnumerateAllViews(trans, begin);
            var transforms = views.Where(c => (c.Item1 as IDataTransform) != null)
                             .Select(c => c.Item1)
                             .ToArray();

            foreach (var tr in transforms.Reverse())
            {
                var tron = tr as ICanSaveOnnx;
                if (tron == null)
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in Onnx format.");
                }
                if (!tron.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in ONNX format.");
                }
                var tron2 = tron as ISaveAsOnnx;
                if (!tron2.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} does not implement SaveAsOnnx.");
                }
                tron2.SaveAsOnnx(ctx);
            }

            var mapOuputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.last))
            {
                mapOuputType[it.variableName] = it.variableType;
            }

            foreach (var col in outputs)
            {
                var variableName     = ctx.TryGetVariableName(col);
                var trueVariableName = ctx.AddIntermediateVariable(null, col, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(mapOuputType[col], trueVariableName);
            }

            return(ctx);
        }