예제 #1
0
        public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope statescope, CancellationToken cancellationToken)
        {
            var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
            var state = (PruningState)statescope.Variables[PruningStateVariableName].Value;

            var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel;

            if (state.Code <= 1)
            {
                InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken);
                cancellationToken.ThrowIfCancellationRequested();
            }
            if (state.Code <= 2)
            {
                AssignPruningThresholds(treeModel, state, PruningDecay);
                cancellationToken.ThrowIfCancellationRequested();
            }
            if (state.Code <= 3)
            {
                UpdateThreshold(treeModel, state);
                cancellationToken.ThrowIfCancellationRequested();
            }
            if (state.Code <= 4)
            {
                Prune(treeModel, state, PruningStrength);
                cancellationToken.ThrowIfCancellationRequested();
            }

            state.Code = 5;
        }
예제 #2
0
            public void FillLeafs(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IDataset data)
            {
                var helperQueue         = new Queue <RegressionNodeModel>();
                var trainingHelperQueue = new Queue <IReadOnlyList <int> >();

                nodeQueue.Clear();
                trainingRowsQueue.Clear();

                helperQueue.Enqueue(tree.Root);
                trainingHelperQueue.Enqueue(trainingRows);

                while (helperQueue.Count != 0)
                {
                    var n = helperQueue.Dequeue();
                    var t = trainingHelperQueue.Dequeue();
                    if (n.IsLeaf)
                    {
                        nodeQueue.Enqueue(n);
                        trainingRowsQueue.Enqueue(t);
                        continue;
                    }

                    IReadOnlyList <int> leftTraining, rightTraining;
                    RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);

                    helperQueue.Enqueue(n.Left);
                    helperQueue.Enqueue(n.Right);
                    trainingHelperQueue.Enqueue(leftTraining);
                    trainingHelperQueue.Enqueue(rightTraining);
                }
            }
예제 #3
0
 public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList <int> pruningRows, IReadOnlyList <int> trainingRows, IDataset data)
 {
     FillTopDown(tree, pruningRows, trainingRows, data);
     nodeQueue         = new Queue <RegressionNodeModel>(nodeQueue.Reverse());
     pruningRowsQueue  = new Queue <IReadOnlyList <int> >(pruningRowsQueue.Reverse());
     trainingRowsQueue = new Queue <IReadOnlyList <int> >(trainingRowsQueue.Reverse());
 }
예제 #4
0
        public void Build(IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken)
        {
            var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;

            variables = regressionTreeParams.AllowedInputVariables.ToList();

            //build tree and select node with maximum coverage
            var tree = RegressionNodeTreeModel.CreateTreeModel(regressionTreeParams.TargetVariable, regressionTreeParams);

            tree.BuildModel(trainingRows, pruningRows, statescope, results, cancellationToken);
            var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First();

            var satts = new List <string>();
            var svals = new List <double>();
            var reops = new List <Comparison>();

            //extract splits
            for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent)
            {
                satts.Add(temp.Parent.SplitAttribute);
                svals.Add(temp.Parent.SplitValue);
                reops.Add(temp.Parent.Left == temp ? Comparison.LessEqual : Comparison.Greater);
            }
            Comparisons     = reops.ToArray();
            SplitAttributes = satts.ToArray();
            SplitValues     = svals.ToArray();
            int np;

            RuleModel = regressionTreeParams.LeafModel.BuildModel(trainingRows.Union(pruningRows).Where(r => Covers(regressionTreeParams.Data, r)).ToArray(), regressionTreeParams, cancellationToken, out np);
        }
예제 #5
0
        private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize)
        {
            var stateScope = new Scope("RegressionTreeStateScope");

            //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise
            var doubleVars = new HashSet <string>(problemData.Dataset.DoubleVariables);
            var vars       = problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }).ToArray();

            if (vars.Any(v => !doubleVars.Contains(v)))
            {
                throw new NotSupportedException("Decision tree regression supports only double valued input or output features.");
            }
            var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray();

            if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x))))
            {
                throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset.");
            }
            var trainingData = new Dataset(vars, doubles);
            var pd           = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable);

            pd.TrainingPartition.End   = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows;
            pd.TrainingPartition.Start = 0;

            //store regression tree parameters
            var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter);

            stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams));

            //initialize tree operators
            pruning.Initialize(stateScope);
            splitter.Initialize(stateScope);
            leafModel.Initialize(stateScope);

            //store unbuilt model
            IItem model;

            if (generateRules)
            {
                model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams);
                RegressionRuleSetModel.Initialize(stateScope);
            }
            else
            {
                model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams);
            }
            stateScope.Variables.Add(new Variable(ModelVariableName, model));

            //store training & pruning indices
            IReadOnlyList <int> trainingSet, pruningSet;

            GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet);
            stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray())));
            stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray())));

            return(stateScope);
        }
예제 #6
0
        public static Dictionary <string, int> GetTreeVariableFrequences(RegressionNodeTreeModel treeModel)
        {
            var res  = treeModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0);
            var root = treeModel.Root;

            foreach (var cur in root.EnumerateNodes().Where(x => !x.IsLeaf))
            {
                res[cur.SplitAttribute]++;
            }
            return(res);
        }
예제 #7
0
        public static Result CreateLeafDepthHistogram(RegressionNodeTreeModel treeModel)
        {
            var list = new List <int>();

            GetLeafDepths(treeModel.Root, 0, list);
            var row = new DataRow("Depths", "", list.Select(x => (double)x))
            {
                VisualProperties = { ChartType = DataRowVisualProperties.DataRowChartType.Histogram }
            };
            var hist = new DataTable("LeafDepths");

            hist.Rows.Add(row);
            return(new Result(hist.Name, hist));
        }
예제 #8
0
 private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength)
 {
     if (state.Code == 3)
     {
         state.FillTopDown(tree);
         state.Code = 4;
     }
     while (state.nodeQueue.Count != 0)
     {
         var n = state.nodeQueue.Dequeue();
         if (n.IsLeaf || pruningStrength <= n.PruningStrength)
         {
             continue;
         }
         n.ToLeaf();
     }
 }
예제 #9
0
 private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state)
 {
     if (state.Code == 2)
     {
         state.FillTopDown(tree);
         state.Code = 3;
     }
     while (state.nodeQueue.Count != 0)
     {
         var n = state.nodeQueue.Dequeue();
         if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength))
         {
             continue;
         }
         n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength);
     }
 }
예제 #10
0
 private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay)
 {
     if (state.Code == 1)
     {
         state.FillBottomUp(tree);
         state.Code = 2;
     }
     while (state.nodeQueue.Count != 0)
     {
         var n = state.nodeQueue.Dequeue();
         if (n.IsLeaf)
         {
             continue;
         }
         n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay);
     }
 }
예제 #11
0
 private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken)
 {
     if (state.Code == 0)
     {
         state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data);
         state.Code = 1;
     }
     while (state.nodeQueue.Count != 0)
     {
         cancellationToken.ThrowIfCancellationRequested();
         var n        = state.nodeQueue.Peek();
         var training = state.trainingRowsQueue.Peek();
         var pruning  = state.pruningRowsQueue.Peek();
         BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken);
         state.nodeQueue.Dequeue();
         state.trainingRowsQueue.Dequeue();
         state.pruningRowsQueue.Dequeue();
     }
 }
예제 #12
0
        public void Build(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IScope stateScope, CancellationToken cancellationToken)
        {
            var parameters = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
            var state      = (LeafBuildingState)stateScope.Variables[LeafBuildingStateVariableName].Value;

            if (state.Code == 0)
            {
                state.FillLeafs(tree, trainingRows, parameters.Data);
                state.Code = 1;
            }
            while (state.nodeQueue.Count != 0)
            {
                var n = state.nodeQueue.Peek();
                var t = state.trainingRowsQueue.Peek();
                int numP;
                n.SetLeafModel(BuildModel(t, parameters, cancellationToken, out numP));
                state.nodeQueue.Dequeue();
                state.trainingRowsQueue.Dequeue();
            }
        }
예제 #13
0
            public void FillTopDown(RegressionNodeTreeModel tree)
            {
                var helperQueue = new Queue <RegressionNodeModel>();

                nodeQueue.Clear();

                helperQueue.Enqueue(tree.Root);
                nodeQueue.Enqueue(tree.Root);

                while (helperQueue.Count != 0)
                {
                    var n = helperQueue.Dequeue();
                    if (n.IsLeaf)
                    {
                        continue;
                    }
                    helperQueue.Enqueue(n.Left);
                    helperQueue.Enqueue(n.Right);
                    nodeQueue.Enqueue(n.Left);
                    nodeQueue.Enqueue(n.Right);
                }
            }
예제 #14
0
        public static void AnalyzeNodes(RegressionNodeTreeModel tree, ResultCollection results, IRegressionProblemData pd)
        {
            var dict             = new Dictionary <int, RegressionNodeModel>();
            var trainingLeafRows = new Dictionary <int, IReadOnlyList <int> >();
            var testLeafRows     = new Dictionary <int, IReadOnlyList <int> >();
            var modelNumber      = new IntValue(1);
            var symtree          = new SymbolicExpressionTree(MirrorTree(tree.Root, dict, trainingLeafRows, testLeafRows, modelNumber, pd.Dataset, pd.TrainingIndices.ToList(), pd.TestIndices.ToList()));

            results.AddOrUpdateResult("DecisionTree", symtree);

            if (dict.Count > 200)
            {
                return;
            }
            var models = new Scope("NodeModels");

            results.AddOrUpdateResult("NodeModels", models);
            foreach (var m in dict.Keys.OrderBy(x => x))
            {
                models.Variables.Add(new Variable("Model " + m, dict[m].CreateRegressionSolution(Subselect(pd, trainingLeafRows[m], testLeafRows[m]))));
            }
        }
예제 #15
0
        public void Split(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IScope stateScope, CancellationToken cancellationToken)
        {
            var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
            var splittingState       = (SplittingState)stateScope.Variables[SplittingStateVariableName].Value;
            var variables            = regressionTreeParams.AllowedInputVariables.ToArray();
            var target = regressionTreeParams.TargetVariable;

            if (splittingState.Code <= 0)
            {
                splittingState.nodeQueue.Enqueue(tree.Root);
                splittingState.trainingRowsQueue.Enqueue(trainingRows);
                splittingState.Code = 1;
            }
            while (splittingState.nodeQueue.Count != 0)
            {
                var n    = splittingState.nodeQueue.Dequeue();
                var rows = splittingState.trainingRowsQueue.Dequeue();

                string attr;
                double splitValue;
                var    isLeaf = !DecideSplit(new RegressionProblemData(RegressionTreeUtilities.ReduceDataset(regressionTreeParams.Data, rows, variables, target), variables, target), regressionTreeParams.MinLeafSize, out attr, out splitValue);
                if (isLeaf)
                {
                    continue;
                }

                IReadOnlyList <int> leftRows, rightRows;
                RegressionTreeUtilities.SplitRows(rows, regressionTreeParams.Data, attr, splitValue, out leftRows, out rightRows);
                n.Split(regressionTreeParams, attr, splitValue, rows.Count);

                splittingState.nodeQueue.Enqueue(n.Left);
                splittingState.nodeQueue.Enqueue(n.Right);
                splittingState.trainingRowsQueue.Enqueue(leftRows);
                splittingState.trainingRowsQueue.Enqueue(rightRows);
                cancellationToken.ThrowIfCancellationRequested();
            }
        }
예제 #16
0
 public void FillBottomUp(RegressionNodeTreeModel tree)
 {
     FillTopDown(tree);
     nodeQueue = new Queue <RegressionNodeModel>(nodeQueue.Reverse());
 }
예제 #17
0
 public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope scope, CancellationToken cancellationToken)
 {
 }
예제 #18
0
 protected RegressionNodeTreeModel(RegressionNodeTreeModel original, Cloner cloner) : base(original, cloner)
 {
     Root = cloner.Clone(original.Root);
 }
예제 #19
0
        public static void PruningChart(RegressionNodeTreeModel tree, ComplexityPruning pruning, ResultCollection results)
        {
            var nodes = new Queue <RegressionNodeModel>();

            nodes.Enqueue(tree.Root);
            var max       = 0.0;
            var strenghts = new SortedList <double, int>();

            while (nodes.Count > 0)
            {
                var n = nodes.Dequeue();

                if (n.IsLeaf)
                {
                    max++;
                    continue;
                }

                if (!strenghts.ContainsKey(n.PruningStrength))
                {
                    strenghts.Add(n.PruningStrength, 0);
                }
                strenghts[n.PruningStrength]++;
                nodes.Enqueue(n.Left);
                nodes.Enqueue(n.Right);
            }
            if (strenghts.Count == 0)
            {
                return;
            }

            var plot = new ScatterPlot("Pruned Sizes", "")
            {
                VisualProperties =
                {
                    XAxisTitle             = "Pruning Strength",
                    YAxisTitle             = "Tree Size",
                    XAxisMinimumAuto       = false,
                    XAxisMinimumFixedValue = 0
                }
            };
            var row = new ScatterPlotDataRow("TreeSizes", "", new List <Point2D <double> >());

            row.Points.Add(new Point2D <double>(pruning.PruningStrength, max));

            var fillerDots = new Queue <double>();
            var minX       = pruning.PruningStrength;
            var maxX       = strenghts.Last().Key;
            var size       = (maxX - minX) / 200;

            for (var x = minX; x <= maxX; x += size)
            {
                fillerDots.Enqueue(x);
            }

            foreach (var strenght in strenghts.Keys)
            {
                while (fillerDots.Count > 0 && strenght > fillerDots.Peek())
                {
                    row.Points.Add(new Point2D <double>(fillerDots.Dequeue(), max));
                }
                max -= strenghts[strenght];
                row.Points.Add(new Point2D <double>(strenght, max));
            }


            row.VisualProperties.PointSize = 6;
            plot.Rows.Add(row);
            results.AddOrUpdateResult("PruningSizes", plot);
        }