IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx) { sourceCtx = input; Contracts.CheckValue(env, "env"); env.CheckValue(args, "args"); env.CheckValue(input, "input"); env.CheckValue(args.tag, "tag is empty"); env.CheckValue(args.trainer, "trainer", "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag); if (views.Any()) { throw env.Except("Tag '{0}' is already used.", args.tag); } var host = env.Register("TagTrainOrScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer); ITrainer trainer = trainerSett.CreateInstance(host); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumns); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group); ICalibratorTrainer calibrator = args.calibrator == null ? null : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host); var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N"); var extTrainer = new ExtendedTrainer(trainer, nameTrainer); _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples); if (!string.IsNullOrEmpty(args.outputModel)) { ch.Info("Saving model into '{0}'", args.outputModel); using (var fs = File.Create(args.outputModel)) TrainUtils.SaveModel(env, ch, fs, _predictor, data); ch.Info("Done."); } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } ch.Trace("Scoring"); if (_args.scorer != null) { var mapper = new SchemaBindablePredictorWrapper(_predictor); var roles = new RoleMappedSchema(input.Schema, null, feat, group: group); var bound = (mapper as ISchemaBindableMapper).Bind(_host, roles); var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer); _scorer = scorPars.CreateInstance(_host, input, bound, roles); } else { _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor); } ch.Info("Tagging with tag '{0}'.", args.tag); var ar = new TagViewTransform.Arguments { tag = args.tag }; var res = new TagViewTransform(env, ar, _scorer, _predictor); return(res); } }
public IEnumerable <Tuple <string, ITaggedDataView> > EnumerateTaggedView(bool recursive = true) { return(TagHelper.EnumerateTaggedView(recursive, this)); }
public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, LoaderSignature) { _args = new Arguments(); _args.Read(ctx, _host); bool hasPredictor = ctx.Reader.ReadByte() == 1; bool hasCali = ctx.Reader.ReadByte() == 1; bool hasScorer = ctx.Reader.ReadByte() == 1; if (hasPredictor) { ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor"); } else { _predictor = null; } using (var ch = _host.Start("TagTrainOrScoreTransform loading")) { var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag); if (views.Any()) { throw _host.Except("Tag '{0}' is already used.", _args.tag); } var customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumns); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group); if (hasCali) { ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor); } else { _cali = null; } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } if (hasScorer) { ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data); } else { _scorer = null; } ch.Info("Tagging with tag '{0}'.", _args.tag); var ar = new TagViewTransform.Arguments { tag = _args.tag }; var res = new TagViewTransform(_host, ar, _scorer, _predictor); _sourcePipe = res; } }