/// <summary> /// Save schema associations of role/column-name in <paramref name="rep"/>. /// </summary> internal static void SaveRoleMappings(IHostEnvironment env, IChannel ch, RoleMappedSchema schema, RepositoryWriter rep) { // REVIEW: Should we also save this stuff, for instance, in some portion of the // score command or transform? Contracts.AssertValue(env); env.AssertValue(ch); ch.AssertValue(schema); ArrayDataViewBuilder builder = new ArrayDataViewBuilder(env); List<string> rolesList = new List<string>(); List<string> columnNamesList = new List<string>(); // OrderBy is stable, so there is no danger in it "reordering" columns // when a role is filled by multiple columns. foreach (var role in schema.GetColumnRoleNames().OrderBy(r => r.Key.Value)) { rolesList.Add(role.Key.Value); columnNamesList.Add(role.Value); } builder.AddColumn("Role", rolesList.ToArray()); builder.AddColumn("Column", columnNamesList.ToArray()); using (var entry = rep.CreateEntry(DirTrainingInfo, RoleMappingFile)) { // REVIEW: It seems very important that we have the role mappings // be easily human interpretable and even manipulable, but relying on the // text saver/loader means that special characters like '\n' won't be reinterpretable. // On the other hand, no one is such a big lunatic that they will actually // ever go ahead and do something so stupid as that. var saver = new TextSaver(env, new TextSaver.Arguments() { Dense = true, Silent = true }); var view = builder.GetDataView(); saver.SaveData(entry.Stream, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount)); } }
public static void Write(ModelSaveContext ctx, RoleMappedSchema schema) { var array = schema.GetColumnRoleNames().ToArray(); ctx.Writer.Write(array.Length); foreach (var pair in array) { ctx.Writer.Write(pair.Key.Value); ctx.Writer.Write(pair.Value); } }
private void Run(IChannel ch) { IDataLoader loader = null; IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; if (_model == null) { if (string.IsNullOrEmpty(Args.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor), "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified."); } else { LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); } view = loader; } else { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); } // Get the transform chain. IDataView source; IDataView end; LinkedList <ITransformCanSaveOnnx> transforms; GetPipe(ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, ModelVersion, _domain); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { RoleMappedData data; if (trainSchema != null) { data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames()); } else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label, DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; if (scoreOnnx?.CanSaveOnnx == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX."); ch.Warning("We do not know how to save the predictor as ONNX. Ignoring."); } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } HashSet <string> inputColumns = new HashSet <string>(); //Create graph inputs. for (int i = 0; i < source.Schema.ColumnCount; i++) { string colName = source.Schema.GetColumnName(i); if (_inputsToDrop.Contains(colName)) { continue; } ctx.AddInputVariable(source.Schema.GetColumnType(i), colName); inputColumns.Add(colName); } //Create graph nodes, outputs and intermediate values. foreach (var trans in transforms) { Host.Assert(trans.CanSaveOnnx); trans.SaveAsOnnx(ctx); } //Add graph outputs. for (int i = 0; i < end.Schema.ColumnCount; ++i) { if (end.Schema.IsHidden(i)) { continue; } var idataviewColumnName = end.Schema.GetColumnName(i);; if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName)) { continue; } var variableName = ctx.TryGetVariableName(idataviewColumnName); if (variableName != null) { ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName); } } var model = ctx.MakeModel(); if (_outputModelPath != null) { using (var file = Host.CreateOutputFile(_outputModelPath)) using (var stream = file.CreateWriteStream()) model.WriteTo(stream); } if (_outputJsonModelPath != null) { using (var file = Host.CreateOutputFile(_outputJsonModelPath)) using (var stream = file.CreateWriteStream()) using (var writer = new StreamWriter(stream)) { var parsedJson = JsonConvert.DeserializeObject(model.ToString()); writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented)); } } if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, Args.OutputModelFile); } }
private void Run(IChannel ch) { ILegacyDataLoader loader = null; IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; if (_model == null) { if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified."); } else { LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); } view = loader; } else { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); } // Create the ONNX context for storing global information var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, ModelVersion, _domain, ImplOptions.OnnxVersion); // Get the transform chain. IDataView source; IDataView end; LinkedList <ITransformCanSaveOnnx> transforms; GetPipe(ctx, ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { RoleMappedData data; if (trainSchema != null) { data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); } else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. data = new RoleMappedData(end, DefaultColumnNames.Label, DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; if (scoreOnnx?.CanSaveOnnx(ctx) == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX."); ch.Warning("We do not know how to save the predictor as ONNX. Ignoring."); } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); using (var file = Host.CreateOutputFile(_outputModelPath)) using (var stream = file.CreateWriteStream()) model.WriteTo(stream); if (_outputJsonModelPath != null) { using (var file = Host.CreateOutputFile(_outputJsonModelPath)) using (var stream = file.CreateWriteStream()) using (var writer = new StreamWriter(stream)) { var parsedJson = JsonConvert.DeserializeObject(model.ToString()); writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented)); } } if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, ImplOptions.OutputModelFile); } }
private void Run(IChannel ch) { ILegacyDataLoader loader = null; IPredictor rawPred = null; IDataView view; RoleMappedSchema trainSchema = null; if (_model == null && _predictiveModel == null) { if (string.IsNullOrEmpty(ImplOptions.InputModelFile)) { loader = CreateLoader(); rawPred = null; trainSchema = null; Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor), "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specified."); } else { LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader); } view = loader; } else if (_model != null) { view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); } else { view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema)); rawPred = _predictiveModel.Predictor; trainSchema = _predictiveModel.GetTrainingSchema(Host); } // Create the ONNX context for storing global information var assembly = System.Reflection.Assembly.GetExecutingAssembly(); var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, ModelVersion, _domain, ImplOptions.OnnxVersion); // Get the transform chain. IDataView source; IDataView end; LinkedList <ITransformCanSaveOnnx> transforms; GetPipe(ctx, ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { RoleMappedData data; if (trainSchema != null) { data = new RoleMappedData(end, trainSchema.GetColumnRoleNames()); } else { // We had a predictor, but no roles stored in the model. Just suppose // default column names are OK, if present. data = new RoleMappedData(end, DefaultColumnNames.Label, DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true); } var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; if (scoreOnnx?.CanSaveOnnx(ctx) == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; transforms.AddLast(scoreOnnx); if (rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification) { // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations. // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841 var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel); if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value)) { var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe); end = outputData; transforms.AddLast(outputData as ITransformCanSaveOnnx); } } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX."); ch.Warning("We do not know how to save the predictor as ONNX. Ignoring."); } } else { Contracts.CheckUserArg(_loadPredictor != true, nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present."); } // Convert back to values the KeyDataViewType "pass-through" columns // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values // as described in https://github.com/dotnet/machinelearning/pull/4841 var passThroughColumnNames = GetPassThroughKeyDataViewTypeColumnsNames(source, end); foreach (var name in passThroughColumnNames) { var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end); end = outputData; transforms.AddLast(end as ITransformCanSaveOnnx); } var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop); using (var file = Host.CreateOutputFile(_outputModelPath)) using (var stream = file.CreateWriteStream()) model.WriteTo(stream); if (_outputJsonModelPath != null) { using (var file = Host.CreateOutputFile(_outputJsonModelPath)) using (var stream = file.CreateWriteStream()) using (var writer = new StreamWriter(stream)) { var parsedJson = JsonConvert.DeserializeObject(model.ToString()); writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented)); } } if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile)) { Contracts.Assert(loader != null); ch.Trace("Saving the data pipe"); // Should probably include "end"? SaveLoader(loader, ImplOptions.OutputModelFile); } }