private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData, ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); inputPredictor = null; } ch.Assert(validData == null || trainer.Info.SupportsValidation); var predictor = trainer.Train(new TrainContext(data, validData, inputPredictor)); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
/// <summary> /// Initializes the <see cref="MetaMulticlassTrainer{TTransformer, TModel}"/> from the Arguments class. /// </summary> /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param> /// <param name="args">The legacy arguments <see cref="ArgumentsBase"/>class.</param> /// <param name="name">The component name.</param> /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param> /// <param name="singleEstimator">The binary estimator.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorCalibratorTrainer"/></param> internal MetaMulticlassTrainer(IHostEnvironment env, ArgumentsBase args, string name, string labelColumn = null, TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null) { Host = Contracts.CheckRef(env, nameof(env)).Register(name); Host.CheckValue(args, nameof(args)); Args = args; if (labelColumn != null) { LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); } // Create the first trainer so errors in the args surface early. _trainer = singleEstimator ?? CreateTrainer(); Calibrator = calibrator ?? new PlattCalibratorTrainer(env); if (args.Calibrator != null) { Calibrator = args.Calibrator.CreateComponent(Host); } // Regarding caching, no matter what the internal predictor, we're performing many passes // simply by virtue of this being a meta-trainer, so we will still cache. Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization); }
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null) { ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env); return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor)); }
public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData, IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null) { ICalibratorTrainer caliTrainer = calibrator?.CreateComponent(env); return(TrainCore(env, ch, data, trainer, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor)); }
public TagTrainOrScoreTransform(IHostEnvironment env, Arguments args, IDataView input) : base(env, input, LoaderSignature) { _host.CheckValue(args, "args"); _args = args; _cali = null; _scorer = null; _predictor = null; _sourcePipe = Create(_host, args, input, out _sourceCtx); }
/// <summary> /// Predicts a target using a linear multiclass classification model trained with the <see cref="Pkpd"/>. /// </summary> /// <remarks> /// <para> /// In the Pairwise coupling (PKPD) strategy, a binary classification algorithm is used to train one classifier for each pair of classes. /// Prediction is then performed by running these binary classifiers, and computing a score for each class by counting how many of the binary /// classifiers predicted it. The prediction is the class with the highest score. /// </para> /// </remarks> /// <param name="ctx">The <see cref="MulticlassClassificationContext.MulticlassClassificationTrainers"/>.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumn">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param> public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000) { Contracts.CheckValue(ctx, nameof(ctx)); return(new Pkpd(CatalogUtils.GetEnvironment(ctx), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples)); }
/// <summary> /// Predicts a target using a linear multiclass classification model trained with the <see cref="Ova"/>. /// </summary> /// <remarks> /// <para> /// In <see cref="Ova"/> In this strategy, a binary classification algorithm is used to train one classifier for each class, /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, /// and choosing the prediction with the highest confidence score. /// </para> /// </remarks> /// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumn">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param> /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param> public static Ova OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000, bool useProbabilities = true) { Contracts.CheckValue(catalog, nameof(catalog)); return(new Ova(CatalogUtils.GetEnvironment(catalog), binaryEstimator, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples, useProbabilities)); }
/// <summary> /// Initializes a new instance of the <see cref="Pkpd"/> /// </summary> /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumn">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param> public Pkpd(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000) : base(env, new Arguments { ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, MaxCalibrationExamples = maxCalibrationExamples, }, LoadNameValue, labelColumn, binaryEstimator, calibrator) { Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null."); }
/// <summary> /// Initializes a new instance of the <see cref="PairwiseCouplingTrainer"/> /// </summary> /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="labelColumnName">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="calibrator">The calibrator to use for each model instance. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param> internal PairwiseCouplingTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maximumCalibrationExampleCount = 1000000000) : base(env, new Options { ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, MaxCalibrationExamples = maximumCalibrationExampleCount, }, LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); }
/// <summary> /// Predicts a target using a linear multiclass classification model trained with the <see cref="Pkpd"/>. /// </summary> /// <remarks> /// <para> /// In the Pairwise coupling (PKPD) strategy, a binary classification algorithm is used to train one classifier for each pair of classes. /// Prediction is then performed by running these binary classifiers, and computing a score for each class by counting how many of the binary /// classifiers predicted it. The prediction is the class with the highest score. /// </para> /// </remarks> /// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumn">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param> /// <typeparam name="TModel">The type of the model. This type parameter will usually be inferred automatically from <paramref name="binaryEstimator"/>.</typeparam> public static Pkpd PairwiseCoupling <TModel>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator <ISingleFeaturePredictionTransformer <TModel>, TModel> binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1_000_000_000) where TModel : class { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator <ISingleFeaturePredictionTransformer <IPredictorProducing <float> >, IPredictorProducing <float> > est)) { throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); } return(new Pkpd(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples)); }
/// <summary> /// Initializes a new instance of <see cref="OneVersusAllTrainer"/>. /// </summary> /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumnName">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">If true will treat missing labels as negative labels.</param> /// <param name="maximumCalibrationExampleCount">Number of instances to train the calibrator.</param> /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param> internal OneVersusAllTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) : base(env, new Options { ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, MaxCalibrationExamples = maximumCalibrationExampleCount, }, LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); _options = (Options)Args; _options.UseProbabilities = useProbabilities; }
/// <summary> /// Initializes a new instance of <see cref="Ova"/>. /// </summary> /// <param name="env">The <see cref="IHostEnvironment"/> instance.</param> /// <param name="binaryEstimator">An instance of a binary <see cref="ITrainerEstimator{TTransformer, TPredictor}"/> used as the base trainer.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitely provided, it will default to <see cref="PlattCalibratorTrainer"/></param> /// <param name="labelColumn">The name of the label colum.</param> /// <param name="imputeMissingLabelsAsNegative">Whether to treat missing labels as having negative labels, instead of keeping them missing.</param> /// <param name="maxCalibrationExamples">Number of instances to train the calibrator.</param> /// <param name="useProbabilities">Use probabilities (vs. raw outputs) to identify top-score category.</param> internal Ova(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 1000000000, bool useProbabilities = true) : base(env, new Arguments { ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, MaxCalibrationExamples = maxCalibrationExamples, }, LoadNameValue, labelColumn, binaryEstimator, calibrator) { Host.CheckValue(labelColumn, nameof(labelColumn), "Label column should not be null."); _args = (Arguments)Args; _args.UseProbabilities = useProbabilities; }
/// <summary> /// Initializes the <see cref="MetaMulticlassClassificationTrainer{TTransformer, TModel}"/> from the <see cref="OptionsBase"/> class. /// </summary> /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param> /// <param name="options">The legacy arguments <see cref="OptionsBase"/>class.</param> /// <param name="name">The component name.</param> /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param> /// <param name="singleEstimator">The binary estimator.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorTrainer"/></param> internal MetaMulticlassClassificationTrainer(IHostEnvironment env, OptionsBase options, string name, string labelColumn = null, TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null) { Host = Contracts.CheckRef(env, nameof(env)).Register(name); Host.CheckValue(options, nameof(options)); Args = options; if (labelColumn != null) LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); Trainer = singleEstimator ?? CreateTrainer(); Calibrator = calibrator ?? new PlattCalibratorTrainer(env); if (options.Calibrator != null) Calibrator = options.Calibrator.CreateComponent(Host); // Regarding caching, no matter what the internal predictor, we're performing many passes // simply by virtue of this being a meta-trainer, so we will still cache. Info = new TrainerInfo(normalization: Trainer.Info.NeedNormalization); }
private protected CalibratorEstimatorBase(IHostEnvironment env, ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn) { Host = env; _calibratorTrainer = calibratorTrainer; if (!string.IsNullOrWhiteSpace(labelColumn)) { LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn); } else { env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified."); } ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata? if (weightColumn != null) { WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn); } }
private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData, ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inpPredictor); var trainerRmd = trainer as ITrainer <RoleMappedData>; if (trainerRmd == null) { throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name); } Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore; IPredictor predictor; AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace("Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod( typeof(RoleMappedData), inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor)); Action <RoleMappedData> trainExam = trainerRmd.Train; genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor }); ch.Trace("Constructing predictor"); predictor = trainerRmd.CreatePredictor(); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
/// <summary> /// Trains a model. /// </summary> /// <param name="env">host</param> /// <param name="ch">channel</param> /// <param name="data">traing data</param> /// <param name="validData">validation data</param> /// <param name="calibrator">calibrator</param> /// <param name="maxCalibrationExamples">number of examples used to calibrate</param> /// <param name="cacheData">cache training data</param> /// <param name="inputPredictor">for continuous training, initial state</param> /// <returns>predictor</returns> public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null, ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 0, bool?cacheData = null, IPredictor inputPredictor = null) { /* * return TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples, * cacheData, inpPredictor); */ var trainer = Trainer; var name = LoadName; Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckNonEmpty(name, nameof(name)); ch.CheckValueOrNull(validData); ch.CheckValueOrNull(inputPredictor); AddCacheIfWanted(env, ch, trainer, ref data, cacheData); ch.Trace(MessageSensitivity.None, "Training"); if (validData != null) { AddCacheIfWanted(env, ch, trainer, ref validData, cacheData); } if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining) { ch.Warning(MessageSensitivity.None, "Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) + ": Trainer does not support incremental training."); inputPredictor = null; } ch.Assert(validData == null || trainer.Info.SupportsValidation); var predictor = trainer.Train(new TrainContext(data, validData, null, inputPredictor)); return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data)); }
public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, LoaderSignature) { _args = new Arguments(); _args.Read(ctx, _host); bool hasPredictor = ctx.Reader.ReadByte() == 1; bool hasCali = ctx.Reader.ReadByte() == 1; bool hasScorer = ctx.Reader.ReadByte() == 1; if (hasPredictor) { ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor"); } else { _predictor = null; } using (var ch = _host.Start("TagTrainOrScoreTransform loading")) { var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag); if (views.Any()) { throw _host.Except("Tag '{0}' is already used.", _args.tag); } var customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group); if (hasCali) { ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor); } else { _cali = null; } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } if (hasScorer) { ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data); } else { _scorer = null; } ch.Info("Tagging with tag '{0}'.", _args.tag); var ar = new TagViewTransform.Arguments { tag = _args.tag }; var res = new TagViewTransform(_host, ar, _scorer, _predictor); _sourcePipe = res; } }
public IPredictor Calibrate(IChannel ch, IDataView data, ICalibratorTrainer caliTrainer, int maxRows) { Host.CheckValue(ch, nameof(ch)); ch.CheckValue(data, nameof(data)); ch.CheckValue(caliTrainer, nameof(caliTrainer)); if (caliTrainer.NeedsTraining) { var bound = new Bound(this, new RoleMappedSchema(data.Schema)); using (var curs = data.GetRowCursor(col => true)) { var scoreGetter = (ValueGetter <Single>)bound.CreateScoreGetter(curs, col => true, out Action disposer); // We assume that we can use the label column of the first predictor, since if the labels are not identical // then the whole model is garbage anyway. var labelGetter = bound.GetLabelGetter(curs, 0, out Action disp); disposer += disp; var weightGetter = bound.GetWeightGetter(curs, 0, out disp); disposer += disp; try { int num = 0; while (curs.MoveNext()) { Single label = 0; labelGetter(ref label); if (!FloatUtils.IsFinite(label)) { continue; } Single score = 0; scoreGetter(ref score); if (!FloatUtils.IsFinite(score)) { continue; } Single weight = 0; weightGetter(ref weight); if (!FloatUtils.IsFinite(weight)) { continue; } caliTrainer.ProcessTrainingExample(score, label > 0, weight); if (maxRows > 0 && ++num >= maxRows) { break; } } } finally { disposer?.Invoke(); } } } var calibrator = caliTrainer.FinishTraining(ch); return(CalibratorUtils.CreateCalibratedPredictor(Host, this, calibrator)); }
IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx) { sourceCtx = input; Contracts.CheckValue(env, "env"); env.CheckValue(args, "args"); env.CheckValue(input, "input"); env.CheckValue(args.tag, "tag is empty"); env.CheckValue(args.trainer, "trainer", "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag); if (views.Any()) { throw env.Except("Tag '{0}' is already used.", args.tag); } var host = env.Register("TagTrainOrScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer); ITrainer trainer = trainerSett.CreateInstance(host); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group); ICalibratorTrainer calibrator = args.calibrator == null ? null : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host); var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N"); var extTrainer = new ExtendedTrainer(trainer, nameTrainer); _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples); if (!string.IsNullOrEmpty(args.outputModel)) { ch.Info("Saving model into '{0}'", args.outputModel); using (var fs = File.Create(args.outputModel)) TrainUtils.SaveModel(env, ch, fs, _predictor, data); ch.Info("Done."); } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } ch.Trace("Scoring"); if (_args.scorer != null) { var mapper = new SchemaBindablePredictorWrapper(_predictor); var roles = new RoleMappedSchema(input.Schema, null, feat, group: group); var bound = mapper.Bind(_host, roles); var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer); _scorer = scorPars.CreateInstance(_host, input, bound, roles); } else { _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor); } ch.Info("Tagging with tag '{0}'.", args.tag); var ar = new TagViewTransform.Arguments { tag = args.tag }; var res = new TagViewTransform(env, ar, _scorer, _predictor); return(res); } }