internal CountTableTransformer(IHostEnvironment env, CountTargetEncodingFeaturizer featurizer, string[] labelClassNames, int[] seeds, (string outputColumnName, string inputColumnName)[] columns)
public CountTableTransformer Fit(IDataView input) { var labelCol = input.Schema.GetColumnOrNull(_labelColumnName); if (labelCol == null) { throw _host.ExceptUserArg(nameof(_labelColumnName), "Label column '{0}' not found", _labelColumnName); } CheckLabelType(new RoleMappedData(input, roles: RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, _labelColumnName)), out var labelCardinality); var labelColumnType = labelCol.GetValueOrDefault().Type; var labelClassNames = InitLabelClassNames(_host, labelCol.GetValueOrDefault(), labelCardinality); var n = _columns.Length; var inputColumns = new DataViewSchema.Column[_columns.Length]; for (int i = 0; i < inputColumns.Length; i++) { var col = input.Schema.GetColumnOrNull(_columns[i].InputColumnName); if (col == null) { throw _host.Except($"Could not find column {_columns[i].InputColumnName} in input schema"); } inputColumns[i] = col.GetValueOrDefault(); } _host.Assert(_initialCounts != null || _sharedBuilder != null || _builders != null); MultiCountTableBuilderBase multiBuilder; if (_initialCounts != null) { multiBuilder = _initialCounts.Featurizer.MultiCountTable.ToBuilder(_host, inputColumns, labelCardinality); } else if (_builders != null) { multiBuilder = new ParallelMultiCountTableBuilder(_host, inputColumns, _builders, labelCardinality); } else { multiBuilder = new BagMultiCountTableBuilder(_host, inputColumns, _sharedBuilder, labelCardinality); } var cols = new List <DataViewSchema.Column>(); foreach (var c in _columns) { var col = input.Schema.GetColumnOrNull(c.InputColumnName); _host.Assert(col.HasValue); cols.Add(col.Value); } TrainTables(input, cols, multiBuilder, labelCol.GetValueOrDefault()); var multiCountTable = multiBuilder.CreateMultiCountTable(); var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable); return(new CountTableTransformer(_host, featurizer, labelClassNames, _columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray())); }