private static void TrainCore(IHost host, IDataView input, Arguments args, ref VBuffer <Single> scores) { Contracts.AssertValue(host); host.AssertValue(args); host.AssertValue(input); host.Assert(args.Threshold.HasValue != args.NumSlotsToKeep.HasValue); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); ITrainer trainer = args.Filter.CreateComponent(host); IDataView view = input; ISchema schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.LabelColumn), args.LabelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.WeightColumn), args.WeightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.NameColumn), args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); var predictor = TrainUtils.Train(host, ch, data, trainer, null, null, 0, args.CacheData); var rfs = predictor as IPredictorWithFeatureWeights <Single>; Contracts.AssertValue(rfs); rfs.GetFeatureWeights(ref scores); } }
public static TOut Train <TArg, TOut>(IHost host, TArg input, Func <ITrainer> createTrainer, Func <string> getLabel = null, Func <string> getWeight = null, Func <string> getGroup = null, Func <string> getName = null, Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null, ICalibratorTrainerFactory calibrator = null, int maxCalibrationExamples = 0) where TArg : LearnerInputBase where TOut : CommonOutputs.TrainerOutput, new() { using (var ch = host.Start("Training")) { var schema = input.TrainingData.Schema; var feature = FindColumn(ch, schema, input.FeatureColumn); var label = getLabel?.Invoke(); var weight = getWeight?.Invoke(); var group = getGroup?.Invoke(); var name = getName?.Invoke(); var custom = getCustom?.Invoke(); var trainer = createTrainer(); IDataView view = input.TrainingData; TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures); ch.Trace("Binding columns"); var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom); RoleMappedData cachedRoleMappedData = roleMappedData; Cache.CachingType?cachingType = null; switch (input.Caching) { case CachingOptions.Memory: { cachingType = Cache.CachingType.Memory; break; } case CachingOptions.Disk: { cachingType = Cache.CachingType.Disk; break; } case CachingOptions.Auto: { // REVIEW: we should switch to hybrid caching in future. if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching) { // default to Memory so mml is on par with maml cachingType = Cache.CachingType.Memory; } break; } case CachingOptions.None: break; default: throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching); } if (cachingType.HasValue) { var cacheView = Cache.CacheData(host, new Cache.CacheInput() { Data = roleMappedData.Data, Caching = cachingType.Value }).OutputData; cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames()); } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples); return(new TOut() { PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor) }); } }