示例#1
0
        public static TOut Train <TArg, TOut>(IHost host, TArg input,
                                              Func <ITrainer> createTrainer,
                                              Func <string> getLabel  = null,
                                              Func <string> getWeight = null,
                                              Func <string> getGroup  = null,
                                              Func <string> getName   = null,
                                              Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null,
                                              ICalibratorTrainerFactory calibrator = null,
                                              int maxCalibrationExamples           = 0)
            where TArg : LearnerInputBase
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            using (var ch = host.Start("Training"))
            {
                ISchema schema  = input.TrainingData.Schema;
                var     feature = FindColumn(ch, schema, input.FeatureColumn);
                var     label   = getLabel?.Invoke();
                var     weight  = getWeight?.Invoke();
                var     group   = getGroup?.Invoke();
                var     name    = getName?.Invoke();
                var     custom  = getCustom?.Invoke();

                var trainer = createTrainer();

                IDataView view = input.TrainingData;
                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

                ch.Trace("Binding columns");
                var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom);

                RoleMappedData    cachedRoleMappedData = roleMappedData;
                Cache.CachingType?cachingType          = null;
                switch (input.Caching)
                {
                case CachingOptions.Memory:
                {
                    cachingType = Cache.CachingType.Memory;
                    break;
                }

                case CachingOptions.Disk:
                {
                    cachingType = Cache.CachingType.Disk;
                    break;
                }

                case CachingOptions.Auto:
                {
                    ITrainerEx trainerEx = trainer as ITrainerEx;
                    // REVIEW: we should switch to hybrid caching in future.
                    if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching))
                    {
                        // default to Memory so mml is on par with maml
                        cachingType = Cache.CachingType.Memory;
                    }
                    break;
                }

                case CachingOptions.None:
                    break;

                default:
                    throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
                }

                if (cachingType.HasValue)
                {
                    var cacheView = Cache.CacheData(host, new Cache.CacheInput()
                    {
                        Data    = roleMappedData.Data,
                        Caching = cachingType.Value
                    }).OutputData;
                    cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames());
                }

                var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
                var output    = new TOut()
                {
                    PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor)
                };

                ch.Done();
                return(output);
            }
        }
示例#2
0
        public static TOut Train <TArg, TOut>(IHost host, TArg input,
                                              Func <ITrainer> createTrainer,
                                              Func <string> getLabel  = null,
                                              Func <string> getWeight = null,
                                              Func <string> getGroup  = null,
                                              Func <string> getName   = null,
                                              Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null,
                                              ICalibratorTrainerFactory calibrator = null,
                                              int maxCalibrationExamples           = 0)
            where TArg : TrainerInputBase
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            using (var ch = host.Start("Training"))
            {
                var schema  = input.TrainingData.Schema;
                var feature = FindColumn(ch, schema, input.FeatureColumnName);
                var label   = getLabel?.Invoke();
                var weight  = getWeight?.Invoke();
                var group   = getGroup?.Invoke();
                var name    = getName?.Invoke();
                var custom  = getCustom?.Invoke();

                var trainer = createTrainer();

                IDataView view = input.TrainingData;
                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

                ch.Trace("Binding columns");
                var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);

                RoleMappedData cachedRoleMappedData = roleMappedData;
                const string   registrationName     = "CreateCache";
                var            createCacheHost      = host.Register(registrationName);
                IDataView      outputData           = null;

                switch (input.Caching)
                {
                case CachingOptions.Memory:
                {
                    outputData = new CacheDataView(host, roleMappedData.Data, null);
                    break;
                }

                case CachingOptions.Auto:
                {
                    // REVIEW: we should switch to hybrid caching in future.
                    if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
                    {
                        // default to Memory so mml is on par with maml
                        outputData = new CacheDataView(host, roleMappedData.Data, null);
                    }
                    break;
                }

                case CachingOptions.None:
                    break;

                default:
                    throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
                }

                if (outputData != null)
                {
                    cachedRoleMappedData = new RoleMappedData(outputData, roleMappedData.Schema.GetColumnRoleNames());
                }

                var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
                return(new TOut()
                {
                    PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor)
                });
            }
        }