コード例 #1
        /// <summary>
        /// Save the data pipeline defined by dataPipe. If blankLoader is true or the root IDataView is not an IDataLoader,
        /// this persists the root as a BinaryLoader having the same schema.
        /// </summary>
        public static void SaveDataPipe(IHostEnvironment env, RepositoryWriter repositoryWriter, IDataView dataPipe, bool blankLoader = false)
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(repositoryWriter, nameof(repositoryWriter));
            env.CheckValue(dataPipe, nameof(dataPipe));

            IDataView pipeStart;
            var       xfs = BacktrackPipe(dataPipe, out pipeStart);

            Action <ModelSaveContext> saveAction;

            if (!blankLoader && pipeStart is ILegacyDataLoader loader)
                saveAction = loader.Save;
                // The serialized pipe must start with a loader. If the original data view is not a loader,
                // we replace it with a binary loader with the correct schema.
                saveAction = ctx => BinaryLoader.SaveInstance(env, ctx, pipeStart.Schema);

            using (var ctx = ModelFileUtils.GetDataModelSavingContext(repositoryWriter))
                LegacyCompositeDataLoader.SavePipe(env, ctx, saveAction, xfs);
コード例 #2
        // Returns true if a normalizer was added.
        public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValue(view, nameof(view));
            ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
                            "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");

            if (autoNorm == NormalizeOption.No)
                ch.Info("Not adding a normalizer.");

            if (string.IsNullOrEmpty(featureColumn))

            int featCol;
            var schema = view.Schema;

            if (schema.TryGetColumnIndex(featureColumn, out featCol))
                if (autoNorm != NormalizeOption.Yes)
                    if (!trainer.Info.NeedNormalization || schema[featCol].IsNormalized())
                        ch.Info("Not adding a normalizer.");
                    if (autoNorm == NormalizeOption.Warn)
                        ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
                ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
                IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
                => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);

                if (view is ILegacyDataLoader loader)
                    view = LegacyCompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
                    view = ApplyNormalizer(env, view);
コード例 #3
            private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData)

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                            using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                                using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                                    return(LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true));
コード例 #4
 private ILegacyDataLoader CreateTransformChain(ILegacyDataLoader loader)
     return(LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.Transforms));
コード例 #5
            /// <summary>
            /// Loads multiple artifacts of interest from the input model file, given the context
            /// established by the command line arguments.
            /// </summary>
            /// <param name="ch">The channel to which to provide output.</param>
            /// <param name="wantPredictor">Whether we want a predictor from the model file. If
            /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will
            /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail
            /// noisily if we cannot.</param>
            /// <param name="predictor">The predictor in the model, or <c>null</c> if
            /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was
            /// <c>null</c> and no predictor was present.</param>
            /// <param name="wantTrainSchema">Whether we want the training schema. Unlike
            /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If
            /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain
            /// <c>null</c> if there were no role mappings, or pipeline.</param>
            /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/>
            /// is true, and there were role mappings stored in the model.</param>
            /// <param name="pipe">The data pipe constructed from the combination of the
            /// model and command line arguments.</param>
            protected void LoadModelObjects(
                IChannel ch,
                bool?wantPredictor, out IPredictor predictor,
                bool wantTrainSchema, out RoleMappedSchema trainSchema,
                out ILegacyDataLoader pipe)
                // First handle the case where there is no input model file.
                // Everything must come from the command line.

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                            // First consider loading the predictor.
                            if (wantPredictor == false)
                                predictor = null;
                                ch.Trace("Loading predictor");
                                predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep);
                                if (wantPredictor == true)
                                    Host.Check(predictor != null, "Could not load predictor from model file");

                            // Next create the loader.
                            var loaderFactory           = ImplOptions.Loader;
                            ILegacyDataLoader trainPipe = null;
                            if (loaderFactory != null)
                                // The loader is overridden from the command line.
                                pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile));
                                if (ImplOptions.LoadTransforms == true)
                                    Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                                    pipe = LoadTransformChain(pipe);
                                var loadTrans = ImplOptions.LoadTransforms ?? true;
                                pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans);
                                if (loadTrans)
                                    trainPipe = pipe;

                            if (Utils.Size(ImplOptions.Transforms) > 0)
                                pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms);

                            // Next consider loading the training data's role mapped schema.
                            trainSchema = null;
                            if (wantTrainSchema)
                                // First try to get the role mappings.
                                var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep);
                                if (trainRoleMappings != null)
                                    // Next create the training schema. In the event that the loaded pipeline happens
                                    // to be the training pipe, we can just use that. If it differs, then we need to
                                    // load the full pipeline from the model, relying upon the fact that all loaders
                                    // can be loaded with no data at all, to get their schemas.
                                    if (trainPipe == null)
                                        trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
                                    trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
                                // If the role mappings are null, an alternative would be to fail. However the idea
                                // is that the scorer should always still succeed, although perhaps with reduced
                                // functionality, even when the training schema is null, since not all versions of
                                // TLC models will have the role mappings preserved, I believe. And, we do want to
                                // maintain backwards compatibility.
コード例 #6
        private void RunCore(IChannel ch)

            ch.Trace("Creating loader");

            LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);

            ch.Trace("Creating pipeline");
            var scorer = ImplOptions.Scorer;

            ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line.");
            var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory);


            // REVIEW: We probably ought to prefer role mappings from the training schema.
            string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                              nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                               nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var schema     = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
            var mapper     = bindable.Bind(Host, schema);

            if (scorer == null)
                scorer = ScoreUtils.GetScorerComponent(Host, mapper);

            loader = LegacyCompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
                                                              (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));

            loader = LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.PostTransform);

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
                ch.Trace("Saving the data pipe");
                SaveLoader(loader, ImplOptions.OutputModelFile);

            ch.Trace("Creating saver");
            IDataSaver writer;

            if (ImplOptions.Saver == null)
                var ext    = Path.GetExtension(ImplOptions.OutputDataFile);
                var isText = ext == ".txt" || ext == ".tlc";
                if (isText)
                    writer = new TextSaver(Host, new TextSaver.Arguments());
                    writer = new BinarySaver(Host, new BinarySaver.Arguments());
                writer = ImplOptions.Saver.CreateComponent(Host);
            ch.Assert(writer != null);
            var outputIsBinary = writer is BinaryWriter;

            bool outputAllColumns =
                ImplOptions.OutputAllColumns == true ||
                (ImplOptions.OutputAllColumns == null && Utils.Size(ImplOptions.OutputColumns) == 0 && outputIsBinary);

            bool outputNamesAndLabels =
                ImplOptions.OutputAllColumns == true || Utils.Size(ImplOptions.OutputColumns) == 0;

            if (ImplOptions.OutputAllColumns == true && Utils.Size(ImplOptions.OutputColumns) != 0)
                ch.Warning(nameof(ImplOptions.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(ImplOptions.OutputColumns) + " specified.");

            if (!outputAllColumns && Utils.Size(ImplOptions.OutputColumns) != 0)
                foreach (var outCol in ImplOptions.OutputColumns)
                    if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
                        throw ch.ExceptUserArg(nameof(Arguments.OutputColumns), "Column '{0}' not found.", outCol);

            uint maxScoreId = 0;

            if (!outputAllColumns)
                maxScoreId = loader.Schema.GetMaxAnnotationKind(out int colMax, AnnotationUtils.Kinds.ScoreColumnSetId);
            ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
            var cols = new List <int>();

            for (int i = 0; i < loader.Schema.Count; i++)
                if (!ImplOptions.KeepHidden && loader.Schema[i].IsHidden)
                if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
                var type = loader.Schema[i].Type;
                if (writer.IsColumnSavable(type))
                    ch.Warning("The column '{0}' will not be written as it has unsavable column type.",

            ch.Check(cols.Count > 0, "No valid columns to save");

            ch.Trace("Scoring and saving data");
            using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile))
                using (var stream = file.CreateWriteStream())
                    writer.SaveData(stream, loader, cols.ToArray());
コード例 #7
            private FoldResult RunFold(int fold)
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateComponent(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Options();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Options();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                            ch.Warning("Trainer does not accept validation dataset.");
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    ch.Assert(_scorer == null || _scorer is ICommandLineComponentFactory, "CrossValidationCommand should only be used from the command line.");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                        using (var file = host.CreateOutputFile(modelFileName))
                            var rmd = new RoleMappedData(
                                LegacyCompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                         (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);

                    // Evaluate.
                    var eval = _evaluator?.CreateComponent(host) ??
                               EvaluateUtils.GetEvaluator(host, scorePipe.Schema);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
コード例 #8
        private void RunCore(IChannel ch, string cmd)

            IPredictor inputPredictor = null;

            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");

            ch.Trace("Constructing data pipeline");
            ILegacyDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = ImplOptions.PreTransforms;

            if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile))
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                    preXf = preXf.Concat(
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                            var args     = new GenerateNumberTransform.Options();
                            args.Columns = new[] { new GenerateNumberTransform.Column()
                                                       Name = DefaultColumnNames.Name
                                                   }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
            loader = LegacyCompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = ImplOptions.Scorer;
            var       evaluator = ImplOptions.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (ImplOptions.ValidationFile != null)
                validDataCreator =
                    () =>
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile));

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
                throw ch.Except("No overall metrics found");

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics,
                                                                                ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                if (ImplOptions.CollateMetrics)
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]);
                    int i = 0;
                    foreach (var idv in perInstance)
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv);