private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData, IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); inputPredictor = null; } ch.Assert(validData == null || trainer.Info.SupportsValidation); var predictor = trainer.Train(new TrainContext(data, validData, testData, inputPredictor)); var caliTrainer = calibrator?.CreateComponent(env); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data)); }
/// <summary> /// Save the model to the output path. /// The method saves the loader and the transformations of dataPipe and saves optionally predictor /// and command. It also uses featureColumn, if provided, to extract feature names. /// </summary> /// <param name="env">The host environment to use.</param> /// <param name="ch">The communication channel to use.</param> /// <param name="output">The output file handle.</param> /// <param name="predictor">The predictor.</param> /// <param name="data">The training examples.</param> /// <param name="command">The command string.</param> public static void SaveModel(IHostEnvironment env, IChannel ch, IFileHandle output, IPredictor predictor, RoleMappedData data, string command = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckParam(output != null && output.CanWrite, nameof(output)); ch.CheckValueOrNull(predictor); ch.CheckValue(data, nameof(data)); ch.CheckValueOrNull(command); using (var stream = output.CreateWriteStream()) SaveModel(env, ch, stream, predictor, data, command); }
/// <summary> /// Save the model in text format (if it can save itself) /// </summary> public static void SaveIni(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer) { Contracts.CheckValue(ch, nameof(ch)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValueOrNull(schema); ch.CheckValue(writer, nameof(writer)); var iniSaver = predictor as ICanSaveInIniFormat; if (iniSaver != null) { iniSaver.SaveAsIni(writer, schema); return; } var summarySaver = predictor as ICanSaveSummary; if (summarySaver != null) { writer.WriteLine("'{0}' does not support saving in INI format, writing out model summary instead", predictor.GetType().Name); ch.Error("'{0}' doesn't currently have standardized INI format output, will save model summary instead", predictor.GetType().Name); summarySaver.SaveSummary(writer, schema); } else { writer.WriteLine("'{0}' does not support saving in INI format", predictor.GetType().Name); ch.Error("'{0}' doesn't currently have standardized INI format output", predictor.GetType().Name); } }
/// <summary> /// Save the model to the stream. /// The method saves the loader and the transformations of dataPipe and saves optionally predictor /// and command. It also uses featureColumn, if provided, to extract feature names. /// </summary> /// <param name="env">The host environment to use.</param> /// <param name="ch">The communication channel to use.</param> /// <param name="outputStream">The output model stream.</param> /// <param name="predictor">The predictor.</param> /// <param name="data">The training examples.</param> /// <param name="command">The command string.</param> public static void SaveModel(IHostEnvironment env, IChannel ch, Stream outputStream, IPredictor predictor, RoleMappedData data, string command = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(outputStream, nameof(outputStream)); ch.CheckValueOrNull(predictor); ch.CheckValue(data, nameof(data)); ch.CheckValueOrNull(command); using (var ch2 = env.Start("SaveModel")) using (var pch = env.StartProgressChannel("Saving model")) { using (var rep = RepositoryWriter.CreateNew(outputStream, ch2)) { if (predictor != null) { ch2.Trace("Saving predictor"); ModelSaveContext.SaveModel(rep, predictor, ModelFileUtils.DirPredictor); } ch2.Trace("Saving loader and transformations"); var dataPipe = data.Data; if (dataPipe is IDataLoader) { ModelSaveContext.SaveModel(rep, dataPipe, ModelFileUtils.DirDataLoaderModel); } else { SaveDataPipe(env, rep, dataPipe); } // REVIEW: Handle statistics. // ModelSaveContext.SaveModel(rep, dataStats, DirDataStats); if (!string.IsNullOrWhiteSpace(command)) { using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Command.txt")) using (var writer = Utils.OpenWriter(ent.Stream)) writer.WriteLine(command); } ModelFileUtils.SaveRoleMappings(env, ch, data.Schema, rep); rep.Commit(); } ch2.Done(); } }
// Returns true if a normalizer was added. public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValue(view, nameof(view)); ch.CheckValueOrNull(featureColumn); ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures), "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'."); if (autoNorm == NormalizeOption.No) { ch.Info("Not adding a normalizer."); return(false); } if (string.IsNullOrEmpty(featureColumn)) { return(false); } int featCol; var schema = view.Schema; if (schema.TryGetColumnIndex(featureColumn, out featCol)) { if (autoNorm != NormalizeOption.Yes) { DvBool isNormalized = DvBool.False; if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol)) { ch.Info("Not adding a normalizer."); return(false); } if (autoNorm == NormalizeOption.Warn) { ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options."); return(false); } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input) => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn); if (view is IDataLoader loader) { view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer); } else { view = ApplyNormalizer(env, view); } return(true); } return(false); }
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inpPredictor); var trainerRmd = trainer as ITrainer <RoleMappedData>; if (trainerRmd == null) { throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name); } Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore; IPredictor predictor; AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod( typeof(RoleMappedData), inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor)); Action <RoleMappedData> trainExam = trainerRmd.Train; genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor }); ch.Trace("Constructing predictor"); predictor = trainerRmd.CreatePredictor(); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
/// <summary> /// Trains a model. /// </summary> /// <param name="env">host</param> /// <param name="ch">channel</param> /// <param name="data">traing data</param> /// <param name="validData">validation data</param> /// <param name="calibrator">calibrator</param> /// <param name="maxCalibrationExamples">number of examples used to calibrate</param> /// <param name="cacheData">cache training data</param> /// <param name="inputPredictor">for continuous training, initial state</param> /// <returns>predictor</returns> public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 0, bool?cacheData = null, IPredictor inputPredictor = null) { /* * return TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples, * cacheData, inpPredictor); */ var trainer = Trainer; var name = LoadName; Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace(MessageSensitivity.None, "Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning(MessageSensitivity.None, "Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); inputPredictor = null; } ch.Assert(validData == null || trainer.Info.SupportsValidation); var predictor = trainer.Train(new TrainContext(data, validData, null, inputPredictor)); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
/// <summary> /// Save the model summary. /// </summary> public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer) { Contracts.CheckValue(ch, nameof(ch)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValueOrNull(schema); ch.CheckValue(writer, nameof(writer)); var saver = predictor as ICanSaveSummary; if (saver != null) { saver.SaveSummary(writer, schema); } else { writer.WriteLine("'{0}' does not support saving summary", predictor.GetType().Name); ch.Error("'{0}' does not support saving summary", predictor.GetType().Name); } }
/// <summary> /// Save the model in text format (if it can save itself) /// </summary> public static void SaveCode(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer) { Contracts.CheckValue(ch, nameof(ch)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValueOrNull(schema); ch.CheckValue(writer, nameof(writer)); var saver = predictor as ICanSaveInSourceCode; if (saver != null) { saver.SaveAsCode(writer, schema); } else { writer.WriteLine("'{0}' does not support saving in code.", predictor.GetType().Name); ch.Error("'{0}' doesn't currently support saving the model as code", predictor.GetType().Name); } }
// Returns true if a normalizer was added. public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValue(view, nameof(view)); ch.CheckValueOrNull(featureColumn); ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures), "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'."); if (autoNorm == NormalizeOption.No) { ch.Info("Not adding a normalizer."); return(false); } if (string.IsNullOrEmpty(featureColumn)) { return(false); } int featCol; var schema = view.Schema; if (schema.TryGetColumnIndex(featureColumn, out featCol)) { if (autoNorm != NormalizeOption.Yes) { var nn = trainer as ITrainerEx; DvBool isNormalized = DvBool.False; if (nn == null || !nn.NeedNormalization || (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) && isNormalized.IsTrue)) { ch.Info("Not adding a normalizer."); return(false); } if (autoNorm == NormalizeOption.Warn) { ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options."); return(false); } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); // Quote the feature column name string quotedFeatureColumnName = featureColumn; StringBuilder sb = new StringBuilder(); if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb)) { quotedFeatureColumnName = sb.ToString(); } var component = new SubComponent <IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName)); var loader = view as IDataLoader; if (loader != null) { view = CompositeDataLoader.Create(env, loader, new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(null, component)); } else { view = component.CreateInstance(env, view); } return(true); } return(false); }