protected override void GetGradientInOneQuery(int query, int threadIndex) { int begin = Dataset.Boundaries[query]; int end = Dataset.Boundaries[query + 1]; short[] labels = Dataset.Ratings; float[] sizes = SizeAdjustedLogLossUtil.GetSizeLabels(Dataset); Contracts.Check(Dataset.NumDocs == sizes.Length, "Mismatch between dataset and labels"); if (end - begin <= 1) { return; } Array.Clear(_gradient, begin, end - begin); for (int d1 = begin; d1 < end - 1; ++d1) { for (int d2 = d1 + 1; d2 < end; ++d2) { float size = sizes[d1]; //Compute Lij float sizeAdjustedLoss = 0.0F; for (int d3 = d2; d3 < end; ++d3) { size -= sizes[d3]; if (size >= 0.0F && labels[d3] > 0) { sizeAdjustedLoss = 1.0F; } else if (size < 0.0F && labels[d3] > 0) { sizeAdjustedLoss = (1.0F + (size / sizes[d3])); } if (size <= 0.0F || sizeAdjustedLoss > 0.0F) { // Exit condition- we have reached size or size adjusted loss is already populated. break; } } double scoreDiff = _scores[d1] - _scores[d2]; float labelDiff = ((float)labels[d1] - sizeAdjustedLoss); double delta = 0.0; if (_algo == SizeAdjustedLogLossCommandLineArgs.CostFunctionMode.SizeAdjustedPageOrdering) { delta = (_llc * labelDiff) / (1.0 + Math.Exp(_llc * labelDiff * scoreDiff)); } else { delta = (double)labels[d1] - ((double)(labels[d1] + sizeAdjustedLoss) / (1.0 + Math.Exp(-scoreDiff))); } _gradient[d1] += delta; _gradient[d2] -= delta; if (_mode == SizeAdjustedLogLossCommandLineArgs.LogLossMode.Pairwise) { break; } } } }
protected override Test ConstructTestForTrainingData() { return(new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef)); }
protected override void InitializeTests() { Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TrainSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet), cmd.loglosscoef)); if (ValidSet != null) { Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(ValidSet), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(ValidSet), cmd.loglosscoef)); } if (TestSets != null && TestSets.Length > 0) { for (int t = 0; t < TestSets.Length; ++t) { Tests.Add(new SizeAdjustedLogLossTest(ConstructScoreTracker(TestSets[t]), cmd.scoreRangeFileName, SizeAdjustedLogLossUtil.GetSizeLabels(TestSets[t]), cmd.loglosscoef)); } } }
protected override void PrepareLabels(IChannel ch) { trainSetSizeLabels = SizeAdjustedLogLossUtil.GetSizeLabels(TrainSet); }