예제 #1
0
        /// <summary>
        /// Create a TransformModel containing the given (optional) transforms applied to the
        /// given root schema.
        /// </summary>
        public TransformModelImpl(IHostEnvironment env, Schema schemaRoot, IDataTransform[] xfs)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(schemaRoot, nameof(schemaRoot));
            env.CheckValueOrNull(xfs);

            IDataView view = new EmptyDataView(env, schemaRoot);

            _schemaRoot = view.Schema;

            if (Utils.Size(xfs) > 0)
            {
                foreach (var xf in xfs)
                {
                    env.AssertValue(xf, "xfs", "Transforms should not be null");
                    view = ApplyTransformUtils.ApplyTransformToData(env, xf, view);
                }
            }

            _chain = view;
        }
예제 #2
0
        /// <summary>
        /// Apply this transform model to the given input transform model to produce a composite transform model.
        /// </summary>
        public ITransformModel Apply(IHostEnvironment env, ITransformModel input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));

            IDataView view;
            Schema    schemaRoot = input.InputSchema;
            var       mod        = input as TransformModel;

            if (mod != null)
            {
                view = ApplyTransformUtils.ApplyAllTransformsToData(env, _chain, mod._chain);
            }
            else
            {
                view = new EmptyDataView(env, schemaRoot);
                view = input.Apply(env, view);
                view = Apply(env, view);
            }

            return(new TransformModel(env, schemaRoot, view));
        }
예제 #3
0
        public void LambdaTransformCreate()
        {
            var env  = new MLContext(seed: 42);
            var data = ReadBreastCancerExamples();
            var idv  = env.CreateDataView(data);

            var filter = LambdaTransform.CreateFilter <BreastCancerExample, object>(env, idv,
                                                                                    (input, state) => input.Label == 0, null);

            Assert.Null(filter.GetRowCount());

            // test re-apply
            var applied = env.CreateDataView(data);

            applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied);

            var saver = new TextSaver(env, new TextSaver.Arguments());

            Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label));
            using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv")))
                saver.SaveData(fs, applied, label);
        }
예제 #4
0
        void DecomposableTrainAndPredict()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer);
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
예제 #5
0
        protected TTransformer TrainTransformer(IDataView trainSet,
                                                IDataView validationSet = null, IPredictor initPredictor = null)
        {
            var cachedTrain = TrainerInfo.WantCaching ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet;

            var       trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol);
            var       emptyData  = new EmptyDataView(_env, trainSet.Schema);
            IDataView normalizer = emptyData;

            if (TrainerInfo.NeedNormalization && trainRoles.Schema.FeaturesAreNormalized() == false)
            {
                var view = NormalizeTransform.CreateMinMaxNormalizer(_env, trainRoles.Data, name: trainRoles.Schema.Feature.Name);
                normalizer = ApplyTransformUtils.ApplyAllTransformsToData(_env, view, emptyData, cachedTrain);

                trainRoles = new RoleMappedData(view, trainRoles.Schema.GetColumnRoleNames());
            }

            RoleMappedData validRoles;

            if (validationSet == null)
            {
                validRoles = null;
            }
            else
            {
                var cachedValid = TrainerInfo.WantCaching ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet;
                cachedValid = ApplyTransformUtils.ApplyAllTransformsToData(_env, normalizer, cachedValid);
                validRoles  = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol);
            }

            var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor));

            var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol);

            return(MakeScorer(pred, scoreRoles));
        }
        /// <summary>
        /// Saves the pipeline in a stream.
        /// </summary>
        /// <param name="fs">opened stream</param>
        /// <param name="removeFirstTransform">remove the first transform which is a PassThroughTransform</param>
        public void Save(Stream fs, bool removeFirstTransform = false)
        {
            RoleMappedData roleMap = null;

            if (removeFirstTransform)
            {
                var source = _transforms.First().transform;
                if (!(source is PassThroughTransform))
                {
                    throw Contracts.ExceptNotSupp($"The first transform should be of type PassThroughTransform.");
                }
                var replace = (source as PassThroughTransform).Source;
                var last    = _transforms.Last().transform;
                var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, last, replace, source);
                var roles   = _predictor.roleMapData.Schema.GetColumnRoleNames().ToArray();
                roleMap = new RoleMappedData(newPipe, roles);
            }
            else
            {
                roleMap = _predictor.roleMapData;
            }
            using (var ch = _env.Start("Save Predictor"))
                TrainUtils.SaveModel(_env, ch, fs, _predictor.predictor, roleMap);
        }
예제 #7
0
        ValueMapper <DataFrame, DataFrame> GetMapperRow()
        {
            var firstView = _sourceToReplace ?? DataViewHelper.GetFirstView(_transform);
            var schema    = firstView.Schema;

            var inputView = new InfiniteLoopViewCursorDataFrame(null, firstView.Schema);

            // This is extremely time consuming as the transform is serialized and deserialized.
            var outputView = _sourceToReplace == _transform.Source
                                ? ApplyTransformUtils.ApplyTransformToData(_computeEnv, _transform, inputView)
                                : ApplyTransformUtils.ApplyAllTransformsToData(_computeEnv, _transform, inputView, _sourceToReplace);

            // We assume all columns are needed, otherwise they should be removed.
            using (var cur = outputView.GetRowCursor(i => true))
            {
                var getRowFiller = DataFrame.GetRowFiller(cur);

                return((in DataFrame src, ref DataFrame dst) =>
                {
                    if (dst is null)
                    {
                        dst = new DataFrame(outputView.Schema, src.Length);
                    }
                    else if (!dst.CheckSharedSchema(outputView.Schema))
                    {
                        throw _env.Except($"DataFrame does not share the same schema, expected {SchemaHelper.ToString(outputView.Schema)}.");
                    }
                    dst.Resize(src.Length);

                    inputView.Set(src);
                    for (int i = 0; i < src.Length; ++i)
                    {
                        cur.MoveNext();
                        getRowFiller(dst, i);
                    }
                });
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
예제 #9
0
 public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
예제 #10
0
        void CrossValidation()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            int numFolds = 5;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var       text  = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
                IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });


                var metrics = new List <BinaryClassificationMetrics>();
                for (int fold = 0; fold < numFolds; fold++)
                {
                    IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = true
                    }, trans);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features");
                    // Auto-normalization.
                    NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer);
                    var preCachedData = trainData;
                    // Auto-caching.
                    if (trainer.Info.WantCaching)
                    {
                        var prefetch  = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                        var cacheView = new CacheDataView(env, trainData.Data, prefetch);
                        // Because the prefetching worked, we know that these are valid columns.
                        trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames());
                    }

                    var       predictor = trainer.Train(new Runtime.TrainContext(trainData));
                    IDataView testPipe  = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = false
                    }, trans);
                    testPipe = new OpaqueDataView(testPipe);
                    var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);

                    var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());

                    IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);

                    BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                    {
                    });
                    var dataEval    = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
                    var dict        = eval.Evaluate(dataEval);
                    var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
                    metrics.Add(foldMetrics.Single());
                }
            }
        }
예제 #11
0
 /// <summary>
 /// Apply this transform model to the given input data.
 /// </summary>
 public IDataView Apply(IHostEnvironment env, IDataView input)
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(input, nameof(input));
     return(ApplyTransformUtils.ApplyAllTransformsToData(env, _chain, input));
 }