Example #1
0
        public void TestI_Q_KMeansInnerAPIWithDataFrame()
        {
            var methodName       = System.Reflection.MethodBase.GetCurrentMethod().Name;
            var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName);
            var iris             = FileHelper.GetTestFile("iris.txt");

            using (var env = new ConsoleEnvironment(conc: 1))
            {
                ComponentHelper.AddStandardComponents(env);

                var        df      = Scikit.ML.DataManipulation.DataFrameIO.ReadCsv(iris, sep: '\t', dtypes: new ColumnType[] { NumberType.R4 });
                var        conc    = env.CreateTransform("Concat{col=Feature:Sepal_length,Sepal_width}", df);
                var        roleMap = env.CreateExamples(conc, "Feature", label: "Label");
                var        trainer = CreateTrainer(env, "km");
                IPredictor model;
                using (var ch = env.Start("test"))
                    model = TrainUtils.Train(env, ch, roleMap, trainer, null, 0);

                using (var ch = env.Start("Save"))
                    using (var fs = File.Create(outModelFilePath))
                        TrainUtils.SaveModel(env, ch, fs, model, roleMap);

                Predictor pred;
                using (var fs = File.OpenRead(outModelFilePath))
                    pred = env.LoadPredictorOrNull(fs);

#pragma warning disable CS0618
                var scorer = ScoreUtils.GetScorer(pred.GetPredictorObject() as IPredictor, roleMap, env, null);
#pragma warning restore CS0618
                var dfout = Scikit.ML.DataManipulation.DataFrameIO.ReadView(scorer);
                Assert.AreEqual(dfout.Shape, new Tuple <int, int>(150, 13));
            }
        }
 /// <summary>
 /// Trains a model.
 /// </summary>
 /// <param name="env">host</param>
 /// <param name="ch">channel</param>
 /// <param name="data">traing data</param>
 /// <param name="validData">validation data</param>
 /// <param name="calibrator">calibrator</param>
 /// <param name="maxCalibrationExamples">number of examples used to calibrate</param>
 /// <param name="cacheData">cache training data</param>
 /// <param name="inpPredictor">for continuous training, initial state</param>
 /// <returns>predictor</returns>
 public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null,
                         SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator = null, int maxCalibrationExamples = 0,
                         bool?cacheData = null, IPredictor inpPredictor = null)
 {
     return(TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples,
                             cacheData, inpPredictor));
 }
        private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input, IComponentFactory <IPredictor, ISchemaBindableMapper> mapperFactory)
        {
            Contracts.AssertValue(env, nameof(env));
            env.AssertValue(args, nameof(args));
            env.AssertValue(trainer, nameof(trainer));
            env.AssertValue(input, nameof(input));

            var host = env.Register("TrainAndScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var    customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
                string feat;
                string group;
                var    data      = CreateDataFromArgs(ch, input, args, out feat, out group);
                var    predictor = TrainUtils.Train(host, ch, data, trainer, null,
                                                    args.Calibrator, args.MaxCalibrationExamples, null);

                return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, data.Schema, mapperFactory));
            }
        }
Example #4
0
        private static void TrainCore(IHost host, IDataView input, Arguments args, ref VBuffer <Single> scores)
        {
            Contracts.AssertValue(host);
            host.AssertValue(args);
            host.AssertValue(input);
            host.Assert(args.Threshold.HasValue != args.NumSlotsToKeep.HasValue);

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                ITrainer trainer = args.Filter.CreateComponent(host);

                IDataView view = input;

                ISchema schema  = view.Schema;
                var     label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.LabelColumn), args.LabelColumn, DefaultColumnNames.Label);
                var     feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
                var     group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
                var     weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.WeightColumn), args.WeightColumn, DefaultColumnNames.Weight);
                var     name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.NameColumn), args.NameColumn, DefaultColumnNames.Name);

                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, args.NormalizeFeatures);

                ch.Trace("Binding columns");

                var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn);
                var data       = new RoleMappedData(view, label, feature, group, weight, name, customCols);

                var predictor = TrainUtils.Train(host, ch, data, trainer, null,
                                                 null, 0, args.CacheData);

                var rfs = predictor as IPredictorWithFeatureWeights <Single>;
                Contracts.AssertValue(rfs);
                rfs.GetFeatureWeights(ref scores);
            }
        }
Example #5
0
        public static TOut Train <TArg, TOut>(IHost host, TArg input,
                                              Func <ITrainer> createTrainer,
                                              Func <string> getLabel  = null,
                                              Func <string> getWeight = null,
                                              Func <string> getGroup  = null,
                                              Func <string> getName   = null,
                                              Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null,
                                              ICalibratorTrainerFactory calibrator = null,
                                              int maxCalibrationExamples           = 0)
            where TArg : LearnerInputBase
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            using (var ch = host.Start("Training"))
            {
                var schema  = input.TrainingData.Schema;
                var feature = FindColumn(ch, schema, input.FeatureColumn);
                var label   = getLabel?.Invoke();
                var weight  = getWeight?.Invoke();
                var group   = getGroup?.Invoke();
                var name    = getName?.Invoke();
                var custom  = getCustom?.Invoke();

                var trainer = createTrainer();

                IDataView view = input.TrainingData;
                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

                ch.Trace("Binding columns");
                var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);

                RoleMappedData    cachedRoleMappedData = roleMappedData;
                Cache.CachingType?cachingType          = null;
                switch (input.Caching)
                {
                case CachingOptions.Memory:
                {
                    cachingType = Cache.CachingType.Memory;
                    break;
                }

                case CachingOptions.Disk:
                {
                    cachingType = Cache.CachingType.Disk;
                    break;
                }

                case CachingOptions.Auto:
                {
                    // REVIEW: we should switch to hybrid caching in future.
                    if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
                    {
                        // default to Memory so mml is on par with maml
                        cachingType = Cache.CachingType.Memory;
                    }
                    break;
                }

                case CachingOptions.None:
                    break;

                default:
                    throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
                }

                if (cachingType.HasValue)
                {
                    var cacheView = Cache.CacheData(host, new Cache.CacheInput()
                    {
                        Data    = roleMappedData.Data,
                        Caching = cachingType.Value
                    }).OutputData;
                    cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
                }

                var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
                return(new TOut()
                {
                    PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor)
                });
            }
        }