Example #1
0
        /// <summary>
        /// Fits the scored <see cref="IDataView"/> creating a <see cref="CalibratorTransformer{TICalibrator}"/> that can transform the data by adding a
        /// <see cref="DefaultColumnNames.Probability"/> column containing the calibrated <see cref="DefaultColumnNames.Score"/>.
        /// </summary>
        /// <param name="input"></param>
        /// <returns>A trained <see cref="CalibratorTransformer{TICalibrator}"/> that will transform the data by adding the
        /// <see cref="DefaultColumnNames.Probability"/> column.</returns>
        public CalibratorTransformer <TICalibrator> Fit(IDataView input)
        {
            TICalibrator calibrator = null;

            var roles = new List <KeyValuePair <RoleMappedSchema.ColumnRole, string> >();

            roles.Add(RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, DefaultColumnNames.Score));
            roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(LabelColumn.Name));
            roles.Add(RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name));
            if (WeightColumn.IsValid)
            {
                roles.Add(RoleMappedSchema.ColumnRole.Weight.Bind(WeightColumn.Name));
            }

            var roleMappedData = new RoleMappedData(input, opt: false, roles.ToArray());

            using (var ch = Host.Start("Creating calibrator."))
                calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, CalibratorTrainer, Predictor, roleMappedData);

            return(Create(Host, calibrator));
        }
        public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetricsInput input)
        {
            var eval = GetEvaluator(env, input.Kind);

            var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select(
                                                                            idv => RoleMappedData.CreateOpt(idv, new[]
            {
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value)
            })).ToArray(),
                                                                        out var variableSizeVectorColumnNames);

            var warnings = input.Warnings != null ? new List <IDataView>(input.Warnings) : new List <IDataView>();

            if (variableSizeVectorColumnNames.Length > 0)
            {
                var dvBldr = new ArrayDataViewBuilder(env);
                var warn   = $"Detected columns of variable length: {string.Join(", ", variableSizeVectorColumnNames)}." +
                             $" Consider setting collateMetrics- for meaningful per-Folds results.";
                dvBldr.AddColumn(MetricKinds.ColumnNames.WarningText, TextType.Instance, new DvText(warn));
                warnings.Add(dvBldr.GetDataView());
            }

            env.Assert(Utils.Size(perInst) == 1);

            var overall = eval.GetOverallResults(input.OverallMetrics);

            overall = EvaluateUtils.CombineFoldMetricsDataViews(env, overall, input.OverallMetrics.Length);

            IDataView conf = null;

            if (Utils.Size(input.ConfusionMatrix) > 0)
            {
                EvaluateUtils.ReconcileSlotNames <double>(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8);

                for (int i = 0; i < input.ConfusionMatrix.Length; i++)
                {
                    var idv = input.ConfusionMatrix[i];
                    // Find the old Count column and drop it.
                    for (int col = 0; col < idv.Schema.ColumnCount; col++)
                    {
                        if (idv.Schema.IsHidden(col) &&
                            idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count))
                        {
                            input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env,
                                                                                         new ChooseColumnsByIndexTransform.Arguments()
                            {
                                Drop = true, Index = new[] { col }
                            }, idv);
                            break;
                        }
                    }
                }

                conf = EvaluateUtils.ConcatenateOverallMetrics(env, input.ConfusionMatrix);
            }

            var warningsIdv = warnings.Count > 0 ? AppendRowsDataView.Create(env, warnings[0].Schema, warnings.ToArray()) : null;

            return(new CombinedOutput()
            {
                PerInstanceMetrics = perInst[0],
                OverallMetrics = overall,
                ConfusionMatrix = conf,
                Warnings = warningsIdv
            });
        }
        public CountTableTransformer Fit(IDataView input)
        {
            var labelCol = input.Schema.GetColumnOrNull(_labelColumnName);

            if (labelCol == null)
            {
                throw _host.ExceptUserArg(nameof(_labelColumnName), "Label column '{0}' not found", _labelColumnName);
            }

            CheckLabelType(new RoleMappedData(input, roles: RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, _labelColumnName)), out var labelCardinality);

            var labelColumnType = labelCol.GetValueOrDefault().Type;
            var labelClassNames = InitLabelClassNames(_host, labelCol.GetValueOrDefault(), labelCardinality);

            var n = _columns.Length;

            var inputColumns = new DataViewSchema.Column[_columns.Length];

            for (int i = 0; i < inputColumns.Length; i++)
            {
                var col = input.Schema.GetColumnOrNull(_columns[i].InputColumnName);
                if (col == null)
                {
                    throw _host.Except($"Could not find column {_columns[i].InputColumnName} in input schema");
                }
                inputColumns[i] = col.GetValueOrDefault();
            }

            _host.Assert(_initialCounts != null || _sharedBuilder != null || _builders != null);
            MultiCountTableBuilderBase multiBuilder;

            if (_initialCounts != null)
            {
                multiBuilder = _initialCounts.Featurizer.MultiCountTable.ToBuilder(_host, inputColumns, labelCardinality);
            }
            else if (_builders != null)
            {
                multiBuilder = new ParallelMultiCountTableBuilder(_host, inputColumns, _builders, labelCardinality);
            }
            else
            {
                multiBuilder = new BagMultiCountTableBuilder(_host, inputColumns, _sharedBuilder, labelCardinality);
            }

            var cols = new List <DataViewSchema.Column>();

            foreach (var c in _columns)
            {
                var col = input.Schema.GetColumnOrNull(c.InputColumnName);
                _host.Assert(col.HasValue);
                cols.Add(col.Value);
            }

            TrainTables(input, cols, multiBuilder, labelCol.GetValueOrDefault());

            var multiCountTable = multiBuilder.CreateMultiCountTable();

            var featurizer = new CountTargetEncodingFeaturizer(_host, _columns.Select(col => col.PriorCoefficient).ToArray(), _columns.Select(col => col.LaplaceScale).ToArray(), labelCardinality, multiCountTable);

            return(new CountTableTransformer(_host, featurizer, labelClassNames,
                                             _columns.Select(col => col.Seed).ToArray(), _columns.Select(col => (col.Name, col.InputColumnName)).ToArray()));
        }