/// <summary>
        /// Callback from the CV method to apply the transforms from the train data to the test and/or validation data.
        /// </summary>
        private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel ch, IDataView dstData,
                                                        RoleMappedData srcData, IDataView marker)
        {
            var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, srcData.Data, dstData, marker);

            return(RoleMappedData.Create(pipe, srcData.Schema.GetColumnRoleNames()));
        }
        public IDataView GetPerInstanceDataViewToSave(RoleMappedData perInstance)
        {
            Host.CheckValue(perInstance, nameof(perInstance));
            var data = RoleMappedData.Create(perInstance.Data, GetInputColumnRoles(perInstance.Schema, needName: true));

            return(WrapPerInstance(data));
        }
        /// <summary>
        /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role
        /// mappings except for the feature role mapping.
        /// </summary>
        /// <param name="env">The host environment to use to potentially instantiate the transform</param>
        /// <param name="data">The role-mapped data that is potentially going to be modified by this method.</param>
        /// <param name="trainer">The trainer to query with <see cref="NormalizeUtils.NeedNormalization(ITrainer)"/>.
        /// This method will not modify <paramref name="data"/> if the return from that is <c>null</c> or
        /// <c>false</c>.</param>
        /// <returns>True if the normalizer was applied and <paramref name="data"/> was modified</returns>
        public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(trainer, nameof(trainer));

            // If this is false or null, we do not want to normalize.
            if (trainer.NeedNormalization() != true)
            {
                return(false);
            }
            // If this is true or null, we do not want to normalize.
            if (data.Schema.FeaturesAreNormalized() != false)
            {
                return(false);
            }
            var featInfo = data.Schema.Feature;

            env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value.

            var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);

            data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames());
            return(true);
        }
        public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData)
        {
            Host.AssertValue(scoredData);

            var schema   = scoredData.Schema;
            var dataEval = RoleMappedData.Create(scoredData.Data, GetInputColumnRoles(schema));

            return(Evaluator.GetPerInstanceMetrics(dataEval));
        }
示例#5
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = _trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataView view = CreateLoader();

            ISchema schema  = view.Schema;
            var     label   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn), _labelColumn, DefaultColumnNames.Label);
            var     feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn), _featureColumn, DefaultColumnNames.Features);
            var     group   = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn), _groupColumn, DefaultColumnNames.GroupId);
            var     weight  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn), _weightColumn, DefaultColumnNames.Weight);
            var     name    = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn), _nameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref view, feature, Args.NormalizeFeatures);

            ch.Trace("Binding columns");

            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = TrainUtils.CreateExamples(view, label, feature, group, weight, name, customCols);

            // REVIEW: Unify the code that creates validation examples in Train, TrainTest and CV commands.
            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!TrainUtils.CanUseValidationData(trainer))
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, view, validPipe);
                    validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            using (var file = Host.CreateOutputFile(Args.OutputModelFile))
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
        }
示例#6
0
        /// <summary>
        /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
        /// names may be null or whitespace, in which case they are ignored. Any columns that are specified must
        /// be valid columns of the schema.
        /// </summary>
        public static RoleMappedData CreateExamples(IDataView view, string label, string feature,
                                                    string group = null, string weight = null, string name = null,
                                                    IEnumerable <KeyValuePair <ColumnRole, string> > custom = null)
        {
            Contracts.CheckValueOrNull(label);
            Contracts.CheckValueOrNull(feature);
            Contracts.CheckValueOrNull(group);
            Contracts.CheckValueOrNull(weight);
            Contracts.CheckValueOrNull(name);
            Contracts.CheckValueOrNull(custom);

            var list = new List <KeyValuePair <ColumnRole, string> >();

            if (!string.IsNullOrWhiteSpace(label))
            {
                list.Add(ColumnRole.Label.Bind(label));
            }
            if (!string.IsNullOrWhiteSpace(feature))
            {
                list.Add(ColumnRole.Feature.Bind(feature));
            }
            if (!string.IsNullOrWhiteSpace(group))
            {
                list.Add(ColumnRole.Group.Bind(group));
            }
            if (!string.IsNullOrWhiteSpace(weight))
            {
                list.Add(ColumnRole.Weight.Bind(weight));
            }
            if (!string.IsNullOrWhiteSpace(name))
            {
                list.Add(ColumnRole.Name.Bind(name));
            }
            if (custom != null)
            {
                list.AddRange(custom);
            }

            return(RoleMappedData.Create(view, list));
        }
示例#7
0
        private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer trainer, ref RoleMappedData data, bool?cacheData)
        {
            Contracts.AssertValue(env, nameof(env));
            env.AssertValue(ch, nameof(ch));
            ch.AssertValue(trainer, nameof(trainer));
            ch.AssertValue(data, nameof(data));

            ITrainerEx trainerEx   = trainer as ITrainerEx;
            bool       shouldCache = cacheData ?? (!(data.Data is BinaryLoader) && (trainerEx == null || trainerEx.WantCaching));

            if (shouldCache)
            {
                ch.Trace("Caching");
                var prefetch  = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                var cacheView = new CacheDataView(env, data.Data, prefetch);
                // Because the prefetching worked, we know that these are valid columns.
                data = RoleMappedData.Create(cacheView, data.Schema.GetColumnRoleNames());
            }
            else
            {
                ch.Trace("Not caching");
            }
            return(shouldCache);
        }
 public Dictionary <string, IDataView> Evaluate(RoleMappedData data)
 {
     data = RoleMappedData.Create(data.Data, GetInputColumnRoles(data.Schema, needStrat: true));
     return(Evaluator.Evaluate(data));
 }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateInstance(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!TrainUtils.CanUseValidationData(trainer))
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = RoleMappedData.Create(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var evalComp = _evaluator;
                    if (!evalComp.IsGood())
                    {
                        evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
                    }
                    var eval = evalComp.CreateInstance(host);
                    // Note that this doesn't require the provided columns to exist (because of "Opt").
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames());

                    var       dict        = eval.Evaluate(dataEval);
                    IDataView perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst     = eval.GetPerInstanceMetrics(dataEval);
                        var perInstData = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames());
                        perInstance = eval.GetPerInstanceDataViewToSave(perInstData);
                    }
                    ch.Done();
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            ISchema schema = trainPipe.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = TrainUtils.CreateExamples(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!TrainUtils.CanUseValidationData(trainer))
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = RoleMappedData.Create(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            IDataLoader testPipe;

            using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
                              Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);

                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var dataEval  = TrainUtils.CreateExamplesOpt(scorePipe, label, features,
                                                         group, weight, name, customCols);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }