示例#1
0
        /// <summary>
        /// Save the data pipeline defined by dataPipe. If blankLoader is true or the root IDataView is not an IDataLoader,
        /// this persists the root as a BinaryLoader having the same schema.
        /// </summary>
        public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositoryWriter, IDataView dataPipe, bool blankLoader = false)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(repositoryWriter, nameof(repositoryWriter));
            env.CheckValue(dataPipe, nameof(dataPipe));

            IDataView pipeStart;
            var       xfs = BacktrackPipe(dataPipe, out pipeStart);

            Action <ModelSaveContext> saveAction;

            if (!blankLoader && pipeStart is IDataLoader loader)
            {
                saveAction = loader.Save;
            }
            else
            {
                // The serialized pipe must start with a loader. If the original data view is not a loader,
                // we replace it with a binary loader with the correct schema.
                saveAction = ctx => BinaryLoader.SaveInstance(env, ctx, pipeStart.Schema);
            }

            using (var ctx = ModelFileUtils.GetDataModelSavingContext(repositoryWriter))
            {
                CompositeDataLoader.SavePipe(env, ctx, saveAction, xfs);
                ctx.Done();
            }
        }
示例#2
0
        // Returns true if a normalizer was added.
        public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValue(view, nameof(view));
            ch.CheckValueOrNull(featureColumn);
            ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
                            "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");

            if (autoNorm == NormalizeOption.No)
            {
                ch.Info("Not adding a normalizer.");
                return(false);
            }

            if (string.IsNullOrEmpty(featureColumn))
            {
                return(false);
            }

            int featCol;
            var schema = view.Schema;

            if (schema.TryGetColumnIndex(featureColumn, out featCol))
            {
                if (autoNorm != NormalizeOption.Yes)
                {
                    if (!trainer.Info.NeedNormalization || schema[featCol].IsNormalized())
                    {
                        ch.Info("Not adding a normalizer.");
                        return(false);
                    }
                    if (autoNorm == NormalizeOption.Warn)
                    {
                        ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
                        return(false);
                    }
                }
                ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
                IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
                => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);

                if (view is IDataLoader loader)
                {
                    view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
                }
                else
                {
                    view = ApplyNormalizer(env, view);
                }
                return(true);
            }
            return(false);
        }
            private IDataLoader LoadTransformChain(IDataLoader srcData)
            {
                Host.Assert(!string.IsNullOrWhiteSpace(Args.InputModelFile));

                using (var file = Host.OpenInputFile(Args.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                            using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                                using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                                    return(CompositeDataLoader.Create(Host, ctx, srcData, x => true));
            }
 private IDataLoader CreateTransformChain(IDataLoader loader)
 {
     return(CompositeDataLoader.Create(Host, loader, Args.Transforms));
 }
            /// <summary>
            /// Loads multiple artifacts of interest from the input model file, given the context
            /// established by the command line arguments.
            /// </summary>
            /// <param name="ch">The channel to which to provide output.</param>
            /// <param name="wantPredictor">Whether we want a predictor from the model file. If
            /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will
            /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail
            /// noisily if we cannot.</param>
            /// <param name="predictor">The predictor in the model, or <c>null</c> if
            /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was
            /// <c>null</c> and no predictor was present.</param>
            /// <param name="wantTrainSchema">Whether we want the training schema. Unlike
            /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If
            /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain
            /// <c>null</c> if there were no role mappings, or pipeline.</param>
            /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/>
            /// is true, and there were role mappings stored in the model.</param>
            /// <param name="pipe">The data pipe constructed from the combination of the
            /// model and command line arguments.</param>
            protected void LoadModelObjects(
                IChannel ch,
                bool?wantPredictor, out IPredictor predictor,
                bool wantTrainSchema, out RoleMappedSchema trainSchema,
                out IDataLoader pipe)
            {
                // First handle the case where there is no input model file.
                // Everything must come from the command line.

                using (var file = Host.OpenInputFile(Args.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                        {
                            // First consider loading the predictor.
                            if (wantPredictor == false)
                            {
                                predictor = null;
                            }
                            else
                            {
                                ch.Trace("Loading predictor");
                                predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep);
                                if (wantPredictor == true)
                                {
                                    Host.Check(predictor != null, "Could not load predictor from model file");
                                }
                            }

                            // Next create the loader.
                            var         loaderFactory = Args.Loader;
                            IDataLoader trainPipe     = null;
                            if (loaderFactory != null)
                            {
                                // The loader is overridden from the command line.
                                pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(Args.DataFile));
                                if (Args.LoadTransforms == true)
                                {
                                    Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile));
                                    pipe = LoadTransformChain(pipe);
                                }
                            }
                            else
                            {
                                var loadTrans = Args.LoadTransforms ?? true;
                                pipe = LoadLoader(rep, Args.DataFile, loadTrans);
                                if (loadTrans)
                                {
                                    trainPipe = pipe;
                                }
                            }

                            if (Utils.Size(Args.Transforms) > 0)
                            {
                                pipe = CompositeDataLoader.Create(Host, pipe, Args.Transforms);
                            }

                            // Next consider loading the training data's role mapped schema.
                            trainSchema = null;
                            if (wantTrainSchema)
                            {
                                // First try to get the role mappings.
                                var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep);
                                if (trainRoleMappings != null)
                                {
                                    // Next create the training schema. In the event that the loaded pipeline happens
                                    // to be the training pipe, we can just use that. If it differs, then we need to
                                    // load the full pipeline from the model, relying upon the fact that all loaders
                                    // can be loaded with no data at all, to get their schemas.
                                    if (trainPipe == null)
                                    {
                                        trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
                                    }
                                    trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
                                }
                                // If the role mappings are null, an alternative would be to fail. However the idea
                                // is that the scorer should always still succeed, although perhaps with reduced
                                // functionality, even when the training schema is null, since not all versions of
                                // TLC models will have the role mappings preserved, I believe. And, we do want to
                                // maintain backwards compatibility.
                            }
                        }
            }
示例#6
0
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);

            ch.Trace("Creating loader");

            LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Creating pipeline");
            var scorer = Args.Scorer;

            ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line.");
            var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory);

            ch.AssertValue(bindable);

            // REVIEW: We probably ought to prefer role mappings from the training schema.
            string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                              nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                               nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var schema     = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
            var mapper     = bindable.Bind(Host, schema);

            if (scorer == null)
            {
                scorer = ScoreUtils.GetScorerComponent(Host, mapper);
            }

            loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
                                                        (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));

            loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform);

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                SaveLoader(loader, Args.OutputModelFile);
            }

            ch.Trace("Creating saver");
            IDataSaver writer;

            if (Args.Saver == null)
            {
                var ext    = Path.GetExtension(Args.OutputDataFile);
                var isText = ext == ".txt" || ext == ".tlc";
                if (isText)
                {
                    writer = new TextSaver(Host, new TextSaver.Arguments());
                }
                else
                {
                    writer = new BinarySaver(Host, new BinarySaver.Arguments());
                }
            }
            else
            {
                writer = Args.Saver.CreateComponent(Host);
            }
            ch.Assert(writer != null);
            var outputIsBinary = writer is BinaryWriter;

            bool outputAllColumns =
                Args.OutputAllColumns == true ||
                (Args.OutputAllColumns == null && Utils.Size(Args.OutputColumn) == 0 && outputIsBinary);

            bool outputNamesAndLabels =
                Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0;

            if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0)
            {
                ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified.");
            }

            if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0)
            {
                foreach (var outCol in Args.OutputColumn)
                {
                    if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
                    {
                        throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol);
                    }
                }
            }

            uint maxScoreId = 0;

            if (!outputAllColumns)
            {
                maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
            }
            ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
            var cols = new List <int>();

            for (int i = 0; i < loader.Schema.Count; i++)
            {
                if (!Args.KeepHidden && loader.Schema[i].IsHidden)
                {
                    continue;
                }
                if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
                {
                    continue;
                }
                var type = loader.Schema[i].Type;
                if (writer.IsColumnSavable(type))
                {
                    cols.Add(i);
                }
                else
                {
                    ch.Warning("The column '{0}' will not be written as it has unsavable column type.",
                               loader.Schema[i].Name);
                }
            }

            ch.Check(cols.Count > 0, "No valid columns to save");

            ch.Trace("Scoring and saving data");
            using (var file = Host.CreateOutputFile(Args.OutputDataFile))
                using (var stream = file.CreateWriteStream())
                    writer.SaveData(stream, loader, cols.ToArray());
        }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateComponent(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Options();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Options();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line.");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = new RoleMappedData(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var eval = _evaluator?.CreateComponent(host) ??
                               EvaluateUtils.GetEvaluator(host, scorePipe.Schema);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    }
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransforms;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                        {
                            var args     = new GenerateNumberTransform.Options();
                            args.Columns = new[] { new GenerateNumberTransform.Column()
                                                   {
                                                       Name = DefaultColumnNames.Name
                                                   }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
                        }))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }