/// <summary> /// Fits the scored <see cref="IDataView"/> creating a <see cref="CalibratorTransformer{TICalibrator}"/> that can transform the data by adding a /// <see cref="DefaultColumnNames.Probability"/> column containing the calibrated <see cref="DefaultColumnNames.Score"/>. /// </summary> /// <param name="input"></param> /// <returns>A trained <see cref="CalibratorTransformer{TICalibrator}"/> that will transform the data by adding the /// <see cref="DefaultColumnNames.Probability"/> column.</returns> public CalibratorTransformer <TICalibrator> Fit(IDataView input) { TICalibrator calibrator = null; var roles = new List <KeyValuePair <RoleMappedSchema.ColumnRole, string> >(); roles.Add(RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, DefaultColumnNames.Score)); roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(LabelColumn.Name)); roles.Add(RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name)); if (WeightColumn.IsValid) { roles.Add(RoleMappedSchema.ColumnRole.Weight.Bind(WeightColumn.Name)); } var roleMappedData = new RoleMappedData(input, opt: false, roles.ToArray()); using (var ch = Host.Start("Creating calibrator.")) calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, CalibratorTrainer, Predictor, roleMappedData); return(Create(Host, calibrator)); }
public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetricsInput input) { var eval = GetEvaluator(env, input.Kind); var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( idv => RoleMappedData.CreateOpt(idv, new[] { RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn), RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value), RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value) })).ToArray(), out var variableSizeVectorColumnNames); var warnings = input.Warnings != null ? new List <IDataView>(input.Warnings) : new List <IDataView>(); if (variableSizeVectorColumnNames.Length > 0) { var dvBldr = new ArrayDataViewBuilder(env); var warn = $"Detected columns of variable length: {string.Join(", ", variableSizeVectorColumnNames)}." + $" Consider setting collateMetrics- for meaningful per-Folds results."; dvBldr.AddColumn(MetricKinds.ColumnNames.WarningText, TextType.Instance, new DvText(warn)); warnings.Add(dvBldr.GetDataView()); } env.Assert(Utils.Size(perInst) == 1); var overall = eval.GetOverallResults(input.OverallMetrics); overall = EvaluateUtils.CombineFoldMetricsDataViews(env, overall, input.OverallMetrics.Length); IDataView conf = null; if (Utils.Size(input.ConfusionMatrix) > 0) { EvaluateUtils.ReconcileSlotNames <double>(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8); for (int i = 0; i < input.ConfusionMatrix.Length; i++) { var idv = input.ConfusionMatrix[i]; // Find the old Count column and drop it. for (int col = 0; col < idv.Schema.ColumnCount; col++) { if (idv.Schema.IsHidden(col) && idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count)) { input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { col } }, idv); break; } } } conf = EvaluateUtils.ConcatenateOverallMetrics(env, input.ConfusionMatrix); } var warningsIdv = warnings.Count > 0 ? AppendRowsDataView.Create(env, warnings[0].Schema, warnings.ToArray()) : null; return(new CombinedOutput() { PerInstanceMetrics = perInst[0], OverallMetrics = overall, ConfusionMatrix = conf, Warnings = warningsIdv }); }
public CountTableTransformer Fit(IDataView input) { var labelCol = input.Schema.GetColumnOrNull(_labelColumnName); if (labelCol == null) { throw _host.ExceptUserArg(nameof(_labelColumnName), "Label column '{0}' not found", _labelColumnName); } CheckLabelType(new RoleMappedData(input, roles: RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, _labelColumnName)), out var labelCardinality); var labelColumnType = labelCol.GetValueOrDefault().Type; var labelClassNames = InitLabelClassNames(_host, labelCol.GetValueOrDefault(), labelCardinality); var n = _columns.Length; var inputColumns = new DataViewSchema.Column[_columns.Length]; for (int i = 0; i < inputColumns.Length; i++) { var col = input.Schema.GetColumnOrNull(_columns[i].InputColumnName); if (col == null) { throw _host.Except($"Could not find column {_columns[i].InputColumnName} in input schema"); } inputColumns[i] = col.GetValueOrDefault(); } _host.Assert(_initialCounts != null || _sharedBuilder != null || _builders != null); MultiCountTableBuilderBase multiBuilder; if (_initialCounts != null) { multiBuilder = _initialCounts.Featurizer.MultiCountTable.ToBuilder(_host, inputColumns, labelCardinality); } else if (_builders != null) { multiBuilder = new ParallelMultiCountTableBuilder(_host, inputColumns, _builders, labelCardinality); } else { multiBuilder = new BagMultiCountTableBuilder(_host, inputColumns, _sharedBuilder, labelCardinality); } var cols = new List <DataViewSchema.Column>(); foreach (var c in _columns) { var col = input.Schema.GetColumnOrNull(c.InputColumnName); _host.Assert(col.HasValue); cols.Add(col.Value); } TrainTables(input, cols, multiBuilder, labelCol.GetValueOrDefault()); var multiCountTable = multiBuilder.CreateMultiCountTable(); var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable); return(new CountTableTransformer(_host, featurizer, labelClassNames, _columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray())); }