internal CountTargetEncodingEstimator(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(CountTargetEncodingEstimator)); _host.CheckValue(options, nameof(options)); _host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified"); _host.CheckUserArg(!string.IsNullOrWhiteSpace(options.LabelColumn), nameof(options.LabelColumn), "Must specify the label column name"); if (!string.IsNullOrEmpty(options.InitialCountsModel)) { _countTableEstimator = LoadFromFile(env, options.InitialCountsModel, options.LabelColumn, options.Columns.Select(col => new InputOutputColumnPair(col.Name)).ToArray()); if (_countTableEstimator == null) { throw env.Except($"The file {options.InitialCountsModel} does not contain a CountTableTransformer"); } } else if (options.SharedTable) { var columns = new CountTableEstimator.SharedColumnOptions[options.Columns.Length]; for (int i = 0; i < options.Columns.Length; i++) { var column = options.Columns[i]; columns[i] = new CountTableEstimator.SharedColumnOptions( column.Name, column.Name, column.PriorCoefficient ?? options.PriorCoefficient, column.LaplaceScale ?? options.LaplaceScale, column.Seed ?? options.Seed); } var builder = options.CountTable; _host.CheckValue(builder, nameof(options.CountTable)); _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, builder.CreateComponent(_host), columns); } else { var columns = new CountTableEstimator.ColumnOptions[options.Columns.Length]; for (int i = 0; i < options.Columns.Length; i++) { var column = options.Columns[i]; var builder = column.CountTable ?? options.CountTable; _host.CheckValue(builder, nameof(options.CountTable)); columns[i] = new CountTableEstimator.ColumnOptions( column.Name, column.Name, builder.CreateComponent(_host), column.PriorCoefficient ?? options.PriorCoefficient, column.LaplaceScale ?? options.LaplaceScale, column.Seed ?? options.Seed); } _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, columns); } _hashingColumns = InitializeHashingColumnOptions(options); _hashingEstimator = new HashingEstimator(_host, _hashingColumns); }
private CountTargetEncodingEstimator(IHostEnvironment env, CountTableEstimator estimator, CountTableEstimator.ColumnOptionsBase[] columns, int numberOfBits, bool combine, uint hashingSeed) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(CountTargetEncodingEstimator)); _hashingColumns = InitializeHashingColumnOptions(columns, numberOfBits, combine, hashingSeed); _hashingEstimator = new HashingEstimator(_host, _hashingColumns); _countTableEstimator = estimator; }
internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTargetEncodingTransformer initialCounts, params InputOutputColumnPair[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(CountTargetEncodingEstimator)); _host.CheckValue(initialCounts, nameof(initialCounts)); _host.CheckNonEmpty(columns, nameof(columns)); _host.Check(initialCounts.VerifyColumns(columns), nameof(columns)); _hashingEstimator = new HashingEstimator(_host, initialCounts.HashingTransformer.Columns.ToArray()); _countTableEstimator = new CountTableEstimator(_host, labelColumnName, initialCounts.CountTable, columns.Select(c => new InputOutputColumnPair(c.OutputColumnName, c.OutputColumnName)).ToArray()); }