Beispiel #1
0
            public static IRowCursor CreateKeyRowCursor(RangeFilter filter, IRowCursor input, bool[] active)
            {
                Contracts.Assert(filter._type.IsKey);
                Func <RangeFilter, IRowCursor, bool[], IRowCursor> del = CreateKeyRowCursor <int>;
                var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(filter._type.RawType);

                return((IRowCursor)methodInfo.Invoke(null, new object[] { filter, input, active }));
            }
 public DoubleRowCursor(RangeFilter parent, IRowCursor input, bool[] active)
     : base(parent, input, active)
 {
     Ch.Assert(Parent._type == NumberType.R8);
     _srcGetter = Input.GetGetter <Double>(Parent._index);
     _getter    =
         (ref Double value) =>
     {
         Ch.Check(IsGood);
         value = _value;
     };
 }
            public KeyRowCursor(RangeFilter parent, IRowCursor input, bool[] active)
                : base(parent, input, active)
            {
                Ch.Assert(Parent._type.KeyCount > 0);
                _count     = Parent._type.KeyCount;
                _srcGetter = Input.GetGetter <T>(Parent._index);
                _getter    =
                    (ref T dst) =>
                {
                    Ch.Check(IsGood);
                    dst = _value;
                };
                bool identity;

                _conv = Conversions.Instance.GetStandardConversion <T, ulong>(Parent._type, NumberType.U8, out identity);
            }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateInstance(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = new RoleMappedData(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var evalComp = _evaluator;
                    if (!evalComp.IsGood())
                    {
                        evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
                    }
                    var eval = evalComp.CreateInstance(host);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    }
                    ch.Done();
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema.GetColumnType(index);
                    if (type.KeyCount > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Arguments();
                var col        = new GenerateNumberTransform.Column();
                col.Name          = stratificationColumn;
                keyGenArgs.Column = new[] { col };
                output            = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema.GetColumnType(col);
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    var hashargs = new HashTransform.Arguments();
                    hashargs.Column = new[] { new HashTransform.Column {
                                                  Source = origStratCol, Name = stratificationColumn
                                              } };
                    hashargs.HashBits = 30;
                    output            = new HashTransform(Host, hashargs, input);
                }
            }

            return(stratificationColumn);
        }
 private static IRowCursor CreateKeyRowCursor <TSrc>(RangeFilter filter, IRowCursor input, bool[] active)
 {
     Contracts.Assert(filter._type.IsKey);
     return(new KeyRowCursor <TSrc>(filter, input, active));
 }