internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer <TModel>)), ctx)
        {
            var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);

            _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
        public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer <TModel>)), model, inputSchema, featureColumn)
        {
            var schema = new RoleMappedSchema(inputSchema, null, featureColumn);

            _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
        }
        public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model.");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                var predictor = args.PredictorModel.Predictor;
                ch.Trace("Prepare data");
                RoleMappedData data = null;
                args.PredictorModel.PrepareData(env, input, out data, out var predictor2);
                ch.AssertValue(data);
                ch.Assert(predictor == predictor2);

                // Make sure that the given predictor has the correct number of input features.
                if (predictor is CalibratedPredictorBase)
                {
                    predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                }
                // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should
                // be non-null.
                var vm = predictor as IValueMapper;
                ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type");
                if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
                {
                    throw ch.ExceptUserArg(nameof(args.PredictorModel),
                                           "Predictor expects {0} features, but data has {1} features",
                                           vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
                }

                var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                var bound    = bindable.Bind(env, data.Schema);
                xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema);
                ch.Done();
            }
            return(xf);
        }
Beispiel #4
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransform.CreateDataFromArgs(ch, input, args);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
                    }

                    var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound    = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransform.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
 /// <summary>
 /// Constructor for <see cref="ApplyToData"/> method.
 /// </summary>
 private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView data)
     : base(env, data, RegistrationName, transform.Bindable)
 {
     _bindings    = transform._bindings.ApplyToSchema(env, data.Schema);
     OutputSchema = _bindings.AsSchema;
 }