Esempio n. 1
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = mapper.Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }
        static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            env.CheckValue(args.tag, "Tag cannot be empty.");
            if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }
            env.CheckValue(args.selectTag, "Selected tag cannot be empty.");

            if (string.IsNullOrEmpty(args.filename))
            {
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (!selected.Any())
                {
                    throw env.Except("Unable to find a view to select with tag '{0}'. Did you forget to specify a filename?", args.selectTag);
                }
                var first = selected.First();
                if (selected.Skip(1).Any())
                {
                    throw env.Except("Tag '{0}' is ambiguous, {1} views were found.", args.selectTag, selected.Count());
                }
                var tagged = input as ITaggedDataView;
                if (tagged == null)
                {
                    var ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }
                first.Item2.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, first.Item2) });
#if (DEBUG_TIP)
                long count = DataViewUtils.ComputeRowCount(tagged);
                if (count == 0)
                {
                    throw env.Except("Replaced view is empty.");
                }
                count = DataViewUtils.ComputeRowCount(first.Item2);
                if (count == 0)
                {
                    throw env.Except("Selected view is empty.");
                }
#endif
                var tr = first.Item2 as IDataTransform;
                env.AssertValue(tr);
                return(tr);
            }
            else
            {
                if (!File.Exists(args.filename))
                {
                    throw env.Except("Unable to find file '{0}'.", args.filename);
                }
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (selected.Any())
                {
                    throw env.Except("Tag '{0}' was already given. It cannot be assigned to the new file.", args.selectTag);
                }
                var loaderArgs   = new BinaryLoader.Arguments();
                var file         = new MultiFileSource(args.filename);
                var loadSettings = ScikitSubComponent <ILegacyDataLoader, SignatureDataLoader> .AsSubComponent(args.loaderSettings);

                IDataView loader = loadSettings.CreateInstance(env, file);

                var ag = new TagViewTransform.Arguments {
                    tag = args.selectTag
                };
                var newInput = new TagViewTransform(env, ag, loader);
                var tagged   = input as ITaggedDataView;
                if (tagged == null)
                {
                    ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }

                newInput.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, newInput) });

                var schema = loader.Schema;
                if (schema.Count == 0)
                {
                    throw env.Except("The loaded view '{0}' is empty (empty schema).", args.filename);
                }
                return(newInput);
            }
        }
Esempio n. 3
0
        public TagTrainOrScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, ctx, input, LoaderSignature)
        {
            _args = new Arguments();
            _args.Read(ctx, _host);

            bool hasPredictor = ctx.Reader.ReadByte() == 1;
            bool hasCali      = ctx.Reader.ReadByte() == 1;
            bool hasScorer    = ctx.Reader.ReadByte() == 1;

            if (hasPredictor)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _predictor, "predictor");
            }
            else
            {
                _predictor = null;
            }

            using (var ch = _host.Start("TagTrainOrScoreTransform loading"))
            {
                var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == _args.tag);
                if (views.Any())
                {
                    throw _host.Except("Tag '{0}' is already used.", _args.tag);
                }

                var    customCols = TrainUtils.CheckAndGenerateCustomColumns(_host, _args.CustomColumn);
                string feat;
                string group;
                var    data = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), _args, out feat, out group);

                if (hasCali)
                {
                    ctx.LoadModel <ICalibratorTrainer, SignatureLoadModel>(host, out _cali, "calibrator", _predictor);
                }
                else
                {
                    _cali = null;
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                if (hasScorer)
                {
                    ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(host, out _scorer, "scorer", data.Data);
                }
                else
                {
                    _scorer = null;
                }

                ch.Info("Tagging with tag '{0}'.", _args.tag);
                var ar = new TagViewTransform.Arguments {
                    tag = _args.tag
                };
                var res = new TagViewTransform(_host, ar, _scorer, _predictor);
                _sourcePipe = res;
            }
        }