Exemplo n.º 1
0
        internal override InternalRegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
        {
            Contracts.CheckValue(ch, nameof(ch));
            AgdScoreTracker trainingScores = TrainingScores as AgdScoreTracker;
            //First Let's make XK=YK as we want to fit YK and LineSearch YK
            // and call base class that uses fits XK (in our case will fir YK thanks to the swap)
            var xk = trainingScores.XK;

            trainingScores.XK = trainingScores.YK;
            trainingScores.YK = null;

            //Invoke standard gradient descent on YK rather than XK(Scores)
            InternalRegressionTree tree = base.TrainingIteration(ch, activeFeatures);

            //Reverse the XK/YK swap
            trainingScores.YK = trainingScores.XK;
            trainingScores.XK = xk;

            if (tree == null)
            {
                return(null); // No tree was actually learnt. Give up.
            }
            // ... and update the training scores that we omitted from update
            // in AcceleratedGradientDescent.UpdateScores
            // Here we could use faster way of comuting train scores taking advantage of scores precompited by LineSearch
            // But that would make the code here even more difficult/complex
            trainingScores.AddScores(tree, TreeLearner.Partitioning, 1.0);

            //Now rescale all previous trees based on ratio of new_desired_tree_scale/previous_tree_scale
            for (int t = 0; t < Ensemble.NumTrees - 1; t++)
            {
                Ensemble.GetTreeAt(t).ScaleOutputsBy(AgdScoreTracker.TreeMultiplier(t, Ensemble.NumTrees) / AgdScoreTracker.TreeMultiplier(t, Ensemble.NumTrees - 1));
            }
            return(tree);
        }
Exemplo n.º 2
0
 public override void FinalizeLearning(int bestIteration)
 {
     if (bestIteration != Ensemble.NumTrees)
     {
         // Restore multiplier for each tree as it was set during bestIteration
         for (int t = 0; t < bestIteration; t++)
         {
             Ensemble.GetTreeAt(t).ScaleOutputsBy(AgdScoreTracker.TreeMultiplier(t, bestIteration) / AgdScoreTracker.TreeMultiplier(t, Ensemble.NumTrees));
         }
     }
     base.FinalizeLearning(bestIteration);
 }