コード例 #1
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(options, nameof(options));
            _host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(options.LabelColumn), nameof(options.LabelColumn), "Must specify the label column name");

            if (!string.IsNullOrEmpty(options.InitialCountsModel))
            {
                _countTableEstimator = LoadFromFile(env, options.InitialCountsModel, options.LabelColumn,
                                                    options.Columns.Select(col => new InputOutputColumnPair(col.Name)).ToArray());
                if (_countTableEstimator == null)
                {
                    throw env.Except($"The file {options.InitialCountsModel} does not contain a CountTableTransformer");
                }
            }
            else if (options.SharedTable)
            {
                var columns = new CountTableEstimator.SharedColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column = options.Columns[i];
                    columns[i] = new CountTableEstimator.SharedColumnOptions(
                        column.Name,
                        column.Name,
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                var builder = options.CountTable;
                _host.CheckValue(builder, nameof(options.CountTable));
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, builder.CreateComponent(_host), columns);
            }
            else
            {
                var columns = new CountTableEstimator.ColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column  = options.Columns[i];
                    var builder = column.CountTable ?? options.CountTable;
                    _host.CheckValue(builder, nameof(options.CountTable));
                    columns[i] = new CountTableEstimator.ColumnOptions(
                        column.Name,
                        column.Name,
                        builder.CreateComponent(_host),
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, columns);
            }

            _hashingColumns   = InitializeHashingColumnOptions(options);
            _hashingEstimator = new HashingEstimator(_host, _hashingColumns);
        }
コード例 #2
0
        private CountTargetEncodingEstimator(IHostEnvironment env, CountTableEstimator estimator, CountTableEstimator.ColumnOptionsBase[] columns,
                                             int numberOfBits, bool combine, uint hashingSeed)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));

            _hashingColumns      = InitializeHashingColumnOptions(columns, numberOfBits, combine, hashingSeed);
            _hashingEstimator    = new HashingEstimator(_host, _hashingColumns);
            _countTableEstimator = estimator;
        }
コード例 #3
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTargetEncodingTransformer initialCounts, params InputOutputColumnPair[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(initialCounts, nameof(initialCounts));
            _host.CheckNonEmpty(columns, nameof(columns));
            _host.Check(initialCounts.VerifyColumns(columns), nameof(columns));

            _hashingEstimator    = new HashingEstimator(_host, initialCounts.HashingTransformer.Columns.ToArray());
            _countTableEstimator = new CountTableEstimator(_host, labelColumnName, initialCounts.CountTable,
                                                           columns.Select(c => new InputOutputColumnPair(c.OutputColumnName, c.OutputColumnName)).ToArray());
        }