public void TestI_Q_KMeansInnerAPIWithDataFrame() { var methodName = System.Reflection.MethodBase.GetCurrentMethod().Name; var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName); var iris = FileHelper.GetTestFile("iris.txt"); using (var env = new ConsoleEnvironment(conc: 1)) { ComponentHelper.AddStandardComponents(env); var df = Scikit.ML.DataManipulation.DataFrameIO.ReadCsv(iris, sep: '\t', dtypes: new ColumnType[] { NumberType.R4 }); var conc = env.CreateTransform("Concat{col=Feature:Sepal_length,Sepal_width}", df); var roleMap = env.CreateExamples(conc, "Feature", label: "Label"); var trainer = CreateTrainer(env, "km"); IPredictor model; using (var ch = env.Start("test")) model = TrainUtils.Train(env, ch, roleMap, trainer, null, 0); using (var ch = env.Start("Save")) using (var fs = File.Create(outModelFilePath)) TrainUtils.SaveModel(env, ch, fs, model, roleMap); Predictor pred; using (var fs = File.OpenRead(outModelFilePath)) pred = env.LoadPredictorOrNull(fs); #pragma warning disable CS0618 var scorer = ScoreUtils.GetScorer(pred.GetPredictorObject() as IPredictor, roleMap, env, null); #pragma warning restore CS0618 var dfout = Scikit.ML.DataManipulation.DataFrameIO.ReadView(scorer); Assert.AreEqual(dfout.Shape, new Tuple <int, int>(150, 13)); } }
/// <summary> /// Trains a model. /// </summary> /// <param name="env">host</param> /// <param name="ch">channel</param> /// <param name="data">traing data</param> /// <param name="validData">validation data</param> /// <param name="calibrator">calibrator</param> /// <param name="maxCalibrationExamples">number of examples used to calibrate</param> /// <param name="cacheData">cache training data</param> /// <param name="inpPredictor">for continuous training, initial state</param> /// <returns>predictor</returns> public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null, SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator = null, int maxCalibrationExamples = 0, bool?cacheData = null, IPredictor inpPredictor = null) { return(TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples, cacheData, inpPredictor)); }
private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input, IComponentFactory <IPredictor, ISchemaBindableMapper> mapperFactory) { Contracts.AssertValue(env, nameof(env)); env.AssertValue(args, nameof(args)); env.AssertValue(trainer, nameof(trainer)); env.AssertValue(input, nameof(input)); var host = env.Register("TrainAndScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(ch, input, args, out feat, out group); var predictor = TrainUtils.Train(host, ch, data, trainer, null, args.Calibrator, args.MaxCalibrationExamples, null); return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, data.Schema, mapperFactory)); } }
private static void TrainCore(IHost host, IDataView input, Arguments args, ref VBuffer <Single> scores) { Contracts.AssertValue(host); host.AssertValue(args); host.AssertValue(input); host.Assert(args.Threshold.HasValue != args.NumSlotsToKeep.HasValue); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); ITrainer trainer = args.Filter.CreateComponent(host); IDataView view = input; ISchema schema = view.Schema; var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.LabelColumn), args.LabelColumn, DefaultColumnNames.Label); var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); var group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.WeightColumn), args.WeightColumn, DefaultColumnNames.Weight); var name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.NameColumn), args.NameColumn, DefaultColumnNames.Name); TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, args.NormalizeFeatures); ch.Trace("Binding columns"); var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn); var data = new RoleMappedData(view, label, feature, group, weight, name, customCols); var predictor = TrainUtils.Train(host, ch, data, trainer, null, null, 0, args.CacheData); var rfs = predictor as IPredictorWithFeatureWeights <Single>; Contracts.AssertValue(rfs); rfs.GetFeatureWeights(ref scores); } }
public static TOut Train <TArg, TOut>(IHost host, TArg input, Func <ITrainer> createTrainer, Func <string> getLabel = null, Func <string> getWeight = null, Func <string> getGroup = null, Func <string> getName = null, Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null, ICalibratorTrainerFactory calibrator = null, int maxCalibrationExamples = 0) where TArg : LearnerInputBase where TOut : CommonOutputs.TrainerOutput, new() { using (var ch = host.Start("Training")) { var schema = input.TrainingData.Schema; var feature = FindColumn(ch, schema, input.FeatureColumn); var label = getLabel?.Invoke(); var weight = getWeight?.Invoke(); var group = getGroup?.Invoke(); var name = getName?.Invoke(); var custom = getCustom?.Invoke(); var trainer = createTrainer(); IDataView view = input.TrainingData; TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures); ch.Trace("Binding columns"); var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom); RoleMappedData cachedRoleMappedData = roleMappedData; Cache.CachingType?cachingType = null; switch (input.Caching) { case CachingOptions.Memory: { cachingType = Cache.CachingType.Memory; break; } case CachingOptions.Disk: { cachingType = Cache.CachingType.Disk; break; } case CachingOptions.Auto: { // REVIEW: we should switch to hybrid caching in future. if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching) { // default to Memory so mml is on par with maml cachingType = Cache.CachingType.Memory; } break; } case CachingOptions.None: break; default: throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching); } if (cachingType.HasValue) { var cacheView = Cache.CacheData(host, new Cache.CacheInput() { Data = roleMappedData.Data, Caching = cachingType.Value }).OutputData; cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames()); } var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples); return(new TOut() { PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor) }); } }