Example #1
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumns);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = (mapper as ISchemaBindableMapper).Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }
 public IEnumerable <Tuple <string, ITaggedDataView> > EnumerateTaggedView(bool recursive = true)
 {
     return(TagHelper.EnumerateTaggedView(recursive, this));
 }
Example #3
0
        public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, ctx, input, LoaderSignature)
        {
            _args = new Arguments();
            _args.Read(ctx, _host);

            bool hasPredictor = ctx.Reader.ReadByte() == 1;
            bool hasCali      = ctx.Reader.ReadByte() == 1;
            bool hasScorer    = ctx.Reader.ReadByte() == 1;

            if (hasPredictor)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor");
            }
            else
            {
                _predictor = null;
            }

            using (var ch = _host.Start("TagTrainOrScoreTransform loading"))
            {
                var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag);
                if (views.Any())
                {
                    throw _host.Except("Tag '{0}' is already used.", _args.tag);
                }

                var    customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumns);
                string feat;
                string group;
                var    data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group);

                if (hasCali)
                {
                    ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor);
                }
                else
                {
                    _cali = null;
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                if (hasScorer)
                {
                    ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data);
                }
                else
                {
                    _scorer = null;
                }

                ch.Info("Tagging with tag '{0}'.", _args.tag);
                var ar = new TagViewTransform.Arguments {
                    tag = _args.tag
                };
                var res = new TagViewTransform(_host, ar, _scorer, _predictor);
                _sourcePipe = res;
            }
        }