/// <summary> /// Training algorithm for the single-feature functions f(x) /// </summary> /// <param name="ch">The channel to write to</param> private void TrainMainEffectsModel(IChannel ch) { Contracts.AssertValue(ch); int iterations = Args.NumIterations; ch.Info("Starting to train ..."); using (var pch = Host.StartProgressChannel("GAM training")) { _objectiveFunction = CreateObjectiveFunction(); var sumWeights = HasWeights ? TrainSet.SampleWeights.Sum() : 0; int iteration = 0; pch.SetHeader(new ProgressHeader("iterations"), e => e.SetProgress(0, iteration, iterations)); for (int i = iteration; iteration < iterations; iteration++) { using (Timer.Time(TimerEvent.Iteration)) { var gradient = _objectiveFunction.GetGradient(ch, TrainSetScore.Scores); var sumTargets = gradient.Sum(); SumUpsAcrossFlocks(gradient, sumTargets, sumWeights); TrainOnEachFeature(gradient, TrainSetScore.Scores, sumTargets, sumWeights, iteration); UpdateScores(iteration); } } } CombineGraphs(ch); }
private void Initialize(IChannel ch) { using (Timer.Time(TimerEvent.InitializeTraining)) { InitializeGamHistograms(); _subGraph = new SubGraph(TrainSet.NumFeatures, Args.NumIterations); _leafSplitCandidates = new LeastSquaresRegressionTreeLearner.LeafSplitCandidates(TrainSet); _leafSplitHelper = new LeafSplitHelper(HasWeights); } }
private void TrainCore(IChannel ch) { Contracts.CheckValue(ch, nameof(ch)); // REVIEW:Get rid of this lock then we completly remove all static classes from Gam such as BlockingThreadPool. lock (FastTreeShared.TrainLock) { using (Timer.Time(TimerEvent.TotalInitialization)) Initialize(ch); using (Timer.Time(TimerEvent.TotalTrain)) TrainMainEffectsModel(ch); } }