public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol, int truncationLevel, Double[] labelGains) : base(env, input, labelCol, scoreCol, groupCol, RegistrationName) { Host.CheckParam(0 < truncationLevel, nameof(truncationLevel), "Truncation level must be greater than 0"); Host.CheckValue(labelGains, nameof(labelGains)); _truncationLevel = truncationLevel; _discountMap = RankingUtils.GetDiscountMap(_truncationLevel); _labelGains = labelGains; _bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel); }
protected override void UpdateState(RowCursorState state) { // Calculate the current group DCG, NDCG and MaxDcg. RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.MaxDcgCur); RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur); for (int t = 0; t < _truncationLevel; t++) { Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0; state.NdcgCur[t] = ndcg; } state.QueryLabels.Clear(); state.QueryOutputs.Clear(); }
public void UpdateGroup(Single weight) { RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur); if (_groupMaxDcg != null) { var maxDcg = new Double[TruncationLevel]; Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel); _groupMaxDcg.Add(maxDcg); } RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur); if (_groupDcg != null) { var groupDcg = new Double[TruncationLevel]; Array.Copy(_groupDcgCur, groupDcg, TruncationLevel); _groupDcg.Add(groupDcg); } var groupNdcg = new Double[TruncationLevel]; for (int t = 0; t < TruncationLevel; t++) { Double ndcg = _groupMaxDcgCur[t] > 0 ? _groupDcgCur[t] / _groupMaxDcgCur[t] : 0; _sumNdcgAtN[t] += ndcg * weight; _sumDcgAtN[t] += _groupDcgCur[t] * weight; groupNdcg[t] = ndcg; } _sumWeights += weight; if (_groupNdcg != null) { _groupNdcg.Add(groupNdcg); } _queryLabels.Clear(); _queryOutputs.Clear(); }
public Counters(Double[] labelGains, int truncationLevel, bool groupSummary) { Contracts.Assert(truncationLevel > 0); Contracts.AssertValue(labelGains); TruncationLevel = truncationLevel; _discountMap = RankingUtils.GetDiscountMap(truncationLevel); _sumDcgAtN = new Double[TruncationLevel]; _sumNdcgAtN = new Double[TruncationLevel]; _groupDcgCur = new Double[TruncationLevel]; _groupMaxDcgCur = new Double[TruncationLevel]; if (groupSummary) { _groupNdcg = new List <Double[]>(); _groupDcg = new List <Double[]>(); _groupMaxDcg = new List <Double[]>(); } _queryLabels = new List <short>(); _queryOutputs = new List <Single>(); _labelGains = labelGains; }