Example #1
0
        public TagViewTransform(IHostEnvironment env, Arguments args, IDataView input, IPredictor predictor)
        {
            Contracts.CheckValue(env, "env");
            var tagged = input as ITaggedDataView;

            if (tagged != null)
            {
                throw env.Except("The input view is already tagged. Don't tag again with '{0}'.", args.tag);
            }
            if (predictor == null)
            {
                throw env.Except("Predictor is null, it cannot be tagged with '{0}'.", args.tag);
            }
            _host = env.Register(RegistrationName);
            _host.CheckValue(args, "args");
            _args  = args;
            _input = input;
            _host.CheckValue(args.tag, "Tag cannot be empty.");
            if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any())
            {
                throw _host.Except("Tag '{0}' is already used.", args.tag);
            }
            _parallelViews = new List <Tuple <string, ITaggedDataView> >();
            _parallelViews.Add(new Tuple <string, ITaggedDataView>(_args.tag, this));
            _taggedPredictor = predictor;
        }
Example #2
0
 /// <summary>
 /// Looks for a predictor among tagged predictors.
 /// If the tag ends by .zip, it assumes it is a file.
 /// </summary>
 /// <param name="env">environment</param>
 /// <param name="input">IDataView</param>
 /// <param name="tag">tag name</param>
 /// <returns>predictor</returns>
 public static IPredictor GetTaggedPredictor(IHostEnvironment env, IDataView input, string tag)
 {
     if (string.IsNullOrEmpty(tag))
     {
         throw env.Except("tag must not be null.");
     }
     if (tag.EndsWith(".zip"))
     {
         using (Stream modelStream = new FileStream(tag, FileMode.Open, FileAccess.Read))
         {
             var ipred = ModelFileUtils.LoadPredictorOrNull(env, modelStream);
             return(ipred);
         }
     }
     else
     {
         var tagged = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == tag);
         if (!tagged.Any())
         {
             throw env.Except("Unable to find any view with tag '{0}'.", tag);
         }
         if (tagged.Skip(1).Any())
         {
             var allTagged = TagHelper.EnumerateTaggedView(true, input).ToArray();
             throw env.Except("Ambiguous tag '{0}' - {1}/{2}.", tag,
                              allTagged.Where(c => c.Item1 == tag).Count(), allTagged.Length);
         }
         var predictor = tagged.First().Item2.TaggedPredictor;
         if (predictor == null)
         {
             env.Except("Tagged view '{0}' does not host a predictor.", tag);
         }
         return(predictor);
     }
 }
Example #3
0
 /// <summary>
 /// Reading serialized transform.
 /// </summary>
 private TagViewTransform(IHost host, ModelLoadContext ctx, IDataView input)
 {
     Contracts.CheckValue(host, "host");
     _host = host;
     _host.CheckValue(input, "input");
     _host.CheckValue(ctx, "ctx");
     _input = input;
     _args  = new Arguments();
     _args.Read(ctx, _host);
     _host.CheckValue(_args.tag, "Tag cannot be empty.");
     if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag).Any())
     {
         throw _host.Except("Tag '{0}' is already used.", _args.tag);
     }
     _parallelViews = new List <Tuple <string, ITaggedDataView> >();
     _parallelViews.Add(new Tuple <string, ITaggedDataView>(_args.tag, this));
 }
Example #4
0
        private IDataView Setup(IDataView input)
        {
            List <IDataView> concat = new List <IDataView>();

            concat.Add(input);
            foreach (var tag in _args.tag)
            {
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == tag);
                if (!selected.Any())
                {
                    throw _host.Except("Unable to find a view to append with tag '{0}'", tag);
                }
                var first = selected.First();
                if (selected.Skip(1).Any())
                {
                    throw _host.Except("Tag '{0}' is ambiguous, {1} views were found.", tag, selected.Count());
                }
                concat.Add(first.Item2);
            }
            return(AppendRowsDataView.Create(_host, input.Schema, concat.ToArray()));
        }
        static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            env.CheckValue(args.tag, "Tag cannot be empty.");
            if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }
            env.CheckValue(args.selectTag, "Selected tag cannot be empty.");

            if (string.IsNullOrEmpty(args.filename))
            {
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (!selected.Any())
                {
                    throw env.Except("Unable to find a view to select with tag '{0}'. Did you forget to specify a filename?", args.selectTag);
                }
                var first = selected.First();
                if (selected.Skip(1).Any())
                {
                    throw env.Except("Tag '{0}' is ambiguous, {1} views were found.", args.selectTag, selected.Count());
                }
                var tagged = input as ITaggedDataView;
                if (tagged == null)
                {
                    var ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }
                first.Item2.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, first.Item2) });
#if (DEBUG_TIP)
                long count = DataViewUtils.ComputeRowCount(tagged);
                if (count == 0)
                {
                    throw env.Except("Replaced view is empty.");
                }
                count = DataViewUtils.ComputeRowCount(first.Item2);
                if (count == 0)
                {
                    throw env.Except("Selected view is empty.");
                }
#endif
                var tr = first.Item2 as IDataTransform;
                env.AssertValue(tr);
                return(tr);
            }
            else
            {
                if (!File.Exists(args.filename))
                {
                    throw env.Except("Unable to find file '{0}'.", args.filename);
                }
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (selected.Any())
                {
                    throw env.Except("Tag '{0}' was already given. It cannot be assigned to the new file.", args.selectTag);
                }
                var loaderArgs   = new BinaryLoader.Arguments();
                var file         = new MultiFileSource(args.filename);
                var loadSettings = ScikitSubComponent <ILegacyDataLoader, SignatureDataLoader> .AsSubComponent(args.loaderSettings);

                IDataView loader = loadSettings.CreateInstance(env, file);

                var ag = new TagViewTransform.Arguments {
                    tag = args.selectTag
                };
                var newInput = new TagViewTransform(env, ag, loader);
                var tagged   = input as ITaggedDataView;
                if (tagged == null)
                {
                    ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }

                newInput.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, newInput) });

                var schema = loader.Schema;
                if (schema.Count == 0)
                {
                    throw env.Except("The loaded view '{0}' is empty (empty schema).", args.filename);
                }
                return(newInput);
            }
        }
Example #6
0
 public IEnumerable <Tuple <string, ITaggedDataView> > EnumerateTaggedView(bool recursive = true)
 {
     return(TagHelper.EnumerateTaggedView(recursive, this));
 }
Example #7
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = mapper.Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }
Example #8
0
        public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, ctx, input, LoaderSignature)
        {
            _args = new Arguments();
            _args.Read(ctx, _host);

            bool hasPredictor = ctx.Reader.ReadByte() == 1;
            bool hasCali      = ctx.Reader.ReadByte() == 1;
            bool hasScorer    = ctx.Reader.ReadByte() == 1;

            if (hasPredictor)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor");
            }
            else
            {
                _predictor = null;
            }

            using (var ch = _host.Start("TagTrainOrScoreTransform loading"))
            {
                var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag);
                if (views.Any())
                {
                    throw _host.Except("Tag '{0}' is already used.", _args.tag);
                }

                var    customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumn);
                string feat;
                string group;
                var    data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group);

                if (hasCali)
                {
                    ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor);
                }
                else
                {
                    _cali = null;
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                if (hasScorer)
                {
                    ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data);
                }
                else
                {
                    _scorer = null;
                }

                ch.Info("Tagging with tag '{0}'.", _args.tag);
                var ar = new TagViewTransform.Arguments {
                    tag = _args.tag
                };
                var res = new TagViewTransform(_host, ar, _scorer, _predictor);
                _sourcePipe = res;
            }
        }