Пример #1
0
        /// <summary>
        /// Training algorithm for the single-feature functions f(x)
        /// </summary>
        /// <param name="ch">The channel to write to</param>
        private void TrainMainEffectsModel(IChannel ch)
        {
            Contracts.AssertValue(ch);
            int iterations = Args.NumIterations;

            ch.Info("Starting to train ...");

            using (var pch = Host.StartProgressChannel("GAM training"))
            {
                _objectiveFunction = CreateObjectiveFunction();
                var sumWeights = HasWeights ? TrainSet.SampleWeights.Sum() : 0;

                int iteration = 0;
                pch.SetHeader(new ProgressHeader("iterations"), e => e.SetProgress(0, iteration, iterations));
                for (int i = iteration; iteration < iterations; iteration++)
                {
                    using (Timer.Time(TimerEvent.Iteration))
                    {
                        var gradient   = _objectiveFunction.GetGradient(ch, TrainSetScore.Scores);
                        var sumTargets = gradient.Sum();

                        SumUpsAcrossFlocks(gradient, sumTargets, sumWeights);
                        TrainOnEachFeature(gradient, TrainSetScore.Scores, sumTargets, sumWeights, iteration);
                        UpdateScores(iteration);
                    }
                }
            }

            CombineGraphs(ch);
        }
Пример #2
0
 private void Initialize(IChannel ch)
 {
     using (Timer.Time(TimerEvent.InitializeTraining))
     {
         InitializeGamHistograms();
         _subGraph            = new SubGraph(TrainSet.NumFeatures, Args.NumIterations);
         _leafSplitCandidates = new LeastSquaresRegressionTreeLearner.LeafSplitCandidates(TrainSet);
         _leafSplitHelper     = new LeafSplitHelper(HasWeights);
     }
 }
Пример #3
0
 private void TrainCore(IChannel ch)
 {
     Contracts.CheckValue(ch, nameof(ch));
     // REVIEW:Get rid of this lock then we completly remove all static classes from Gam such as BlockingThreadPool.
     lock (FastTreeShared.TrainLock)
     {
         using (Timer.Time(TimerEvent.TotalInitialization))
             Initialize(ch);
         using (Timer.Time(TimerEvent.TotalTrain))
             TrainMainEffectsModel(ch);
     }
 }