Пример #1
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(options, nameof(options));
            _host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(options.LabelColumn), nameof(options.LabelColumn), "Must specify the label column name");

            if (!string.IsNullOrEmpty(options.InitialCountsModel))
            {
                _countTableEstimator = LoadFromFile(env, options.InitialCountsModel, options.LabelColumn,
                                                    options.Columns.Select(col => new InputOutputColumnPair(col.Name)).ToArray());
                if (_countTableEstimator == null)
                {
                    throw env.Except($"The file {options.InitialCountsModel} does not contain a CountTableTransformer");
                }
            }
            else if (options.SharedTable)
            {
                var columns = new CountTableEstimator.SharedColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column = options.Columns[i];
                    columns[i] = new CountTableEstimator.SharedColumnOptions(
                        column.Name,
                        column.Name,
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                var builder = options.CountTable;
                _host.CheckValue(builder, nameof(options.CountTable));
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, builder.CreateComponent(_host), columns);
            }
            else
            {
                var columns = new CountTableEstimator.ColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column  = options.Columns[i];
                    var builder = column.CountTable ?? options.CountTable;
                    _host.CheckValue(builder, nameof(options.CountTable));
                    columns[i] = new CountTableEstimator.ColumnOptions(
                        column.Name,
                        column.Name,
                        builder.CreateComponent(_host),
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, columns);
            }

            _hashingColumns   = InitializeHashingColumnOptions(options);
            _hashingEstimator = new HashingEstimator(_host, _hashingColumns);
        }
Пример #2
0
        /// <summary>
        /// Transforms a categorical column into a set of features that includes the count of each label class,
        /// the log-odds for each label class and the back-off indicator.
        /// </summary>
        /// <param name="catalog">The transforms catalog.</param>
        /// <param name="columns">The input and output columns.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="builder">The builder that creates the count tables from the training data.</param>
        /// <param name="priorCoefficient">The coefficient with which to apply the prior smoothing to the features.</param>
        /// <param name="laplaceScale">The Laplacian noise diversity/scale-parameter. Recommended values are between 0 and 1. Note that the noise
        /// will only be applied if the estimator is part of an <see cref="EstimatorChain{TLastTransformer}"/>, when fitting the next estimator in the chain.</param>
        /// <param name="sharedTable">Indicates whether to keep counts for all columns and slots in one shared count table. If true, the keys in the count table
        /// will include a hash of the column and slot indices.</param>
        /// <param name="numberOfBits">The number of bits to hash the input into. Must be between 1 and 31, inclusive.</param>
        /// <param name="combine">In case the input is a vector column, indicates whether the values should be combined into a single hash to create a single
        /// count table, or be left as a vector of hashes with multiple count tables.</param>
        /// <param name="hashingSeed">The seed used for hashing the input columns.</param>
        /// <returns></returns>
        public static CountTargetEncodingEstimator CountTargetEncode(this TransformsCatalog catalog,
                                                                     InputOutputColumnPair[] columns, string labelColumn = DefaultColumnNames.Label,
                                                                     CountTableBuilderBase builder = null,
                                                                     float priorCoefficient        = CountTableTransformer.Defaults.PriorCoefficient,
                                                                     float laplaceScale            = CountTableTransformer.Defaults.LaplaceScale,
                                                                     bool sharedTable = CountTableTransformer.Defaults.SharedTable,
                                                                     int numberOfBits = HashingEstimator.Defaults.NumberOfBits,
                                                                     bool combine     = HashingEstimator.Defaults.Combine,
                                                                     uint hashingSeed = HashingEstimator.Defaults.Seed)
        {
            var env = CatalogUtils.GetEnvironment(catalog);

            env.CheckValue(columns, nameof(columns));

            builder = builder ?? new CMCountTableBuilder();

            CountTargetEncodingEstimator estimator;

            if (sharedTable)
            {
                var columnOptions = new CountTableEstimator.SharedColumnOptions[columns.Length];
                for (int i = 0; i < columns.Length; i++)
                {
                    columnOptions[i] = new CountTableEstimator.SharedColumnOptions(
                        columns[i].OutputColumnName, columns[i].InputColumnName, priorCoefficient, laplaceScale);
                }
                estimator = new CountTargetEncodingEstimator(env, labelColumn, columnOptions, builder, numberOfBits, combine, hashingSeed);
            }
            else
            {
                var columnOptions = new CountTableEstimator.ColumnOptions[columns.Length];
                for (int i = 0; i < columns.Length; i++)
                {
                    columnOptions[i] = new CountTableEstimator.ColumnOptions(
                        columns[i].OutputColumnName, columns[i].InputColumnName, builder, priorCoefficient, laplaceScale);
                }
                estimator = new CountTargetEncodingEstimator(env, labelColumn, columnOptions, numberOfBits: numberOfBits, combine: combine, hashingSeed: hashingSeed);
            }
            return(estimator);
        }