예제 #1
0
        public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                       SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null)
        {
            ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);

            return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor));
        }
예제 #2
0
        private EarlyStoppingRuleBase CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
        {
            var env = new MLContext()
                      .AddStandardComponents();
            var sub = new SubComponent <EarlyStoppingRuleBase, SignatureEarlyStoppingCriterion>(name, args);

            return(sub.CreateInstance(env, lowerIsBetter));
        }
예제 #3
0
        private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
        {
            var env = new ConsoleEnvironment()
                      .AddStandardComponents();
            var sub = new SubComponent <IEarlyStoppingCriterion, SignatureEarlyStoppingCriterion>(name, args);

            return(sub.CreateInstance(env, lowerIsBetter));
        }
예제 #4
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = _trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataView view = CreateLoader();

            ISchema schema  = view.Schema;
            var     label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label);
            var     feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features);
            var     group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId);
            var     weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight);
            var     name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures);

            ch.Trace("Binding columns");

            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(view, label, feature, group, weight, name, customCols);

            // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            using (var file = Host.CreateOutputFile(Args.OutputModelFile))
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
        }
예제 #5
0
        private void WeightedMetricTest(Instances noWeights, Instances weights1, Instances weightsQuarter, string predictorName, Func <Tester <Float> > tester)
        {
            Instances[] data = new Instances[3] {
                noWeights, weights1, weightsQuarter
            };
            Metric[][] results = new Metric[3][];
            var        sub     = new SubComponent <ITrainer <Instances, IPredictor <Instance, Float> >, SignatureOldTrainer>(
                predictorName, "nl=5 lr=0.25 iter=20 mil=1");

            for (int i = 0; i < 3; i++)
            {
                Instances instances = data[i];
                if (instances == null)
                {
                    continue;
                }

                // Create the trainer
                var trainer = sub.CreateInstance(new TrainHost(new Random(1), 0));

                // Train a predictor
                trainer.Train(instances);
                var predictor = trainer.CreatePredictor();

                results[i] = tester().Test(predictor, instances);
            }

            //Compare metrics results with unweighted metrics
            for (int i = 1; i < 3; i++)
            {
                if (results[i] == null)
                {
                    continue;
                }
                //The nonweighted result should have half of the metrics
                Assert.Equal(results[i].Length, results[0].Length * 2);
                for (int m = 0; m < results[0].Length; m++)
                {
                    Assert.Equal(results[0][m].Name, results[i][m].Name);
                    Double diff = Math.Abs(results[0][m].Value - results[i][m].Value);
                    if (diff > 1e-6)
                    {
                        Fail("{0} differ: {1} vs. {2}", results[0][m].Name, results[0][m].Value, results[i][m].Value);
                    }
                }
            }

            //Compare all metrics between weight 1 (with and without explicit weight in the input)
            for (int m = 0; m < results[0].Length; m++)
            {
                Assert.True(Math.Abs(results[0][m].Value - results[1][m].Value) < 1e-10);
                Assert.True(Math.Abs(results[0][m].Value - results[1][m + results[0].Length].Value) < 1e-10);
            }
        }
예제 #6
0
        /// <summary>
        /// Return role/column-name pairs loaded from a repository.
        /// </summary>
        public static IEnumerable <KeyValuePair <ColumnRole, string> > LoadRoleMappingsOrNull(IHostEnvironment env, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("RoleMappingUtils");

            var list = new List <KeyValuePair <string, string> >();

            var entry = rep.OpenEntryOrNull(DirTrainingInfo, RoleMappingFile);

            if (entry == null)
            {
                return(null);
            }
            entry.Dispose();

            using (var ch = h.Start("Loading role mappings"))
            {
                // REVIEW: Should really validate the schema here, and consider
                // ignoring this stream if it isn't as expected.
                var loaderSub = new SubComponent <IDataLoader, SignatureDataLoader>("Text");
                var loader    = loaderSub.CreateInstance(env,
                                                         new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile));

                using (var cursor = loader.GetRowCursor(c => true))
                {
                    var roleGetter = cursor.GetGetter <DvText>(0);
                    var colGetter  = cursor.GetGetter <DvText>(1);
                    var role       = default(DvText);
                    var col        = default(DvText);
                    while (cursor.MoveNext())
                    {
                        roleGetter(ref role);
                        colGetter(ref col);
                        string roleStr = role.ToString();
                        string colStr  = col.ToString();

                        h.CheckDecode(!string.IsNullOrWhiteSpace(roleStr), "Role name must not be empty");
                        h.CheckDecode(!string.IsNullOrWhiteSpace(colStr), "Column name must not be empty");
                        list.Add(new KeyValuePair <string, string>(roleStr, colStr));
                    }
                }

                ch.Done();
            }

            return(TrainUtils.CheckAndGenerateCustomColumns(env, list.ToArray()));
        }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateInstance(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = new RoleMappedData(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var evalComp = _evaluator;
                    if (!evalComp.IsGood())
                    {
                        evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
                    }
                    var eval = evalComp.CreateInstance(host);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    }
                    ch.Done();
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }
예제 #8
0
        public void SimpleExampleTest()
        {
            RunMTAThread(() =>
            {
                string dataFilename = GetDataPath(TestDatasets.msm.trainFilename);

                ///*********  Training a model *******//
                // assume data is in memory in matrix/vector form. Sparse format also supported.
                Float[][] data;
                Float[] labels;
                // below just reads some actual data into these arrays
                PredictionUtil.ReadInstancesAsArrays(dataFilename, out data, out labels);

                // Create an Instances dataset.
                ListInstances instances = new ListInstances();
                for (int i = 0; i < data.Length; i++)
                {
                    instances.AddInst(data[i], labels[i]);
                }
                instances.CopyMetadata(null);

                // Create a predictor and specify some non-default settings
                var sub = new SubComponent <ITrainer <Instances, IPredictor <Instance, Float> >, SignatureOldBinaryClassifierTrainer>(
                    "FastRank", "nl=5 lr =0.25 iter= 20");
                var trainer = sub.CreateInstance(new TrainHost(new Random(1), 0));

                // Train a predictor
                trainer.Train(instances);
                var predictor = trainer.CreatePredictor();

                ///*********  Several ways to save models. Only binary can be used to-reload in TLC. *******//

                // Save the model in internal binary format that can be used for loading it.
                string modelFilename = Path.GetTempFileName();
                PredictorUtils.Save(modelFilename, predictor, instances, null);

                // Save the model as a plain-text description
                string modelFilenameText = Path.GetTempFileName();
                PredictorUtils.SaveText(predictor, instances.Schema.FeatureNames, modelFilenameText);

                // Save the model in Bing's INI format
                string modelFilenameIni = Path.GetTempFileName();
                PredictorUtils.SaveIni(predictor, instances.Schema.FeatureNames, modelFilenameIni);

                ///*********  Loading and making predictions with a previously saved model *******//
                // Note:   there are several alternative ways to construct instances
                // For example, see FactoryExampleTest  below that demonstrates named-feature : value pairs.

                // Load saved model
                IDataModel dataModel;
                IDataStats dataStats;
                var pred = PredictorUtils.LoadPredictor <Float>(out dataModel, out dataStats, modelFilename);
                var dp   = pred as IDistributionPredictor <Instance, Float, Float>;

                // Get predictions for instances
                Float[] probPredictions = new Float[instances.Count];
                Float[] rawPredictions  = new Float[instances.Count];
                Float[] rawPredictions1 = new Float[instances.Count];
                for (int i = 0; i < instances.Count; i++)
                {
                    probPredictions[i] = dp.PredictDistribution(instances[i], out rawPredictions[i]);
                    rawPredictions1[i] = dp.Predict(new Instance(data[i]));
                }

                Float[] bulkPredictions = ((IBulkPredictor <Instance, Instances, Float, Float[]>)pred).BulkPredict(instances);

                Assert.Equal(rawPredictions.Length, bulkPredictions.Length);
                Assert.Equal(rawPredictions.Length, rawPredictions1.Length);
                for (int i = 0; i < rawPredictions.Length; i++)
                {
                    Assert.Equal(rawPredictions[i], bulkPredictions[i]);
                }
                for (int i = 0; i < rawPredictions.Length; i++)
                {
                    Assert.Equal(rawPredictions[i], rawPredictions1[i]);
                }

                //test new testers
                {
                    var results = new ClassifierTester(new ProbabilityPredictorTesterArguments()).Test(pred, instances);

                    // Get metric names and print them alongside numbers
                    for (int i = 0; i < results.Length; i++)
                    {
                        Log("{0,-30} {1}", results[i].Name, results[i].Value);
                    }

                    // sanity check vs. original predictor
                    var results2 = new ClassifierTester(new ProbabilityPredictorTesterArguments()).Test(predictor, instances);
                    Assert.Equal(results.Length, results2.Length);
                    for (int i = 0; i < results.Length; i++)
                    {
                        Assert.Equal(results[i].Name, results2[i].Name);
                        Assert.Equal(results[i].Value, results2[i].Value);
                    }
                }
                File.Delete(modelFilename);
                File.Delete(modelFilenameText);
                File.Delete(modelFilenameIni);
            });
            Done();
        }
예제 #9
0
        public void Train(List <FeatureSubsetModel <IPredictorProducing <TOutput> > > models, RoleMappedData data, IHostEnvironment env)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(Stacking.LoadName);

            host.CheckValue(models, nameof(models));
            host.CheckValue(data, nameof(data));

            using (var ch = host.Start("Training stacked model"))
            {
                ch.Check(Meta == null, "Train called multiple times");
                ch.Check(BasePredictorType != null);

                var maps = new ValueMapper <VBuffer <Single>, TOutput> [models.Count];
                for (int i = 0; i < maps.Length; i++)
                {
                    Contracts.Assert(models[i].Predictor is IValueMapper);
                    var m = (IValueMapper)models[i].Predictor;
                    maps[i] = m.GetMapper <VBuffer <Single>, TOutput>();
                }

                // REVIEW: Should implement this better....
                var labels   = new Single[100];
                var features = new VBuffer <Single> [100];
                int count    = 0;
                // REVIEW: Should this include bad values or filter them?
                using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
                {
                    TOutput[] predictions = new TOutput[maps.Length];
                    var       vBuffers    = new VBuffer <Single> [maps.Length];
                    while (cursor.MoveNext())
                    {
                        Parallel.For(0, maps.Length, i =>
                        {
                            var model = models[i];
                            if (model.SelectedFeatures != null)
                            {
                                EnsembleUtils.SelectFeatures(ref cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
                                maps[i](ref vBuffers[i], ref predictions[i]);
                            }
                            else
                            {
                                maps[i](ref cursor.Features, ref predictions[i]);
                            }
                        });

                        Utils.EnsureSize(ref labels, count + 1);
                        Utils.EnsureSize(ref features, count + 1);
                        labels[count] = cursor.Label;
                        FillFeatureBuffer(predictions, ref features[count]);
                        count++;
                    }
                }

                ch.Info("The number of instances used for stacking trainer is {0}", count);

                var bldr = new ArrayDataViewBuilder(host);
                Array.Resize(ref labels, count);
                Array.Resize(ref features, count);
                bldr.AddColumn("Label", NumberType.Float, labels);
                bldr.AddColumn("Features", NumberType.Float, features);

                var view = bldr.GetDataView();
                var rmd  = RoleMappedData.Create(view, ColumnRole.Label.Bind("Label"), ColumnRole.Feature.Bind("Features"));

                var trainer = BasePredictorType.CreateInstance(host);
                if (trainer is ITrainerEx ex && ex.NeedNormalization)
                {
                    ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
                }
                trainer.Train(rmd);
                Meta = trainer.CreatePredictor();
                CheckMeta();

                ch.Done();
            }
        }
예제 #10
0
        private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
        {
            var sub = new SubComponent <IEarlyStoppingCriterion, SignatureEarlyStoppingCriterion>(name, args);

            return(sub.CreateInstance(new TlcEnvironment(), lowerIsBetter));
        }
예제 #11
0
        // Returns true if a normalizer was added.
        public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValue(view, nameof(view));
            ch.CheckValueOrNull(featureColumn);
            ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
                            "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");

            if (autoNorm == NormalizeOption.No)
            {
                ch.Info("Not adding a normalizer.");
                return(false);
            }

            if (string.IsNullOrEmpty(featureColumn))
            {
                return(false);
            }

            int featCol;
            var schema = view.Schema;

            if (schema.TryGetColumnIndex(featureColumn, out featCol))
            {
                if (autoNorm != NormalizeOption.Yes)
                {
                    var    nn           = trainer as ITrainerEx;
                    DvBool isNormalized = DvBool.False;
                    if (nn == null || !nn.NeedNormalization ||
                        (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
                         isNormalized.IsTrue))
                    {
                        ch.Info("Not adding a normalizer.");
                        return(false);
                    }
                    if (autoNorm == NormalizeOption.Warn)
                    {
                        ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
                        return(false);
                    }
                }
                ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
                // Quote the feature column name
                string        quotedFeatureColumnName = featureColumn;
                StringBuilder sb = new StringBuilder();
                if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
                {
                    quotedFeatureColumnName = sb.ToString();
                }
                var component = new SubComponent <IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
                var loader    = view as IDataLoader;
                if (loader != null)
                {
                    view = CompositeDataLoader.Create(env, loader,
                                                      new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(null, component));
                }
                else
                {
                    view = component.CreateInstance(env, view);
                }
                return(true);
            }
            return(false);
        }