public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironment env, PredictedLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConvertPredictedLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int predictedLabelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.PredictedLabelColumn, out predictedLabelCol))
            {
                throw host.Except($"Column '{input.PredictedLabelColumn}' not found.");
            }
            var predictedLabelType = input.Data.Schema.GetColumnType(predictedLabelCol);

            if (predictedLabelType.IsNumber || predictedLabelType.IsBool)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModel(env, nop, input.Data), OutputData = nop
                });
            }

            var xf = new KeyToValueTransform(host, input.PredictedLabelColumn).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
示例#2
0
        void DecomposableTrainAndPredict()
        {
            using (var env = new LocalEnvironment()
                             .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer);
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testData = loader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
示例#3
0
        public TransformWrapper Fit(IDataView input)
        {
            var xf    = new KeyToValueTransform(_env, input, _name, _source);
            var empty = new EmptyDataView(_env, input.Schema);
            var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);

            return(new TransformWrapper(_env, chunk));
        }
示例#4
0
 private static IDataView ApplyKeyToVec(List <KeyToVectorTransform.Column> ktv, IDataView viewTrain, IHost host)
 {
     Contracts.AssertValueOrNull(ktv);
     Contracts.AssertValue(viewTrain);
     Contracts.AssertValue(host);
     if (Utils.Size(ktv) > 0)
     {
         // Instead of simply using KeyToVector, we are jumping to some hoops here to do the right thing in a very common case
         // when the user has slightly different key values between the training and testing set.
         // The solution is to apply KeyToValue, then Term using the terms from the key metadata of the original key column
         // and finally the KeyToVector transform.
         viewTrain = new KeyToValueTransform(host,
                                             new KeyToValueTransform.Arguments()
         {
             Column = ktv
                      .Select(c => new KeyToValueTransform.Column()
             {
                 Name = c.Name, Source = c.Source
             })
                      .ToArray()
         },
                                             viewTrain);
         viewTrain = new Data.TermTransform(host,
                                            new Data.TermTransform.Arguments()
         {
             Column = ktv
                      .Select(c => new Data.TermTransform.Column()
             {
                 Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.Source)
             })
                      .ToArray(),
             TextKeyValues = true
         },
                                            viewTrain);
         viewTrain = new KeyToVectorTransform(host,
                                              new KeyToVectorTransform.Arguments()
         {
             Column = ktv
                      .Select(c => new KeyToVectorTransform.Column()
             {
                 Name = c.Name, Source = c.Name
             })
                      .ToArray(),
             Bag = false
         },
                                              viewTrain);
     }
     return(viewTrain);
 }
        void Extensibility()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new LocalEnvironment())
            {
                var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                Action <IrisData, IrisData> action = (i, j) =>
                {
                    j.Label       = i.Label;
                    j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength;
                    j.PetalWidth  = i.PetalWidth;
                    j.SepalLength = i.SepalLength;
                    j.SepalWidth  = i.SepalWidth;
                };
                var lambda = LambdaTransform.CreateMap(env, loader, action);
                var term   = TermTransform.Create(env, lambda, "Label");
                var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
                             .Transform(term);

                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(scorer);
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
示例#6
0
        void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = new SimpleComponentFactory <ITrainer <IPredictorProducing <float> > >
                                    (
                        (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())
                                    )
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
示例#7
0
        void DecomposableTrainAndPredict()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, newScorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
示例#8
0
 public Mapper(KeyToValueTransform parent, Schema inputSchema)
     : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
 {
     _parent = parent;
     ComputeKvMaps(inputSchema, out _types, out _kvMaps);
 }