/// <summary> /// Compute the point-wise pseudo-response of the loss function to be optimized /// It is used to build the decision tree - except from computing the response value of a leaf node /// </summary> /// <param name="dataSet">all training data</param> public void ComputePseudoResponse(DataSet dataSet) { ResetParameters(); for (int qIdx = 0; qIdx < dataSet.NumGroups; qIdx++) { DataGroup query = dataSet.GetDataGroup(qIdx); RankPairGenerator rankPairs = new RankPairGenerator(query, this.labels); foreach (RankPair rankPair in rankPairs) { float scoreH_minus_scoreL = this.score[rankPair.IdxH] - this.score[rankPair.IdxL]; float gradient = RankPair.CrossEntropyDerivative(scoreH_minus_scoreL); this.pseudoResponses[rankPair.IdxH] += gradient; this.pseudoResponses[rankPair.IdxL] -= gradient; float weight = RankPair.CrossEntropy2ndDerivative(this.score[rankPair.IdxH] - this.score[rankPair.IdxL]); this.weights[rankPair.IdxH] += weight; this.weights[rankPair.IdxL] += weight; } //this.labelFeatureData.PartitionData; } }
/// <summary> /// Compute the point-wise pseudo-response of the loss function to be optimized /// It is used to build the decision tree - except from computing the response value of a leaf node /// </summary> /// <param name="dataSet">all training data</param> public void ComputePseudoResponse(DataSet dataSet) { // Reset/(zero out) pseudoResponse and weights for a new iteration ResetParameters(); for (int qIdx = 0; qIdx < dataSet.NumGroups; qIdx++) { DataGroup queryGroup = dataSet.GetDataGroup(qIdx); RankPairGenerator rankPairs = new RankPairGenerator(queryGroup, this.labels); Query query = this.qcAccum.queries[dataSet.GroupIndex[qIdx]]; ; query.UpdateScores(this.score, queryGroup.iStart); query.ComputeRank(); foreach (RankPair rankPair in rankPairs) { float scoreH_minus_scoreL = this.score[rankPair.IdxH] - this.score[rankPair.IdxL]; //compute the cross-entropy gradient of the pair float gradient = RankPair.CrossEntropyDerivative(scoreH_minus_scoreL); //compute the absolute change in NDCG if we swap the pair in the current ordering float absDeltaPosition = AbsDeltaPosition(rankPair, queryGroup, query); // Marginalize the pair-wise gradient to get point wise gradient. The point with higher relevance label (IdxH) always gets // a positive push (i.e. upwards). this.pseudoResponses[rankPair.IdxH] += gradient * absDeltaPosition; this.pseudoResponses[rankPair.IdxL] -= gradient * absDeltaPosition; // Note that the weights are automatically always positive float weight = absDeltaPosition * RankPair.CrossEntropy2ndDerivative(this.score[rankPair.IdxH] - this.score[rankPair.IdxL]); this.weights[rankPair.IdxH] += weight; this.weights[rankPair.IdxL] += weight; } } for (int i = 0; i < dataSet.NumSamples; i++) { int dataIdx = dataSet.DataIndex[i]; //incorporating the gradient of the label this.pseudoResponses[dataIdx] = (1 - this.labelWeight) * this.pseudoResponses[dataIdx] + this.labelWeight * (this.labels[dataIdx] - this.score[dataIdx]); this.weights[dataIdx] = (1 - this.labelWeight) * this.weights[dataIdx] + this.labelWeight * 1; } }
protected override void ComputeMetrics(float[] labels, float[][] inScores, DataPartitionType dataType, float[] metrics) { //results initialization for (int i = 0; i < metrics.Length; i++) { metrics[i] = 0; } //(1) Compute NDCG, pairwise error, and cross-entropy DataSet dataSet = this.labelFeatureCore.DataGroups.GetDataPartition(dataType); int[] groupIndex = dataSet.GroupIndex; if (groupIndex != null && groupIndex.Length>0) { float[] scores = inScores[0]; double cQueries = 0; double totalErrRate = 0; double totalCrossEnt = 0; for (int i = 0; i < groupIndex.Length; i++) { DataGroup query = this.labelFeatureCore.DataGroups[groupIndex[i]]; RankPairGenerator rankPairs = new RankPairGenerator(query, labels); double cErr = 0; double CrossEnt = 0; double cPairs = 0; foreach (RankPair rankPair in rankPairs) { float scoreH_minus_scoreL = scores[rankPair.IdxH] - scores[rankPair.IdxL]; CrossEnt += RankPair.CrossEntropy(scoreH_minus_scoreL); if (scoreH_minus_scoreL <= 0) { cErr++; } cPairs++; } if (cPairs > 0) // equivalent to !emptyQuery { totalErrRate += (cErr / cPairs); totalCrossEnt += (CrossEnt / cPairs); cQueries++; } else { if (!ndcg.DropEmptyQueries) { totalErrRate += 0.0F; totalCrossEnt += 0.0F; cQueries++; } } } metrics[(int)PairwiseType.PairCrossEnt] = (float)(totalCrossEnt / cQueries); metrics[(int)PairwiseType.PairError] = (float)(totalErrRate / cQueries); } }