예제 #1
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                            ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor));

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
        /// <summary>
        /// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class.
        /// </summary>
        /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
        /// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param>
        /// <param name="name">The component name.</param>
        /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
        /// <param name="singleEstimator">The binary estimator.</param>
        /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param>
        internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null,
                                       TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
        {
            Host = Contracts.CheckRef(env, nameof(env)).Register(name);
            Host.CheckValue(args, nameof(args));
            Args = args;

            if (labelColumn != null)
            {
                LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
            }

            // Create the first trainer so errors in the args surface early.
            _trainer = singleEstimator ?? CreateTrainer();

            Calibrator = calibrator ?? new PlattCalibratorTrainer(env);

            if (args.Calibrator != null)
            {
                Calibrator = args.Calibrator.CreateComponent(Host);
            }

            // Regarding caching, no matter what the internal predictor, we're performing many passes
            // simply by virtue of this being a meta-trainer, so we will still cache.
            Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);
        }
예제 #3
0
        public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                       SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null)
        {
            ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);

            return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor));
        }
예제 #4
0
        public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                       IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null)
        {
            ICalibratorTrainer caliTrainer = calibrator?.CreateComponent(env);

            return(TrainCore(env, ch, data, trainer, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor));
        }
예제 #5
0
 public TagTrainOrScoreTransform(IHostEnvironment env, Arguments args, IDataView input) :
     base(env, input, LoaderSignature)
 {
     _host.CheckValue(args, "args");
     _args       = args;
     _cali       = null;
     _scorer     = null;
     _predictor  = null;
     _sourcePipe = Create(_host, args, input, out _sourceCtx);
 }
 /// <summary>
 /// Predicts a target using a linear multiclass classification model trained with the <see cref="Pkpd"/>.
 /// </summary>
 /// <remarks>
 /// <para>
 /// In the Pairwise coupling (PKPD) strategy, a binary classification algorithm is used to train one classifier for each pair of classes.
 /// Prediction is then performed by running these binary classifiers, and computing a score for each class by counting how many of the binary
 /// classifiers predicted it. The prediction is the class with the highest score.
 /// </para>
 /// </remarks>
 /// <param name="ctx">The <see cref="MulticlassClassificationContext.MulticlassClassificationTrainers"/>.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="labelColumn">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
 /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
 public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
                                     ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > binaryEstimator,
                                     string labelColumn = DefaultColumnNames.Label,
                                     bool imputeMissingLabelsAsNegative = false,
                                     ICalibratorTrainer calibrator      = null,
                                     int maxCalibrationExamples         = 1000000000)
 {
     Contracts.CheckValue(ctx, nameof(ctx));
     return(new Pkpd(CatalogUtils.GetEnvironment(ctx), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples));
 }
예제 #7
0
 /// <summary>
 /// Predicts a target using a linear multiclass classification model trained with the <see cref="Ova"/>.
 /// </summary>
 /// <remarks>
 /// <para>
 /// In <see cref="Ova"/> In this strategy, a binary classification algorithm is used to train one classifier for each class,
 /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers,
 /// and choosing the prediction with the highest confidence score.
 /// </para>
 /// </remarks>
 /// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="labelColumn">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
 /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
 /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param>
 public static Ova OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
                                ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > binaryEstimator,
                                string labelColumn = DefaultColumnNames.Label,
                                bool imputeMissingLabelsAsNegative = false,
                                ICalibratorTrainer calibrator      = null,
                                int maxCalibrationExamples         = 1000000000,
                                bool useProbabilities = true)
 {
     Contracts.CheckValue(catalog, nameof(catalog));
     return(new Ova(CatalogUtils.GetEnvironment(catalog), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples, useProbabilities));
 }
예제 #8
0
 /// <summary>
 /// Initializes a new instance of the <see cref="Pkpd"/>
 /// </summary>
 /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="labelColumn">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
 /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
 public Pkpd(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumn = DefaultColumnNames.Label,
             bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000)
     : base(env,
            new Arguments
 {
     ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
     MaxCalibrationExamples        = maxCalibrationExamples,
 },
            LoadNameValue, labelColumn, binaryEstimator, calibrator)
 {
     Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null.");
 }
예제 #9
0
 /// <summary>
 /// Initializes a new instance of the <see cref="PairwiseCouplingTrainer"/>
 /// </summary>
 /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="labelColumnName">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
 /// <param name="calibrator">The calibrator to use for each model instance. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param>
 internal PairwiseCouplingTrainer(IHostEnvironment env,
                                  TScalarTrainer binaryEstimator,
                                  string labelColumnName             = DefaultColumnNames.Label,
                                  bool imputeMissingLabelsAsNegative = false,
                                  ICalibratorTrainer calibrator      = null,
                                  int maximumCalibrationExampleCount = 1000000000)
     : base(env,
            new Options
 {
     ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
     MaxCalibrationExamples        = maximumCalibrationExampleCount,
 },
            LoadNameValue, labelColumnName, binaryEstimator, calibrator)
 {
     Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null.");
 }
예제 #10
0
        /// <summary>
        /// Predicts a target using a linear multiclass classification model trained with the <see cref="Pkpd"/>.
        /// </summary>
        /// <remarks>
        /// <para>
        /// In the Pairwise coupling (PKPD) strategy, a binary classification algorithm is used to train one classifier for each pair of classes.
        /// Prediction is then performed by running these binary classifiers, and computing a score for each class by counting how many of the binary
        /// classifiers predicted it. The prediction is the class with the highest score.
        /// </para>
        /// </remarks>
        /// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param>
        /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
        /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
        /// <param name="labelColumn">The name of the label colum.</param>
        /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
        /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
        /// <typeparam name="TModel">The type of the model. This type parameter will usually be inferred automatically from <paramref name="binaryEstimator"/>.</typeparam>
        public static Pkpd PairwiseCoupling <TModel>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
                                                     ITrainerEstimator <ISingleFeaturePredictionTransformer <TModel>, TModel> binaryEstimator,
                                                     string labelColumn = DefaultColumnNames.Label,
                                                     bool imputeMissingLabelsAsNegative = false,
                                                     ICalibratorTrainer calibrator      = null,
                                                     int maxCalibrationExamples         = 1_000_000_000)
            where TModel : class
        {
            Contracts.CheckValue(catalog, nameof(catalog));
            var env = CatalogUtils.GetEnvironment(catalog);

            if (!(binaryEstimator is ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > est))
            {
                throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model.");
            }
            return(new Pkpd(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples));
        }
 /// <summary>
 /// Initializes a new instance of <see cref="OneVersusAllTrainer"/>.
 /// </summary>
 /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="calibrator">The calibrator. If a calibrator is not provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="labelColumnName">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">If true will treat missing labels as negative labels.</param>
 /// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param>
 /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param>
 internal OneVersusAllTrainer(IHostEnvironment env,
                              TScalarTrainer binaryEstimator,
                              string labelColumnName             = DefaultColumnNames.Label,
                              bool imputeMissingLabelsAsNegative = false,
                              ICalibratorTrainer calibrator      = null,
                              int maximumCalibrationExampleCount = 1000000000,
                              bool useProbabilities = true)
     : base(env,
            new Options
 {
     ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
     MaxCalibrationExamples        = maximumCalibrationExampleCount,
 },
            LoadNameValue, labelColumnName, binaryEstimator, calibrator)
 {
     Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null.");
     _options = (Options)Args;
     _options.UseProbabilities = useProbabilities;
 }
예제 #12
0
 /// <summary>
 /// Initializes a new instance of <see cref="Ova"/>.
 /// </summary>
 /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param>
 /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param>
 /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
 /// <param name="labelColumn">The name of the label colum.</param>
 /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param>
 /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param>
 /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param>
 internal Ova(IHostEnvironment env,
              TScalarTrainer binaryEstimator,
              string labelColumn = DefaultColumnNames.Label,
              bool imputeMissingLabelsAsNegative = false,
              ICalibratorTrainer calibrator      = null,
              int maxCalibrationExamples         = 1000000000,
              bool useProbabilities = true)
     : base(env,
            new Arguments
 {
     ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative,
     MaxCalibrationExamples        = maxCalibrationExamples,
 },
            LoadNameValue, labelColumn, binaryEstimator, calibrator)
 {
     Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null.");
     _args = (Arguments)Args;
     _args.UseProbabilities = useProbabilities;
 }
예제 #13
0
        /// <summary>
        /// Initializes the <see cref="MetaMulticlassClassificationTrainer{TTransformer, TModel}"/> from the <see cref="OptionsBase"/> class.
        /// </summary>
        /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
        /// <param name="options">The legacy arguments <see cref="OptionsBase"/>class.</param>
        /// <param name="name">The component name.</param>
        /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
        /// <param name="singleEstimator">The binary estimator.</param>
        /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
        internal MetaMulticlassClassificationTrainer(IHostEnvironment env, OptionsBase options, string name, string labelColumn = null,
            TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
        {
            Host = Contracts.CheckRef(env, nameof(env)).Register(name);
            Host.CheckValue(options, nameof(options));
            Args = options;

            if (labelColumn != null)
                LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            Trainer = singleEstimator ?? CreateTrainer();

            Calibrator = calibrator ?? new PlattCalibratorTrainer(env);
            if (options.Calibrator != null)
                Calibrator = options.Calibrator.CreateComponent(Host);

            // Regarding caching, no matter what the internal predictor, we're performing many passes
            // simply by virtue of this being a meta-trainer, so we will still cache.
            Info = new TrainerInfo(normalization: Trainer.Info.NeedNormalization);
        }
        private protected CalibratorEstimatorBase(IHostEnvironment env,
                                                  ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn)
        {
            Host = env;
            _calibratorTrainer = calibratorTrainer;

            if (!string.IsNullOrWhiteSpace(labelColumn))
            {
                LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn);
            }
            else
            {
                env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified.");
            }
            ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata?

            if (weightColumn != null)
            {
                WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn);
            }
        }
예제 #15
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                            ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckNonEmpty(name, nameof(name));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inpPredictor);

            var trainerRmd = trainer as ITrainer <RoleMappedData>;

            if (trainerRmd == null)
            {
                throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
            }

            Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore;
            IPredictor predictor;

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
                typeof(RoleMappedData),
                inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
            Action <RoleMappedData> trainExam = trainerRmd.Train;

            genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });

            ch.Trace("Constructing predictor");
            predictor = trainerRmd.CreatePredictor();
            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
예제 #16
0
        /// <summary>
        /// Trains a model.
        /// </summary>
        /// <param name="env">host</param>
        /// <param name="ch">channel</param>
        /// <param name="data">traing data</param>
        /// <param name="validData">validation data</param>
        /// <param name="calibrator">calibrator</param>
        /// <param name="maxCalibrationExamples">number of examples used to calibrate</param>
        /// <param name="cacheData">cache training data</param>
        /// <param name="inputPredictor">for continuous training, initial state</param>
        /// <returns>predictor</returns>
        public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null,
                                ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 0,
                                bool?cacheData = null, IPredictor inputPredictor = null)
        {
            /*
             * return TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples,
             *                      cacheData, inpPredictor);
             */

            var trainer = Trainer;
            var name    = LoadName;

            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckNonEmpty(name, nameof(name));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace(MessageSensitivity.None, "Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning(MessageSensitivity.None, "Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor = trainer.Train(new TrainContext(data, validData, null, inputPredictor));

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
예제 #17
0
        public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, ctx, input, LoaderSignature)
        {
            _args = new Arguments();
            _args.Read(ctx, _host);

            bool hasPredictor = ctx.Reader.ReadByte() == 1;
            bool hasCali      = ctx.Reader.ReadByte() == 1;
            bool hasScorer    = ctx.Reader.ReadByte() == 1;

            if (hasPredictor)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor");
            }
            else
            {
                _predictor = null;
            }

            using (var ch = _host.Start("TagTrainOrScoreTransform loading"))
            {
                var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag);
                if (views.Any())
                {
                    throw _host.Except("Tag '{0}' is already used.", _args.tag);
                }

                var    customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumn);
                string feat;
                string group;
                var    data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group);

                if (hasCali)
                {
                    ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor);
                }
                else
                {
                    _cali = null;
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                if (hasScorer)
                {
                    ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data);
                }
                else
                {
                    _scorer = null;
                }

                ch.Info("Tagging with tag '{0}'.", _args.tag);
                var ar = new TagViewTransform.Arguments {
                    tag = _args.tag
                };
                var res = new TagViewTransform(_host, ar, _scorer, _predictor);
                _sourcePipe = res;
            }
        }
예제 #18
0
            public IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows)
            {
                Host.CheckValue(ch, nameof(ch));
                ch.CheckValue(data, nameof(data));
                ch.CheckValue(caliTrainer, nameof(caliTrainer));

                if (caliTrainer.NeedsTraining)
                {
                    var bound = new Bound(this, new RoleMappedSchema(data.Schema));
                    using (var curs = data.GetRowCursor(col => true))
                    {
                        var scoreGetter = (ValueGetter <Single>)bound.CreateScoreGetter(curs, col => true, out Action disposer);

                        // We assume that we can use the label column of the first predictor, since if the labels are not identical
                        // then the whole model is garbage anyway.
                        var labelGetter = bound.GetLabelGetter(curs, 0, out Action disp);
                        disposer += disp;
                        var weightGetter = bound.GetWeightGetter(curs, 0, out disp);
                        disposer += disp;
                        try
                        {
                            int num = 0;
                            while (curs.MoveNext())
                            {
                                Single label = 0;
                                labelGetter(ref label);
                                if (!FloatUtils.IsFinite(label))
                                {
                                    continue;
                                }
                                Single score = 0;
                                scoreGetter(ref score);
                                if (!FloatUtils.IsFinite(score))
                                {
                                    continue;
                                }
                                Single weight = 0;
                                weightGetter(ref weight);
                                if (!FloatUtils.IsFinite(weight))
                                {
                                    continue;
                                }

                                caliTrainer.ProcessTrainingExample(score, label > 0, weight);

                                if (maxRows > 0 && ++num >= maxRows)
                                {
                                    break;
                                }
                            }
                        }
                        finally
                        {
                            disposer?.Invoke();
                        }
                    }
                }

                var calibrator = caliTrainer.FinishTraining(ch);

                return(CalibratorUtils.CreateCalibratedPredictor(Host, this, calibrator));
            }
예제 #19
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = mapper.Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }