Beispiel #1
0
        public static ModelOperations.PredictorModelOutput CombineOvaModels(IHostEnvironment env, ModelOperations.CombineOvaPredictorModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineOvaModels");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.ModelArray, nameof(input.ModelArray));
            // Something tells me we should put normalization as part of macro expansion, but since i get
            // subgraph instead of learner it's a bit tricky to get learner and decide should we add
            // normalization node or not, plus everywhere in code we leave that reposnsibility to TransformModel.
            var normalizedView = input.ModelArray[0].TransformModel.Apply(host, input.TrainingData);

            using (var ch = host.Start("CombineOvaModels"))
            {
                ISchema schema = normalizedView.Schema;
                var     label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumn),
                                                                     input.LabelColumn,
                                                                     DefaultColumnNames.Label);
                var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumn),
                                                                  input.FeatureColumn, DefaultColumnNames.Features);
                var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
                                                                 input.WeightColumn, DefaultColumnNames.Weight);
                var data = TrainUtils.CreateExamples(normalizedView, label, feature, null, weight);

                return(new ModelOperations.PredictorModelOutput
                {
                    PredictorModel = new PredictorModel(env, data, input.TrainingData,
                                                        Create(host, input.UseProbabilities,
                                                               input.ModelArray.Select(p => p.Predictor as IPredictorProducing <float>).ToArray()))
                });
            }
        }
        /// <summary>
        /// Generate training examples for training a predictor or instantiating a scorer.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="data">The data to use for training or scoring.</param>
        /// <param name="features">The name of the features column. Can be null.</param>
        /// <param name="label">The name of the label column. Can be null.</param>
        /// <param name="group">The name of the group ID column (for ranking). Can be null.</param>
        /// <param name="weight">The name of the weight column. Can be null.</param>
        /// <param name="custom">Additional column mapping to be passed to the trainer or scorer (specific to the prediction type). Can be null or empty.</param>
        /// <returns>The constructed examples.</returns>
        public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView data, string features, string label = null,
                                                    string group = null, string weight = null, IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > custom = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(label);
            env.CheckValueOrNull(features);
            env.CheckValueOrNull(group);
            env.CheckValueOrNull(weight);
            env.CheckValueOrNull(custom);

            return(TrainUtils.CreateExamples(data, label, features, group, weight, name: null, custom: custom));
        }
        private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transforms, IPredictor pred, string testDataPath = null)
        {
            using (var ch = env.Start("Saving model"))
                using (var memoryStream = new MemoryStream())
                {
                    var trainRoles = TrainUtils.CreateExamples(transforms, label: "Label", feature: "Features");

                    // Model cannot be saved with CacheDataView
                    TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles);
                    memoryStream.Position = 0;
                    using (var rep = RepositoryReader.Open(memoryStream, ch))
                    {
                        IDataLoader    testPipe  = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true);
                        RoleMappedData testRoles = TrainUtils.CreateExamples(testPipe, label: "Label", feature: "Features");
                        return(ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema));
                    }
                }
        }
        private FastForestRegressionPredictor FitModel(IEnumerable <IRunResult> previousRuns)
        {
            Single[]   targets  = new Single[previousRuns.Count()];
            Single[][] features = new Single[previousRuns.Count()][];

            int i = 0;

            foreach (RunResult r in previousRuns)
            {
                features[i] = SweeperProbabilityUtils.ParameterSetAsFloatArray(_host, _sweepParameters, r.ParameterSet, true);
                targets[i]  = (Float)r.MetricValue;
                i++;
            }

            ArrayDataViewBuilder dvBuilder = new ArrayDataViewBuilder(_host);

            dvBuilder.AddColumn("Label", NumberType.Float, targets);
            dvBuilder.AddColumn("Features", NumberType.Float, features);

            IDataView view = dvBuilder.GetDataView();

            _host.Assert(view.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations");
            RoleMappedData data = TrainUtils.CreateExamples(view, "Label", "Features");

            using (IChannel ch = _host.Start("Single training"))
            {
                // Set relevant random forest arguments.
                FastForestRegression.Arguments args = new FastForestRegression.Arguments();
                args.FeatureFraction     = _args.SplitRatio;
                args.NumTrees            = _args.NumOfTrees;
                args.MinDocumentsInLeafs = _args.NMinForSplit;

                // Train random forest.
                FastForestRegression trainer = new FastForestRegression(_host, args);
                trainer.Train(data);
                FastForestRegressionPredictor predictor = trainer.CreatePredictor();

                // Return random forest predictor.
                ch.Done();
                return(predictor);
            }
        }
        private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out string label)
        {
            var host = env.Register("OVA Macro GetNumberOfClasses");

            using (var ch = host.Start("OVA Macro GetNumberOfClasses"))
            {
                // RoleMappedData creation
                ISchema schema = input.TrainingData.Schema;
                label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                            input.LabelColumn,
                                                            DefaultColumnNames.Label);
                var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  input.FeatureColumn, DefaultColumnNames.Features);
                var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                 input.WeightColumn, DefaultColumnNames.Weight);

                // Get number of classes
                var data = TrainUtils.CreateExamples(input.TrainingData, label, feature, null, weight);
                data.CheckMultiClassLabel(out var numClasses);
                return(numClasses);
            }
        }
        public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
        {
            string dataPath     = GetDataPath("iris.txt");
            string testDataPath = dataPath;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    HasHeader = false,
                    Column    = new[] {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.R4
                        },
                        new TextLoader.Column()
                        {
                            Name   = "SepalLength",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.R4
                        },
                        new TextLoader.Column()
                        {
                            Name   = "SepalWidth",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 2, Max = 2
                                              } },
                            Type = DataKind.R4
                        },
                        new TextLoader.Column()
                        {
                            Name   = "PetalLength",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 3, Max = 3
                                              } },
                            Type = DataKind.R4
                        },
                        new TextLoader.Column()
                        {
                            Name   = "PetalWidth",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 4, Max = 4
                                              } },
                            Type = DataKind.R4
                        }
                    }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = new ConcatTransform(env, loader, "Features",
                                                           "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");

                // Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto
                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features");

                // Train
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments());

                // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = TrainUtils.CreateExamples(cached, label: "Label", feature: "Features");
                trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                var pred = trainer.CreatePredictor();
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = Evaluate(env, testDataScorer);
                CompareMatrics(metrics);

                // Create prediction engine and test predictions
                var model = env.CreatePredictionEngine <IrisData, IrisPrediction>(testDataScorer);
                ComparePredictions(model);

                // Get feature importance i.e. weight vector
                var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(7.757867, Convert.ToDouble(summary[0].Value), 5);
            }
        }
Beispiel #7
0
        public static TOut Train <TArg, TOut>(IHost host, TArg input,
                                              Func <ITrainer> createTrainer,
                                              Func <string> getLabel  = null,
                                              Func <string> getWeight = null,
                                              Func <string> getGroup  = null,
                                              Func <string> getName   = null,
                                              Func <IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > > getCustom = null,
                                              ICalibratorTrainerFactory calibrator = null,
                                              int maxCalibrationExamples           = 0)
            where TArg : LearnerInputBase
            where TOut : CommonOutputs.TrainerOutput, new()
        {
            using (var ch = host.Start("Training"))
            {
                ISchema schema  = input.TrainingData.Schema;
                var     feature = FindColumn(ch, schema, input.FeatureColumn);
                var     label   = getLabel?.Invoke();
                var     weight  = getWeight?.Invoke();
                var     group   = getGroup?.Invoke();
                var     name    = getName?.Invoke();
                var     custom  = getCustom?.Invoke();

                var trainer = createTrainer();

                IDataView view = input.TrainingData;
                TrainUtils.AddNormalizerIfNeeded(host, ch, trainer, ref view, feature, input.NormalizeFeatures);

                ch.Trace("Binding columns");
                var roleMappedData = TrainUtils.CreateExamples(view, label, feature, group, weight, name, custom);

                RoleMappedData    cachedRoleMappedData = roleMappedData;
                Cache.CachingType?cachingType          = null;
                switch (input.Caching)
                {
                case CachingOptions.Memory:
                {
                    cachingType = Cache.CachingType.Memory;
                    break;
                }

                case CachingOptions.Disk:
                {
                    cachingType = Cache.CachingType.Disk;
                    break;
                }

                case CachingOptions.Auto:
                {
                    ITrainerEx trainerEx = trainer as ITrainerEx;
                    // REVIEW: we should switch to hybrid caching in future.
                    if (!(input.TrainingData is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching))
                    {
                        // default to Memory so mml is on par with maml
                        cachingType = Cache.CachingType.Memory;
                    }
                    break;
                }

                case CachingOptions.None:
                    break;

                default:
                    throw ch.ExceptParam(nameof(input.Caching), "Unknown option for caching: '{0}'", input.Caching);
                }

                if (cachingType.HasValue)
                {
                    var cacheView = Cache.CacheData(host, new Cache.CacheInput()
                    {
                        Data    = roleMappedData.Data,
                        Caching = cachingType.Value
                    }).OutputData;
                    cachedRoleMappedData = RoleMappedData.Create(cacheView, roleMappedData.Schema.GetColumnRoleNames());
                }

                var predictor = TrainUtils.Train(host, ch, cachedRoleMappedData, trainer, "Train", calibrator, maxCalibrationExamples);
                var output    = new TOut()
                {
                    PredictorModel = new PredictorModel(host, roleMappedData, input.TrainingData, predictor)
                };

                ch.Done();
                return(output);
            }
        }
        public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "SentimentText",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 1
                                              } },
                            Type = DataKind.Text
                        }
                    }
                }, new MultiFileSource(dataPath));

                var trans = TextTransform.Create(env, new TextTransform.Arguments()
                {
                    Column = new TextTransform.Column
                    {
                        Name   = "Features",
                        Source = new[] { "SentimentText" }
                    },
                    KeepDiacritics       = false,
                    KeepPunctuations     = false,
                    TextCase             = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower,
                    OutputTokens         = true,
                    StopWordsRemover     = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(),
                    VectorNormalizer     = TextTransform.TextNormKind.L2,
                    CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 3, AllLengths = false
                    },
                    WordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments()
                    {
                        NgramLength = 2, AllLengths = true
                    },
                },
                                                 loader);

                // Train
                var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()
                {
                    NumLeaves           = 5,
                    NumTrees            = 5,
                    MinDocumentsInLeafs = 2
                });

                var trainRoles = TrainUtils.CreateExamples(trans, label: "Label", feature: "Features");
                trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                var pred = trainer.CreatePredictor();
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = EvaluateBinary(env, testDataScorer);
                ValidateBinaryMetrics(metrics);

                // Create prediction engine and test predictions
                var model       = env.CreateBatchPredictionEngine <SentimentData, SentimentPrediction>(testDataScorer);
                var sentiments  = GetTestData();
                var predictions = model.Predict(sentiments, false);
                Assert.Equal(2, predictions.Count());
                Assert.True(predictions.ElementAt(0).Sentiment.IsFalse);
                Assert.True(predictions.ElementAt(1).Sentiment.IsTrue);

                // Get feature importance based on feature gain during training
                var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(1.0, (double)summary[0].Value, 1);
            }
        }