예제 #1
0
        internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
                                                                   LinkedList <ITransformCanSaveOnnx> transforms, HashSet <string> inputColumnNamesToDrop = null, HashSet <string> outputColumnNamesToDrop = null)
        {
            inputColumnNamesToDrop  = inputColumnNamesToDrop ?? new HashSet <string>();
            outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet <string>();
            HashSet <string> inputColumns = new HashSet <string>();

            // Create graph inputs.
            for (int i = 0; i < inputData.Schema.Count; i++)
            {
                string colName = inputData.Schema[i].Name;
                if (inputColumnNamesToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(inputData.Schema[i].Type, colName);
                inputColumns.Add(colName);
            }

            // Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                ch.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            // Add graph outputs.
            for (int i = 0; i < outputData.Schema.Count; ++i)
            {
                if (outputData.Schema[i].IsHidden)
                {
                    continue;
                }

                var idataviewColumnName = outputData.Schema[i].Name;

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName = ctx.TryGetVariableName(idataviewColumnName);
                // Null variable name occurs when an unsupported transform produces an output and a downsteam step consumes that output.
                // or user accidently removes a transform whose output is used by other transforms.
                ch.Check(variableName != null, "The targeted pipeline can not be fully converted into a well-defined ONNX model. " +
                         "Please check if all steps in that pipeline are convertible to ONNX " +
                         "and all necessary variables are not dropped (via command line arguments).");
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
            }

            return(ctx.MakeModel());
        }
        private void Run(IChannel ch)
        {
            IDataLoader      loader  = null;
            IPredictor       rawPred = null;
            IDataView        view;
            RoleMappedSchema trainSchema = null;

            if (_model == null)
            {
                if (string.IsNullOrEmpty(Args.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, Args.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            HashSet <string> inputColumns = new HashSet <string>();

            //Create graph inputs.
            for (int i = 0; i < source.Schema.Count; i++)
            {
                string colName = source.Schema[i].Name;
                if (_inputsToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(source.Schema[i].Type, colName);
                inputColumns.Add(colName);
            }

            //Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                Host.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            //Add graph outputs.
            for (int i = 0; i < end.Schema.Count; ++i)
            {
                if (end.Schema[i].IsHidden)
                {
                    continue;
                }

                var idataviewColumnName = end.Schema[i].Name;

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName     = ctx.TryGetVariableName(idataviewColumnName);
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(end.Schema[i].Type, trueVariableName);
            }

            var model = ctx.MakeModel();

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, Args.OutputModelFile);
            }
        }