Example #1
0
        /// <summary>
        /// Compute the point-wise pseudo-response of the loss function to be optimized
        /// It is used to build the decision tree - except from computing the response value of a leaf node
        /// </summary>        
        /// <param name="dataSet">all training data</param>
        public void ComputePseudoResponse(DataSet dataSet)
        {
            ResetParameters();
            for (int qIdx = 0; qIdx < dataSet.NumGroups; qIdx++)
            {
                DataGroup query = dataSet.GetDataGroup(qIdx);
                RankPairGenerator rankPairs = new RankPairGenerator(query, this.labels);
                foreach (RankPair rankPair in rankPairs)
                {
                    float scoreH_minus_scoreL = this.score[rankPair.IdxH] - this.score[rankPair.IdxL];
                    float gradient = RankPair.CrossEntropyDerivative(scoreH_minus_scoreL);

                    this.pseudoResponses[rankPair.IdxH] += gradient;
                    this.pseudoResponses[rankPair.IdxL] -= gradient;

                    float weight = RankPair.CrossEntropy2ndDerivative(this.score[rankPair.IdxH] - this.score[rankPair.IdxL]);

                    this.weights[rankPair.IdxH] += weight;
                    this.weights[rankPair.IdxL] += weight;

                }
                //this.labelFeatureData.PartitionData;
            }
        }
Example #2
0
        /// <summary>
        /// Compute the point-wise pseudo-response of the loss function to be optimized
        /// It is used to build the decision tree - except from computing the response value of a leaf node
        /// </summary>    
        /// <param name="dataSet">all training data</param>
        public void ComputePseudoResponse(DataSet dataSet)
        {
            // Reset/(zero out) pseudoResponse and weights for a new iteration
            ResetParameters();

            for (int qIdx = 0; qIdx < dataSet.NumGroups; qIdx++)
            {
                DataGroup queryGroup = dataSet.GetDataGroup(qIdx);
                RankPairGenerator rankPairs = new RankPairGenerator(queryGroup, this.labels);
                Query query = this.qcAccum.queries[dataSet.GroupIndex[qIdx]]; ;
                query.UpdateScores(this.score, queryGroup.iStart);
                query.ComputeRank();
                foreach (RankPair rankPair in rankPairs)
                {
                    float scoreH_minus_scoreL = this.score[rankPair.IdxH] - this.score[rankPair.IdxL];
                    //compute the cross-entropy gradient of the pair
                    float gradient = RankPair.CrossEntropyDerivative(scoreH_minus_scoreL);
                    //compute the absolute change in NDCG if we swap the pair in the current ordering
                    float absDeltaPosition = AbsDeltaPosition(rankPair, queryGroup, query);

                    // Marginalize the pair-wise gradient to get point wise gradient.  The point with higher relevance label (IdxH) always gets
                    // a positive push (i.e. upwards).
                    this.pseudoResponses[rankPair.IdxH] += gradient * absDeltaPosition;
                    this.pseudoResponses[rankPair.IdxL] -= gradient * absDeltaPosition;

                    // Note that the weights are automatically always positive
                    float weight = absDeltaPosition * RankPair.CrossEntropy2ndDerivative(this.score[rankPair.IdxH] - this.score[rankPair.IdxL]);
                    this.weights[rankPair.IdxH] += weight;
                    this.weights[rankPair.IdxL] += weight;
                }
            }

            for (int i = 0; i < dataSet.NumSamples; i++)
            {
                int dataIdx = dataSet.DataIndex[i];
                //incorporating the gradient of the label
                this.pseudoResponses[dataIdx] = (1 - this.labelWeight) * this.pseudoResponses[dataIdx] + this.labelWeight * (this.labels[dataIdx] - this.score[dataIdx]);
                this.weights[dataIdx] = (1 - this.labelWeight) * this.weights[dataIdx] + this.labelWeight * 1;
            }
        }
Example #3
0
        protected override void ComputeMetrics(float[] labels, float[][] inScores, DataPartitionType dataType, float[] metrics)
        {
            //results initialization
            for (int i = 0; i < metrics.Length; i++)
            {
                metrics[i] = 0;
            }

            //(1) Compute NDCG, pairwise error, and cross-entropy
            DataSet dataSet = this.labelFeatureCore.DataGroups.GetDataPartition(dataType);
            int[] groupIndex = dataSet.GroupIndex;
            if (groupIndex != null && groupIndex.Length>0)
            {
                float[] scores = inScores[0];

                double cQueries = 0;

                double totalErrRate = 0;
                double totalCrossEnt = 0;

                for (int i = 0; i < groupIndex.Length; i++)
                {
                    DataGroup query = this.labelFeatureCore.DataGroups[groupIndex[i]];
                    RankPairGenerator rankPairs = new RankPairGenerator(query, labels);
                    double cErr = 0;
                    double CrossEnt = 0;
                    double cPairs = 0;
                    foreach (RankPair rankPair in rankPairs)
                    {
                        float scoreH_minus_scoreL = scores[rankPair.IdxH] - scores[rankPair.IdxL];
                        CrossEnt += RankPair.CrossEntropy(scoreH_minus_scoreL);
                        if (scoreH_minus_scoreL <= 0)
                        {
                            cErr++;
                        }
                        cPairs++;
                    }

                    if (cPairs > 0) // equivalent to !emptyQuery
                    {

                        totalErrRate += (cErr / cPairs);
                        totalCrossEnt += (CrossEnt / cPairs);
                        cQueries++;
                    }
                    else
                    {
                        if (!ndcg.DropEmptyQueries)
                        {
                            totalErrRate += 0.0F;
                            totalCrossEnt += 0.0F;
                            cQueries++;
                        }
                    }
                }

                metrics[(int)PairwiseType.PairCrossEnt] = (float)(totalCrossEnt / cQueries);
                metrics[(int)PairwiseType.PairError] = (float)(totalErrRate / cQueries);
            }
        }