Example #1
0
        /// <summary>
        /// Given a predictor and an optional scorer SubComponent, produces a compatible ISchemaBindableMapper.
        /// First, it tries to instantiate the bindable mapper using the <paramref name="scorerSettings"/>
        /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
        /// of the scorer).
        /// If the above fails, it checks whether the predictor implements <see cref="ISchemaBindableMapper"/>
        /// directly.
        /// If this also isn't true, it will create a 'matching' standard mapper.
        /// </summary>
        public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env, IPredictor predictor,
                                                                    SubComponent <IDataScorerTransform, SignatureDataScorer> scorerSettings)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(scorerSettings);

            // See if we can instantiate a mapper using scorer arguments.
            if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable))
            {
                return(bindable);
            }

            // The easy case is that the predictor implements the interface.
            bindable = predictor as ISchemaBindableMapper;
            if (bindable != null)
            {
                return(bindable);
            }

            // Use one of the standard wrappers.
            if (predictor is IValueMapperDist)
            {
                return(new SchemaBindableBinaryPredictorWrapper(predictor));
            }

            return(new SchemaBindablePredictorWrapper(predictor));
        }
Example #2
0
        public static IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                       SubComponent <ICalibratorTrainer, SignatureCalibrator> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null)
        {
            ICalibratorTrainer caliTrainer = !calibrator.IsGood() ? null : calibrator.CreateInstance(env);

            return(TrainCore(env, ch, data, trainer, name, validData, caliTrainer, maxCalibrationExamples, cacheData, inputPredictor));
        }
        private static byte[] GetBytesOne(IHost host, string dataFile, SubComponent <IDataLoader, SignatureDataLoader> sub,
                                          string termColumn, string valueColumn)
        {
            Contracts.AssertValue(host);
            host.Assert(!string.IsNullOrWhiteSpace(dataFile));
            host.AssertNonEmpty(termColumn);
            host.AssertNonEmpty(valueColumn);

            if (!sub.IsGood())
            {
                // REVIEW: Should there be defaults for loading from text?
                var  ext         = Path.GetExtension(dataFile);
                bool isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                if (!isBinary && !isTranspose)
                {
                    throw host.ExceptUserArg(nameof(Arguments.Loader), "must specify the loader");
                }
                host.Assert(isBinary != isTranspose); // One or the other must be true.
                sub = new SubComponent <IDataLoader, SignatureDataLoader>(isBinary ? "BinaryLoader" : "TransposeLoader");
            }
            var ldr = sub.CreateInstance(host, new MultiFileSource(dataFile));

            return(GetBytesFromDataView(host, ldr, termColumn, valueColumn));
        }
        // This saves the lookup data as a byte array encoded as a binary .idv file.
        private static byte[] GetBytes(IHost host, ColInfo[] infos, Arguments args)
        {
            Contracts.AssertValue(host);
            host.AssertNonEmpty(infos);
            host.AssertValue(args);

            string dataFile = args.DataFile;
            SubComponent <IDataLoader, SignatureDataLoader> loader = args.Loader;
            string termColumn;
            string valueColumn;

            if (!string.IsNullOrEmpty(args.TermColumn))
            {
                host.Assert(!string.IsNullOrEmpty(args.ValueColumn));
                termColumn  = args.TermColumn;
                valueColumn = args.ValueColumn;
            }
            else
            {
                var ext = Path.GetExtension(dataFile);
                if (loader.IsGood() || string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase))
                {
                    throw host.ExceptUserArg(nameof(args.TermColumn), "Term and value columns needed.");
                }
                loader      = GetLoaderSubComponent(args.DataFile, args.KeyValues, host);
                termColumn  = "Term";
                valueColumn = "Value";
            }
            return(GetBytesOne(host, dataFile, loader, termColumn, valueColumn));
        }
Example #5
0
 public bool TryUnparse(StringBuilder sb)
 {
     Contracts.AssertValue(sb);
     if (NewDim != null || MatrixGenerator.IsGood() || UseSin != null || Seed != null)
     {
         return(false);
     }
     return(TryUnparseCore(sb));
 }
            protected void SendTelemetryComponent(IPipe <TelemetryMessage> pipe, SubComponent sub)
            {
                Host.AssertValue(pipe);
                Host.AssertValueOrNull(sub);

                if (sub.IsGood())
                {
                    pipe.Send(TelemetryMessage.CreateTrainer(sub.Kind, sub.SubComponentSettings));
                }
            }
Example #7
0
        private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor,
                                                        SubComponent <IDataScorerTransform, SignatureDataScorer> scorerSettings, out ISchemaBindableMapper bindable)
        {
            Contracts.AssertValue(env);
            env.AssertValue(predictor);
            env.Assert(scorerSettings.IsGood());

            // Try to find a mapper factory method with the same loadname as the scorer settings.
            var mapperComponent = new SubComponent <ISchemaBindableMapper, SignatureBindableMapper>(scorerSettings.Kind, scorerSettings.Settings);

            return(ComponentCatalog.TryCreateInstance(env, out bindable, mapperComponent, predictor));
        }
Example #8
0
        public static LoadableClassInfo GetLoadableClassInfo <TRes, TSig>(SubComponent <TRes, TSig> sub)
            where TRes : class
        {
            Contracts.CheckParam(typeof(TSig).BaseType == typeof(MulticastDelegate), nameof(TSig), "TSig must be a delegate type");
            Contracts.CheckParam(sub.IsGood(), nameof(sub), "SubComponent must be non-null and non-empty");

            // SubComponent.Kind is never null (may be empty).
            Contracts.Assert(sub.Kind != null);

            string loadName = sub.Kind.ToLowerInvariant().Trim();

            return(FindClassCore(new LoadableClassInfo.Key(loadName, typeof(TSig))));
        }
Example #9
0
        /// <summary>
        /// Determines the scorer subcomponent (if the given one is null or empty), and creates the schema bound mapper.
        /// </summary>
        private static SubComponent <IDataScorerTransform, SignatureDataScorer> GetScorerComponentAndMapper(
            IPredictor predictor, SubComponent <IDataScorerTransform, SignatureDataScorer> scorer,
            RoleMappedSchema schema, IHostEnvironment env, out ISchemaBoundMapper mapper)
        {
            Contracts.AssertValue(env);

            var bindable = GetSchemaBindableMapper(env, predictor, scorer);

            env.AssertValue(bindable);
            mapper = bindable.Bind(env, schema);
            if (scorer.IsGood())
            {
                return(scorer);
            }
            return(GetScorerComponent(mapper));
        }
Example #10
0
        private static IDataLoader LoadStopwords(IHostEnvironment env, IChannel ch, string dataFile,
                                                 SubComponent <IDataLoader, SignatureDataLoader> loader, ref string stopwordsCol)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            // First column using the file.
            if (!loader.IsGood())
            {
                // Determine the default loader from the extension.
                var  ext         = Path.GetExtension(dataFile);
                bool isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                if (isBinary || isTranspose)
                {
                    ch.Assert(isBinary != isTranspose);
                    ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Arguments.StopwordsColumn),
                                    "stopwordsColumn should be specified");
                    if (isBinary)
                    {
                        loader = new SubComponent <IDataLoader, SignatureDataLoader>("BinaryLoader");
                    }
                    else
                    {
                        ch.Assert(isTranspose);
                        loader = new SubComponent <IDataLoader, SignatureDataLoader>("TransposeLoader");
                    }
                }
                else
                {
                    if (!string.IsNullOrWhiteSpace(stopwordsCol))
                    {
                        ch.Warning("{0} should not be specified when default loader is TextLoader. Ignoring stopwordsColumn={0}",
                                   stopwordsCol);
                    }
                    loader       = new SubComponent <IDataLoader, SignatureDataLoader>("TextLoader", "sep=tab col=Stopwords:TX:0");
                    stopwordsCol = "Stopwords";
                }
            }
            ch.AssertNonEmpty(stopwordsCol);

            return(loader.CreateInstance(env, new MultiFileSource(dataFile)));
        }
Example #11
0
        public static ComponentCatalog.LoadableClassInfo CheckTrainer <TSig>(IExceptionContext ectx, SubComponent <ITrainer, TSig> trainer, string dataFile)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckUserArg(trainer.IsGood(), nameof(TrainCommand.Arguments.Trainer), "A trainer is required.");

            var info = ComponentCatalog.GetLoadableClassInfo <TSig>(trainer.Kind);

            if (info == null)
            {
                throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Unknown trainer: '{0}'", trainer.Kind);
            }
            if (!typeof(ITrainer).IsAssignableFrom(info.Type))
            {
                throw ectx.Except("Loadable class '{0}' does not implement 'ITrainer'", info.LoadNames[0]);
            }
            if (string.IsNullOrWhiteSpace(dataFile))
            {
                throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.DataFile), "Data file must be defined.");
            }
            return(info);
        }
            private FoldResult RunFold(int fold)
            {
                var host = GetHost();

                host.Assert(0 <= fold && fold <= _numFolds);
                // REVIEW: Make channels buffered in multi-threaded environments.
                using (var ch = host.Start($"Fold {fold}"))
                {
                    ch.Trace("Constructing trainer");
                    ITrainer trainer = _trainer.CreateInstance(host);

                    // Train pipe.
                    var trainFilter = new RangeFilter.Arguments();
                    trainFilter.Column     = _splitColumn;
                    trainFilter.Min        = (Double)fold / _numFolds;
                    trainFilter.Max        = (Double)(fold + 1) / _numFolds;
                    trainFilter.Complement = true;
                    IDataView trainPipe = new RangeFilter(host, trainFilter, _inputDataView);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = _createExamples(host, ch, trainPipe, trainer);

                    // Test pipe.
                    var testFilter = new RangeFilter.Arguments();
                    testFilter.Column = trainFilter.Column;
                    testFilter.Min    = trainFilter.Min;
                    testFilter.Max    = trainFilter.Max;
                    ch.Assert(!testFilter.Complement);
                    IDataView testPipe = new RangeFilter(host, testFilter, _inputDataView);
                    testPipe = new OpaqueDataView(testPipe);
                    var testData = _applyTransformsToTestData(host, ch, testPipe, trainData, trainPipe);

                    // Validation pipe and examples.
                    RoleMappedData validData = null;
                    if (_getValidationDataView != null)
                    {
                        ch.Assert(_applyTransformsToValidationData != null);
                        if (!trainer.Info.SupportsValidation)
                        {
                            ch.Warning("Trainer does not accept validation dataset.");
                        }
                        else
                        {
                            ch.Trace("Constructing the validation pipeline");
                            IDataView validLoader = _getValidationDataView();
                            var       validPipe   = ApplyTransformUtils.ApplyAllTransformsToData(host, _inputDataView, validLoader);
                            validPipe = new OpaqueDataView(validPipe);
                            validData = _applyTransformsToValidationData(host, ch, validPipe, trainData, trainPipe);
                        }
                    }

                    // Train.
                    var predictor = TrainUtils.Train(host, ch, trainData, trainer, _trainer.Kind, validData,
                                                     _calibrator, _maxCalibrationExamples, _cacheData, _inputPredictor);

                    // Score.
                    ch.Trace("Scoring and evaluating");
                    var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, _scorer);
                    ch.AssertValue(bindable);
                    var mapper     = bindable.Bind(host, testData.Schema);
                    var scorerComp = _scorer.IsGood() ? _scorer : ScoreUtils.GetScorerComponent(mapper);
                    IDataScorerTransform scorePipe = scorerComp.CreateInstance(host, testData.Data, mapper, trainData.Schema);

                    // Save per-fold model.
                    string modelFileName = ConstructPerFoldName(_outputModelFile, fold);
                    if (modelFileName != null && _loader != null)
                    {
                        using (var file = host.CreateOutputFile(modelFileName))
                        {
                            var rmd = new RoleMappedData(
                                CompositeDataLoader.ApplyTransform(host, _loader, null, null,
                                                                   (e, newSource) => ApplyTransformUtils.ApplyAllTransformsToData(e, trainData.Data, newSource)),
                                trainData.Schema.GetColumnRoleNames());
                            TrainUtils.SaveModel(host, ch, file, predictor, rmd, _cmd);
                        }
                    }

                    // Evaluate.
                    var evalComp = _evaluator;
                    if (!evalComp.IsGood())
                    {
                        evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
                    }
                    var eval = evalComp.CreateInstance(host);
                    // Note that this doesn't require the provided columns to exist (because of the "opt" parameter).
                    // We don't normally expect the scorer to drop columns, but if it does, we should not require
                    // all the columns in the test pipeline to still be present.
                    var dataEval = new RoleMappedData(scorePipe, testData.Schema.GetColumnRoleNames(), opt: true);

                    var            dict        = eval.Evaluate(dataEval);
                    RoleMappedData perInstance = null;
                    if (_savePerInstance)
                    {
                        var perInst = eval.GetPerInstanceMetrics(dataEval);
                        perInstance = new RoleMappedData(perInst, dataEval.Schema.GetColumnRoleNames(), opt: true);
                    }
                    ch.Done();
                    return(new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema));
                }
            }