Пример #1
0
        private static void TrainCore(IHost host, IDataView input, Arguments args, ref VBuffer <Single> scores)
        {
            Contracts.AssertValue(host);
            host.AssertValue(args);
            host.AssertValue(input);
            host.Assert(args.Threshold.HasValue != args.NumSlotsToKeep.HasValue);

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                ITrainer trainer = args.Filter.CreateComponent(host);

                IDataView view = input;

                ISchema schema  = view.Schema;
                var     label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.LabelColumn), args.LabelColumn, DefaultColumnNames.Label);
                var     feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
                var     group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
                var     weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.WeightColumn), args.WeightColumn, DefaultColumnNames.Weight);
                var     name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(args.NameColumn), args.NameColumn, DefaultColumnNames.Name);

                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, args.NormalizeFeatures);

                ch.Trace("Binding columns");

                var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, args.CustomColumn);
                var data       = new RoleMappedData(view, label, feature, group, weight, name, customCols);

                var predictor = TrainUtils.Train(host, ch, data, trainer, null,
                                                 null, 0, args.CacheData);

                var rfs = predictor as IPredictorWithFeatureWeights <Single>;
                Contracts.AssertValue(rfs);
                rfs.GetFeatureWeights(ref scores);
            }
        }
Пример #2
0
        public static TOut Train <TArg, TOut>(IHost host, TArg input,
                                              Func <ITrainer> createTrainer,
                                              Func <string> getLabel  = null,
                                              Func <string> getWeight = null,
                                              Func <string> getGroup  = null,
                                              Func <string> getName   = null,
                                              Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null,
                                              ICalibratorTrainerFactory calibrator = null,
                                              int maxCalibrationExamples           = 0)
            where TArg : LearnerInputBase
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            using (var ch = host.Start("Training"))
            {
                var schema  = input.TrainingData.Schema;
                var feature = FindColumn(ch, schema, input.FeatureColumn);
                var label   = getLabel?.Invoke();
                var weight  = getWeight?.Invoke();
                var group   = getGroup?.Invoke();
                var name    = getName?.Invoke();
                var custom  = getCustom?.Invoke();

                var trainer = createTrainer();

                IDataView view = input.TrainingData;
                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

                ch.Trace("Binding columns");
                var roleMappedData = new RoleMappedData(view, label, feature, group, weight, name, custom);

                RoleMappedData    cachedRoleMappedData = roleMappedData;
                Cache.CachingType?cachingType          = null;
                switch (input.Caching)
                {
                case CachingOptions.Memory:
                {
                    cachingType = Cache.CachingType.Memory;
                    break;
                }

                case CachingOptions.Disk:
                {
                    cachingType = Cache.CachingType.Disk;
                    break;
                }

                case CachingOptions.Auto:
                {
                    // REVIEW: we should switch to hybrid caching in future.
                    if (!(input.TrainingData is BinaryLoader) && trainer.Info.WantCaching)
                    {
                        // default to Memory so mml is on par with maml
                        cachingType = Cache.CachingType.Memory;
                    }
                    break;
                }

                case CachingOptions.None:
                    break;

                default:
                    throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
                }

                if (cachingType.HasValue)
                {
                    var cacheView = Cache.CacheData(host, new Cache.CacheInput()
                    {
                        Data    = roleMappedData.Data,
                        Caching = cachingType.Value
                    }).OutputData;
                    cachedRoleMappedData = new RoleMappedData(cacheView, roleMappedData.Schema.GetColumnRoleNames());
                }

                var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, calibrator, maxCalibrationExamples);
                return(new TOut()
                {
                    PredictorModel = new PredictorModelImpl(host, roleMappedData, input.TrainingData, predictor)
                });
            }
        }