예제 #1
0
        internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
                                                                   LinkedList <ITransformCanSaveOnnx> transforms, HashSet <string> inputColumnNamesToDrop = null, HashSet <string> outputColumnNamesToDrop = null)
        {
            inputColumnNamesToDrop  = inputColumnNamesToDrop ?? new HashSet <string>();
            outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet <string>();
            HashSet <string> inputColumns = new HashSet <string>();

            // Create graph inputs.
            for (int i = 0; i < inputData.Schema.Count; i++)
            {
                string colName = inputData.Schema[i].Name;
                if (inputColumnNamesToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(inputData.Schema[i].Type, colName);
                inputColumns.Add(colName);
            }

            // Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                ch.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            // Add graph outputs.
            for (int i = 0; i < outputData.Schema.Count; ++i)
            {
                if (outputData.Schema[i].IsHidden)
                {
                    continue;
                }

                var column = outputData.Schema[i];

                var idataviewColumnName = column.Name;

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName = ctx.TryGetVariableName(idataviewColumnName);
                // Null variable name occurs when an unsupported transform produces an output and a downsteam step consumes that output.
                // or user accidently removes a transform whose output is used by other transforms.
                ch.Check(variableName != null, "The targeted pipeline can not be fully converted into a well-defined ONNX model. " +
                         "Please check if all steps in that pipeline are convertible to ONNX " +
                         "and all necessary variables are not dropped (via command line arguments).");
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName + ".output", true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);

                if (column.HasSlotNames())
                {
                    AddSlotNames(ctx, column);
                }
            }

            // Add metadata graph outputs

            return(ctx.MakeModel());
        }
예제 #2
0
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader  = null;
            IPredictor        rawPred = null;
            IDataView         view;
            RoleMappedSchema  trainSchema = null;

            if (_model == null && _predictiveModel == null)
            {
                if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else if (_model != null)
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }
            else
            {
                view        = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
                rawPred     = _predictiveModel.Predictor;
                trainSchema = _predictiveModel.GetTrainingSchema(Host);
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, ImplOptions.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }
예제 #3
0
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader  = null;
            IPredictor        rawPred = null;
            IDataView         view;
            RoleMappedSchema  trainSchema = null;

            if (_model == null && _predictiveModel == null)
            {
                if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else if (_model != null)
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }
            else
            {
                view        = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
                rawPred     = _predictiveModel.Predictor;
                trainSchema = _predictiveModel.GetTrainingSchema(Host);
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, ImplOptions.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);

                    if (rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification)
                    {
                        // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations.
                        // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back
                        // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841
                        var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel);
                        if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value))
                        {
                            var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe);
                            end = outputData;
                            transforms.AddLast(outputData as ITransformCanSaveOnnx);
                        }
                    }
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            // Convert back to values the KeyDataViewType "pass-through" columns
            // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values
            // as described in https://github.com/dotnet/machinelearning/pull/4841

            var passThroughColumnNames = GetPassThroughKeyDataViewTypeColumnsNames(source, end);

            foreach (var name in passThroughColumnNames)
            {
                var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end);
                end = outputData;
                transforms.AddLast(end as ITransformCanSaveOnnx);
            }

            var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }