Exemplo n.º 1
0
        /// <summary>
        /// Given a view and a bunch of column names, create the RoleMappedData object. Any or all of the column
        /// names may be null or whitespace, in which case they are ignored. Any columns that are specified but not
        /// valid columns of the schema are also ignored.
        /// </summary>
        public static RoleMappedData CreateExamplesOpt(IDataView view, string label, string feature,
                                                       string group = null, string weight = null, string name = null,
                                                       IEnumerable <KeyValuePair <ColumnRole, string> > custom = null)
        {
            Contracts.CheckValueOrNull(label);
            Contracts.CheckValueOrNull(feature);
            Contracts.CheckValueOrNull(group);
            Contracts.CheckValueOrNull(weight);
            Contracts.CheckValueOrNull(name);
            Contracts.CheckValueOrNull(custom);

            var list = new List <KeyValuePair <ColumnRole, string> >();

            if (!string.IsNullOrWhiteSpace(label))
            {
                list.Add(ColumnRole.Label.Bind(label));
            }
            if (!string.IsNullOrWhiteSpace(feature))
            {
                list.Add(ColumnRole.Feature.Bind(feature));
            }
            if (!string.IsNullOrWhiteSpace(group))
            {
                list.Add(ColumnRole.Group.Bind(group));
            }
            if (!string.IsNullOrWhiteSpace(weight))
            {
                list.Add(ColumnRole.Weight.Bind(weight));
            }
            if (!string.IsNullOrWhiteSpace(name))
            {
                list.Add(ColumnRole.Name.Bind(name));
            }
            if (custom != null)
            {
                list.AddRange(custom);
            }

            return(RoleMappedData.CreateOpt(view, list));
        }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateInstance(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!TrainUtils.CanUseValidationData(trainer))
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = RoleMappedData.Create(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var evalComp = _evaluator;
                    if (!evalComp.IsGood())
                    {
                        evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
                    }
                    var eval = evalComp.CreateInstance(host);
                    // Note that this doesn't require the provided columns to exist (because of "Opt").
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames());

                    var       dict        = eval.Evaluate(dataEval);
                    IDataView perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst     = eval.GetPerInstanceMetrics(dataEval);
                        var perInstData = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames());
                        perInstance = eval.GetPerInstanceDataViewToSave(perInstData);
                    }
                    ch.Done();
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }