public static TOut Train <TArg, TOut>(IHost host, TArg input, Func <ITrainer> createTrainer, Func <string> getLabel = null, Func <string> getWeight = null, Func <string> getGroup = null, Func <string> getName = null, Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null, ICalibratorTrainerFactory calibrator = null, int maxCalibrationExamples = 0) where TArg : LearnerInputBase where TOut : CommonOutputs.TrainerOutput, new() { using (var ch = host.Start("Training")) { ISchema schema = input.TrainingData.Schema; var feature = FindColumn(ch, schema, input.FeatureColumn); var label = getLabel?.Invoke(); var weight = getWeight?.Invoke(); var group = getGroup?.Invoke(); var name = getName?.Invoke(); var custom = getCustom?.Invoke(); var trainer = createTrainer(); IDataView view = input.TrainingData; TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures); ch.Trace("Binding columns"); var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom); RoleMappedData cachedRoleMappedData = roleMappedData; Cache.CachingType?cachingType = null; switch (input.Caching) { case CachingOptions.Memory: { cachingType = Cache.CachingType.Memory; break; } case CachingOptions.Disk: { cachingType = Cache.CachingType.Disk; break; } case CachingOptions.Auto: { ITrainerEx trainerEx = trainer as ITrainerEx; // REVIEW: we should switch to hybrid caching in future. if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching)) { // default to Memory so mml is on par with maml cachingType = Cache.CachingType.Memory; } break; } case CachingOptions.None: break; default: throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching); } if (cachingType.HasValue) { var cacheView = Cache.CacheData(host, new Cache.CacheInput() { Data = roleMappedData.Data, Caching = cachingType.Value }).OutputData; cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames()); } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples); var output = new TOut() { PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor) }; ch.Done(); return(output); } }
public static TOut Train <TArg, TOut>(IHost host, TArg input, Func <ITrainer> createTrainer, Func <string> getLabel = null, Func <string> getWeight = null, Func <string> getGroup = null, Func <string> getName = null, Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null, ICalibratorTrainerFactory calibrator = null, int maxCalibrationExamples = 0) where TArg : TrainerInputBase where TOut : CommonOutputs.TrainerOutput, new() { using (var ch = host.Start("Training")) { var schema = input.TrainingData.Schema; var feature = FindColumn(ch, schema, input.FeatureColumnName); var label = getLabel?.Invoke(); var weight = getWeight?.Invoke(); var group = getGroup?.Invoke(); var name = getName?.Invoke(); var custom = getCustom?.Invoke(); var trainer = createTrainer(); IDataView view = input.TrainingData; TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures); ch.Trace("Binding columns"); var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom); RoleMappedData cachedRoleMappedData = roleMappedData; const string registrationName = "CreateCache"; var createCacheHost = host.Register(registrationName); IDataView outputData = null; switch (input.Caching) { case CachingOptions.Memory: { outputData = new CacheDataView(host, roleMappedData.Data, null); break; } case CachingOptions.Auto: { // REVIEW: we should switch to hybrid caching in future. if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching) { // default to Memory so mml is on par with maml outputData = new CacheDataView(host, roleMappedData.Data, null); } break; } case CachingOptions.None: break; default: throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching); } if (outputData != null) { cachedRoleMappedData = new RoleMappedData(outputData, roleMappedData.Schema.GetColumnRoleNames()); } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples); return(new TOut() { PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor) }); } }