internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer <TModel>)), ctx) { var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); }
public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer <TModel>)), model, inputSchema, featureColumn) { var schema = new RoleMappedSchema(inputSchema, null, featureColumn); _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); }
public static IDataTransform CreateForEntryPoint(IHostEnvironment env, ArgumentsForEntryPoint args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(args.PredictorModel != null, nameof(args.PredictorModel), "Please specify a predictor model."); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; var predictor = args.PredictorModel.Predictor; ch.Trace("Prepare data"); RoleMappedData data = null; args.PredictorModel.PrepareData(env, input, out data, out var predictor2); ch.AssertValue(data); ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); ch.Done(); } return(xf); }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("Tree Featurizer Transform"); host.CheckValue(args, nameof(args)); host.CheckValue(input, nameof(input)); host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile), "Please specify either a trainer or an input model file."); host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column"); IDataTransform xf; using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; if (!string.IsNullOrWhiteSpace(args.TrainedModelFile)) { if (args.Trainer != null) { ch.Warning("Both an input model and a trainer were specified. Using the model file."); } ch.Trace("Loading model"); IPredictor predictor; using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read)) using (var rep = RepositoryReader.Open(strm, ch)) ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor); ch.Trace("Creating scorer"); var data = TrainAndScoreTransform.CreateDataFromArgs(ch, input, args); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) { predictor = ((CalibratedPredictorBase)predictor).SubPredictor; } // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); if (vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } else { ch.AssertValue(args.Trainer); ch.Trace("Creating TrainAndScoreTransform"); var trainScoreArgs = new TrainAndScoreTransform.Arguments(); args.CopyTo(trainScoreArgs); trainScoreArgs.Trainer = args.Trainer; trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>( (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema)); var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>( (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor)); var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed); var scoreXf = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput, mapperFactory); if (input == labelInput) { return(scoreXf); } return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput)); } } return(xf); }
/// <summary> /// Constructor for <see cref="ApplyToData"/> method. /// </summary> private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView data) : base(env, data, RegistrationName, transform.Bindable) { _bindings = transform._bindings.ApplyToSchema(env, data.Schema); OutputSchema = _bindings.AsSchema; }