public TagViewTransform(IHostEnvironment env, Arguments args, IDataView input, IPredictor predictor) { Contracts.CheckValue(env, "env"); var tagged = input as ITaggedDataView; if (tagged != null) { throw env.Except("The input view is already tagged. Don't tag again with '{0}'.", args.tag); } if (predictor == null) { throw env.Except("Predictor is null, it cannot be tagged with '{0}'.", args.tag); } _host = env.Register(RegistrationName); _host.CheckValue(args, "args"); _args = args; _input = input; _host.CheckValue(args.tag, "Tag cannot be empty."); if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any()) { throw _host.Except("Tag '{0}' is already used.", args.tag); } _parallelViews = new List <Tuple <string, ITaggedDataView> >(); _parallelViews.Add(new Tuple <string, ITaggedDataView>(_args.tag, this)); _taggedPredictor = predictor; }
/// <summary> /// Looks for a predictor among tagged predictors. /// If the tag ends by .zip, it assumes it is a file. /// </summary> /// <param name="env">environment</param> /// <param name="input">IDataView</param> /// <param name="tag">tag name</param> /// <returns>predictor</returns> public static IPredictor GetTaggedPredictor(IHostEnvironment env, IDataView input, string tag) { if (string.IsNullOrEmpty(tag)) { throw env.Except("tag must not be null."); } if (tag.EndsWith(".zip")) { using (Stream modelStream = new FileStream(tag, FileMode.Open, FileAccess.Read)) { var ipred = ModelFileUtils.LoadPredictorOrNull(env, modelStream); return(ipred); } } else { var tagged = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == tag); if (!tagged.Any()) { throw env.Except("Unable to find any view with tag '{0}'.", tag); } if (tagged.Skip(1).Any()) { var allTagged = TagHelper.EnumerateTaggedView(true, input).ToArray(); throw env.Except("Ambiguous tag '{0}' - {1}/{2}.", tag, allTagged.Where(c => c.Item1 == tag).Count(), allTagged.Length); } var predictor = tagged.First().Item2.TaggedPredictor; if (predictor == null) { env.Except("Tagged view '{0}' does not host a predictor.", tag); } return(predictor); } }
/// <summary> /// Reading serialized transform. /// </summary> private TagViewTransform(IHost host, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(host, "host"); _host = host; _host.CheckValue(input, "input"); _host.CheckValue(ctx, "ctx"); _input = input; _args = new Arguments(); _args.Read(ctx, _host); _host.CheckValue(_args.tag, "Tag cannot be empty."); if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag).Any()) { throw _host.Except("Tag '{0}' is already used.", _args.tag); } _parallelViews = new List <Tuple <string, ITaggedDataView> >(); _parallelViews.Add(new Tuple <string, ITaggedDataView>(_args.tag, this)); }
private IDataView Setup(IDataView input) { List <IDataView> concat = new List <IDataView>(); concat.Add(input); foreach (var tag in _args.tag) { var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == tag); if (!selected.Any()) { throw _host.Except("Unable to find a view to append with tag '{0}'", tag); } var first = selected.First(); if (selected.Skip(1).Any()) { throw _host.Except("Tag '{0}' is ambiguous, {1} views were found.", tag, selected.Count()); } concat.Add(first.Item2); } return(AppendRowsDataView.Create(_host, input.Schema, concat.ToArray())); }
static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx) { sourceCtx = input; env.CheckValue(args.tag, "Tag cannot be empty."); if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any()) { throw env.Except("Tag '{0}' is already used.", args.tag); } env.CheckValue(args.selectTag, "Selected tag cannot be empty."); if (string.IsNullOrEmpty(args.filename)) { var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag); if (!selected.Any()) { throw env.Except("Unable to find a view to select with tag '{0}'. Did you forget to specify a filename?", args.selectTag); } var first = selected.First(); if (selected.Skip(1).Any()) { throw env.Except("Tag '{0}' is ambiguous, {1} views were found.", args.selectTag, selected.Count()); } var tagged = input as ITaggedDataView; if (tagged == null) { var ag = new TagViewTransform.Arguments { tag = args.tag }; tagged = new TagViewTransform(env, ag, input); } first.Item2.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) }); tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, first.Item2) }); #if (DEBUG_TIP) long count = DataViewUtils.ComputeRowCount(tagged); if (count == 0) { throw env.Except("Replaced view is empty."); } count = DataViewUtils.ComputeRowCount(first.Item2); if (count == 0) { throw env.Except("Selected view is empty."); } #endif var tr = first.Item2 as IDataTransform; env.AssertValue(tr); return(tr); } else { if (!File.Exists(args.filename)) { throw env.Except("Unable to find file '{0}'.", args.filename); } var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag); if (selected.Any()) { throw env.Except("Tag '{0}' was already given. It cannot be assigned to the new file.", args.selectTag); } var loaderArgs = new BinaryLoader.Arguments(); var file = new MultiFileSource(args.filename); var loadSettings = ScikitSubComponent <ILegacyDataLoader, SignatureDataLoader> .AsSubComponent(args.loaderSettings); IDataView loader = loadSettings.CreateInstance(env, file); var ag = new TagViewTransform.Arguments { tag = args.selectTag }; var newInput = new TagViewTransform(env, ag, loader); var tagged = input as ITaggedDataView; if (tagged == null) { ag = new TagViewTransform.Arguments { tag = args.tag }; tagged = new TagViewTransform(env, ag, input); } newInput.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) }); tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, newInput) }); var schema = loader.Schema; if (schema.Count == 0) { throw env.Except("The loaded view '{0}' is empty (empty schema).", args.filename); } return(newInput); } }
public IEnumerable <Tuple <string, ITaggedDataView> > EnumerateTaggedView(bool recursive = true) { return(TagHelper.EnumerateTaggedView(recursive, this)); }
IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx) { sourceCtx = input; Contracts.CheckValue(env, "env"); env.CheckValue(args, "args"); env.CheckValue(input, "input"); env.CheckValue(args.tag, "tag is empty"); env.CheckValue(args.trainer, "trainer", "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag); if (views.Any()) { throw env.Except("Tag '{0}' is already used.", args.tag); } var host = env.Register("TagTrainOrScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer); ITrainer trainer = trainerSett.CreateInstance(host); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group); ICalibratorTrainer calibrator = args.calibrator == null ? null : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host); var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N"); var extTrainer = new ExtendedTrainer(trainer, nameTrainer); _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples); if (!string.IsNullOrEmpty(args.outputModel)) { ch.Info("Saving model into '{0}'", args.outputModel); using (var fs = File.Create(args.outputModel)) TrainUtils.SaveModel(env, ch, fs, _predictor, data); ch.Info("Done."); } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } ch.Trace("Scoring"); if (_args.scorer != null) { var mapper = new SchemaBindablePredictorWrapper(_predictor); var roles = new RoleMappedSchema(input.Schema, null, feat, group: group); var bound = mapper.Bind(_host, roles); var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer); _scorer = scorPars.CreateInstance(_host, input, bound, roles); } else { _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor); } ch.Info("Tagging with tag '{0}'.", args.tag); var ar = new TagViewTransform.Arguments { tag = args.tag }; var res = new TagViewTransform(env, ar, _scorer, _predictor); return(res); } }
public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, LoaderSignature) { _args = new Arguments(); _args.Read(ctx, _host); bool hasPredictor = ctx.Reader.ReadByte() == 1; bool hasCali = ctx.Reader.ReadByte() == 1; bool hasScorer = ctx.Reader.ReadByte() == 1; if (hasPredictor) { ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor"); } else { _predictor = null; } using (var ch = _host.Start("TagTrainOrScoreTransform loading")) { var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag); if (views.Any()) { throw _host.Except("Tag '{0}' is already used.", _args.tag); } var customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumn); string feat; string group; var data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group); if (hasCali) { ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor); } else { _cali = null; } if (_cali != null) { throw ch.ExceptNotImpl("Calibrator is not implemented yet."); } if (hasScorer) { ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data); } else { _scorer = null; } ch.Info("Tagging with tag '{0}'.", _args.tag); var ar = new TagViewTransform.Arguments { tag = _args.tag }; var res = new TagViewTransform(_host, ar, _scorer, _predictor); _sourcePipe = res; } }