コード例 #1
0
            public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
                             int truncationLevel, Double[] labelGains)
                : base(env, input, labelCol, scoreCol, groupCol, RegistrationName)
            {
                Host.CheckParam(0 < truncationLevel, nameof(truncationLevel),
                                "Truncation level must be greater than 0");
                Host.CheckValue(labelGains, nameof(labelGains));

                _truncationLevel = truncationLevel;
                _discountMap     = RankingUtils.GetDiscountMap(_truncationLevel);
                _labelGains      = labelGains;
                _bindings        = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel);
            }
コード例 #2
0
 protected override void UpdateState(RowCursorState state)
 {
     // Calculate the current group DCG, NDCG and MaxDcg.
     RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs,
                              state.MaxDcgCur);
     RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur);
     for (int t = 0; t < _truncationLevel; t++)
     {
         Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0;
         state.NdcgCur[t] = ndcg;
     }
     state.QueryLabels.Clear();
     state.QueryOutputs.Clear();
 }
コード例 #3
0
                public void UpdateGroup(Single weight)
                {
                    RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur);
                    if (_groupMaxDcg != null)
                    {
                        var maxDcg = new Double[TruncationLevel];
                        Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel);
                        _groupMaxDcg.Add(maxDcg);
                    }

                    RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur);
                    if (_groupDcg != null)
                    {
                        var groupDcg = new Double[TruncationLevel];
                        Array.Copy(_groupDcgCur, groupDcg, TruncationLevel);
                        _groupDcg.Add(groupDcg);
                    }

                    var groupNdcg = new Double[TruncationLevel];

                    for (int t = 0; t < TruncationLevel; t++)
                    {
                        Double ndcg = _groupMaxDcgCur[t] > 0 ? _groupDcgCur[t] / _groupMaxDcgCur[t] : 0;
                        _sumNdcgAtN[t] += ndcg * weight;
                        _sumDcgAtN[t]  += _groupDcgCur[t] * weight;
                        groupNdcg[t]    = ndcg;
                    }
                    _sumWeights += weight;

                    if (_groupNdcg != null)
                    {
                        _groupNdcg.Add(groupNdcg);
                    }

                    _queryLabels.Clear();
                    _queryOutputs.Clear();
                }
コード例 #4
0
                public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
                {
                    Contracts.Assert(truncationLevel > 0);
                    Contracts.AssertValue(labelGains);

                    TruncationLevel = truncationLevel;
                    _discountMap    = RankingUtils.GetDiscountMap(truncationLevel);

                    _sumDcgAtN  = new Double[TruncationLevel];
                    _sumNdcgAtN = new Double[TruncationLevel];

                    _groupDcgCur    = new Double[TruncationLevel];
                    _groupMaxDcgCur = new Double[TruncationLevel];
                    if (groupSummary)
                    {
                        _groupNdcg   = new List <Double[]>();
                        _groupDcg    = new List <Double[]>();
                        _groupMaxDcg = new List <Double[]>();
                    }

                    _queryLabels  = new List <short>();
                    _queryOutputs = new List <Single>();
                    _labelGains   = labelGains;
                }