public SemiSupervisedClassificationStatisticsAggregator(int nClasses, double a, double b) { nClasses_ = nClasses; a_ = a; b_ = b; GaussianAggregator2d = new GaussianAggregator2d(a, b); HistogramAggregator = new HistogramAggregator(nClasses); }
public double ComputeInformationGain(SemiSupervisedClassificationStatisticsAggregator allStatistics, SemiSupervisedClassificationStatisticsAggregator leftStatistics, SemiSupervisedClassificationStatisticsAggregator rightStatistics) { double informationGainLabelled; { double entropyBefore = allStatistics.HistogramAggregator.Entropy(); HistogramAggregator leftHistogram = leftStatistics.HistogramAggregator; HistogramAggregator rightHistogram = rightStatistics.HistogramAggregator; int nTotalSamples = leftHistogram.SampleCount + rightHistogram.SampleCount; if (nTotalSamples <= 1) { informationGainLabelled = 0; } else { double entropyAfter = (leftHistogram.SampleCount * leftHistogram.Entropy() + rightHistogram.SampleCount * rightHistogram.Entropy()) / nTotalSamples; informationGainLabelled = entropyBefore - entropyAfter; } } double informationGainUnlabelled; { double entropyBefore = ((SemiSupervisedClassificationStatisticsAggregator)(allStatistics)).GaussianAggregator2d.GetPdf().Entropy(); GaussianAggregator2d leftGaussian = leftStatistics.GaussianAggregator2d; GaussianAggregator2d rightGaussian = rightStatistics.GaussianAggregator2d; int nTotalSamples = leftGaussian.SampleCount + rightGaussian.SampleCount; double entropyAfter = (leftGaussian.SampleCount * leftGaussian.GetPdf().Entropy() + rightGaussian.SampleCount * rightGaussian.GetPdf().Entropy()) / nTotalSamples; informationGainUnlabelled = entropyBefore - entropyAfter; } double gain = informationGainLabelled + alpha_ * informationGainUnlabelled; return(gain); }