public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope statescope, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; var state = (PruningState)statescope.Variables[PruningStateVariableName].Value; var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel; if (state.Code <= 1) { InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 2) { AssignPruningThresholds(treeModel, state, PruningDecay); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 3) { UpdateThreshold(treeModel, state); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 4) { Prune(treeModel, state, PruningStrength); cancellationToken.ThrowIfCancellationRequested(); } state.Code = 5; }
public void FillLeafs(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IDataset data) { var helperQueue = new Queue <RegressionNodeModel>(); var trainingHelperQueue = new Queue <IReadOnlyList <int> >(); nodeQueue.Clear(); trainingRowsQueue.Clear(); helperQueue.Enqueue(tree.Root); trainingHelperQueue.Enqueue(trainingRows); while (helperQueue.Count != 0) { var n = helperQueue.Dequeue(); var t = trainingHelperQueue.Dequeue(); if (n.IsLeaf) { nodeQueue.Enqueue(n); trainingRowsQueue.Enqueue(t); continue; } IReadOnlyList <int> leftTraining, rightTraining; RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining); helperQueue.Enqueue(n.Left); helperQueue.Enqueue(n.Right); trainingHelperQueue.Enqueue(leftTraining); trainingHelperQueue.Enqueue(rightTraining); } }
public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList <int> pruningRows, IReadOnlyList <int> trainingRows, IDataset data) { FillTopDown(tree, pruningRows, trainingRows, data); nodeQueue = new Queue <RegressionNodeModel>(nodeQueue.Reverse()); pruningRowsQueue = new Queue <IReadOnlyList <int> >(pruningRowsQueue.Reverse()); trainingRowsQueue = new Queue <IReadOnlyList <int> >(trainingRowsQueue.Reverse()); }
public void Build(IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; variables = regressionTreeParams.AllowedInputVariables.ToList(); //build tree and select node with maximum coverage var tree = RegressionNodeTreeModel.CreateTreeModel(regressionTreeParams.TargetVariable, regressionTreeParams); tree.BuildModel(trainingRows, pruningRows, statescope, results, cancellationToken); var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First(); var satts = new List <string>(); var svals = new List <double>(); var reops = new List <Comparison>(); //extract splits for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent) { satts.Add(temp.Parent.SplitAttribute); svals.Add(temp.Parent.SplitValue); reops.Add(temp.Parent.Left == temp ? Comparison.LessEqual : Comparison.Greater); } Comparisons = reops.ToArray(); SplitAttributes = satts.ToArray(); SplitValues = svals.ToArray(); int np; RuleModel = regressionTreeParams.LeafModel.BuildModel(trainingRows.Union(pruningRows).Where(r => Covers(regressionTreeParams.Data, r)).ToArray(), regressionTreeParams, cancellationToken, out np); }
private static IScope InitializeScope(IRandom random, IRegressionProblemData problemData, IPruning pruning, int minLeafSize, ILeafModel leafModel, ISplitter splitter, bool generateRules, bool useHoldout, double holdoutSize) { var stateScope = new Scope("RegressionTreeStateScope"); //reduce RegressionProblemData to AllowedInput & Target column wise and to TrainingSet row wise var doubleVars = new HashSet <string>(problemData.Dataset.DoubleVariables); var vars = problemData.AllowedInputVariables.Concat(new[] { problemData.TargetVariable }).ToArray(); if (vars.Any(v => !doubleVars.Contains(v))) { throw new NotSupportedException("Decision tree regression supports only double valued input or output features."); } var doubles = vars.Select(v => problemData.Dataset.GetDoubleValues(v, problemData.TrainingIndices).ToArray()).ToArray(); if (doubles.Any(v => v.Any(x => double.IsNaN(x) || double.IsInfinity(x)))) { throw new NotSupportedException("Decision tree regression does not support NaN or infinity values in the input dataset."); } var trainingData = new Dataset(vars, doubles); var pd = new RegressionProblemData(trainingData, problemData.AllowedInputVariables, problemData.TargetVariable); pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = pd.Dataset.Rows; pd.TrainingPartition.Start = 0; //store regression tree parameters var regressionTreeParams = new RegressionTreeParameters(pruning, minLeafSize, leafModel, pd, random, splitter); stateScope.Variables.Add(new Variable(RegressionTreeParameterVariableName, regressionTreeParams)); //initialize tree operators pruning.Initialize(stateScope); splitter.Initialize(stateScope); leafModel.Initialize(stateScope); //store unbuilt model IItem model; if (generateRules) { model = RegressionRuleSetModel.CreateRuleModel(problemData.TargetVariable, regressionTreeParams); RegressionRuleSetModel.Initialize(stateScope); } else { model = RegressionNodeTreeModel.CreateTreeModel(problemData.TargetVariable, regressionTreeParams); } stateScope.Variables.Add(new Variable(ModelVariableName, model)); //store training & pruning indices IReadOnlyList <int> trainingSet, pruningSet; GeneratePruningSet(pd.TrainingIndices.ToArray(), random, useHoldout, holdoutSize, out trainingSet, out pruningSet); stateScope.Variables.Add(new Variable(TrainingSetVariableName, new IntArray(trainingSet.ToArray()))); stateScope.Variables.Add(new Variable(PruningSetVariableName, new IntArray(pruningSet.ToArray()))); return(stateScope); }
public static Dictionary <string, int> GetTreeVariableFrequences(RegressionNodeTreeModel treeModel) { var res = treeModel.VariablesUsedForPrediction.ToDictionary(x => x, x => 0); var root = treeModel.Root; foreach (var cur in root.EnumerateNodes().Where(x => !x.IsLeaf)) { res[cur.SplitAttribute]++; } return(res); }
public static Result CreateLeafDepthHistogram(RegressionNodeTreeModel treeModel) { var list = new List <int>(); GetLeafDepths(treeModel.Root, 0, list); var row = new DataRow("Depths", "", list.Select(x => (double)x)) { VisualProperties = { ChartType = DataRowVisualProperties.DataRowChartType.Histogram } }; var hist = new DataTable("LeafDepths"); hist.Rows.Add(row); return(new Result(hist.Name, hist)); }
private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) { if (state.Code == 3) { state.FillTopDown(tree); state.Code = 4; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf || pruningStrength <= n.PruningStrength) { continue; } n.ToLeaf(); } }
private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) { if (state.Code == 2) { state.FillTopDown(tree); state.Code = 3; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) { continue; } n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength); } }
private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) { if (state.Code == 1) { state.FillBottomUp(tree); state.Code = 2; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf) { continue; } n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay); } }
private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) { if (state.Code == 0) { state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data); state.Code = 1; } while (state.nodeQueue.Count != 0) { cancellationToken.ThrowIfCancellationRequested(); var n = state.nodeQueue.Peek(); var training = state.trainingRowsQueue.Peek(); var pruning = state.pruningRowsQueue.Peek(); BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken); state.nodeQueue.Dequeue(); state.trainingRowsQueue.Dequeue(); state.pruningRowsQueue.Dequeue(); } }
public void Build(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IScope stateScope, CancellationToken cancellationToken) { var parameters = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; var state = (LeafBuildingState)stateScope.Variables[LeafBuildingStateVariableName].Value; if (state.Code == 0) { state.FillLeafs(tree, trainingRows, parameters.Data); state.Code = 1; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Peek(); var t = state.trainingRowsQueue.Peek(); int numP; n.SetLeafModel(BuildModel(t, parameters, cancellationToken, out numP)); state.nodeQueue.Dequeue(); state.trainingRowsQueue.Dequeue(); } }
public void FillTopDown(RegressionNodeTreeModel tree) { var helperQueue = new Queue <RegressionNodeModel>(); nodeQueue.Clear(); helperQueue.Enqueue(tree.Root); nodeQueue.Enqueue(tree.Root); while (helperQueue.Count != 0) { var n = helperQueue.Dequeue(); if (n.IsLeaf) { continue; } helperQueue.Enqueue(n.Left); helperQueue.Enqueue(n.Right); nodeQueue.Enqueue(n.Left); nodeQueue.Enqueue(n.Right); } }
public static void AnalyzeNodes(RegressionNodeTreeModel tree, ResultCollection results, IRegressionProblemData pd) { var dict = new Dictionary <int, RegressionNodeModel>(); var trainingLeafRows = new Dictionary <int, IReadOnlyList <int> >(); var testLeafRows = new Dictionary <int, IReadOnlyList <int> >(); var modelNumber = new IntValue(1); var symtree = new SymbolicExpressionTree(MirrorTree(tree.Root, dict, trainingLeafRows, testLeafRows, modelNumber, pd.Dataset, pd.TrainingIndices.ToList(), pd.TestIndices.ToList())); results.AddOrUpdateResult("DecisionTree", symtree); if (dict.Count > 200) { return; } var models = new Scope("NodeModels"); results.AddOrUpdateResult("NodeModels", models); foreach (var m in dict.Keys.OrderBy(x => x)) { models.Variables.Add(new Variable("Model " + m, dict[m].CreateRegressionSolution(Subselect(pd, trainingLeafRows[m], testLeafRows[m])))); } }
public void Split(RegressionNodeTreeModel tree, IReadOnlyList <int> trainingRows, IScope stateScope, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; var splittingState = (SplittingState)stateScope.Variables[SplittingStateVariableName].Value; var variables = regressionTreeParams.AllowedInputVariables.ToArray(); var target = regressionTreeParams.TargetVariable; if (splittingState.Code <= 0) { splittingState.nodeQueue.Enqueue(tree.Root); splittingState.trainingRowsQueue.Enqueue(trainingRows); splittingState.Code = 1; } while (splittingState.nodeQueue.Count != 0) { var n = splittingState.nodeQueue.Dequeue(); var rows = splittingState.trainingRowsQueue.Dequeue(); string attr; double splitValue; var isLeaf = !DecideSplit(new RegressionProblemData(RegressionTreeUtilities.ReduceDataset(regressionTreeParams.Data, rows, variables, target), variables, target), regressionTreeParams.MinLeafSize, out attr, out splitValue); if (isLeaf) { continue; } IReadOnlyList <int> leftRows, rightRows; RegressionTreeUtilities.SplitRows(rows, regressionTreeParams.Data, attr, splitValue, out leftRows, out rightRows); n.Split(regressionTreeParams, attr, splitValue, rows.Count); splittingState.nodeQueue.Enqueue(n.Left); splittingState.nodeQueue.Enqueue(n.Right); splittingState.trainingRowsQueue.Enqueue(leftRows); splittingState.trainingRowsQueue.Enqueue(rightRows); cancellationToken.ThrowIfCancellationRequested(); } }
public void FillBottomUp(RegressionNodeTreeModel tree) { FillTopDown(tree); nodeQueue = new Queue <RegressionNodeModel>(nodeQueue.Reverse()); }
public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList <int> trainingRows, IReadOnlyList <int> pruningRows, IScope scope, CancellationToken cancellationToken) { }
protected RegressionNodeTreeModel(RegressionNodeTreeModel original, Cloner cloner) : base(original, cloner) { Root = cloner.Clone(original.Root); }
public static void PruningChart(RegressionNodeTreeModel tree, ComplexityPruning pruning, ResultCollection results) { var nodes = new Queue <RegressionNodeModel>(); nodes.Enqueue(tree.Root); var max = 0.0; var strenghts = new SortedList <double, int>(); while (nodes.Count > 0) { var n = nodes.Dequeue(); if (n.IsLeaf) { max++; continue; } if (!strenghts.ContainsKey(n.PruningStrength)) { strenghts.Add(n.PruningStrength, 0); } strenghts[n.PruningStrength]++; nodes.Enqueue(n.Left); nodes.Enqueue(n.Right); } if (strenghts.Count == 0) { return; } var plot = new ScatterPlot("Pruned Sizes", "") { VisualProperties = { XAxisTitle = "Pruning Strength", YAxisTitle = "Tree Size", XAxisMinimumAuto = false, XAxisMinimumFixedValue = 0 } }; var row = new ScatterPlotDataRow("TreeSizes", "", new List <Point2D <double> >()); row.Points.Add(new Point2D <double>(pruning.PruningStrength, max)); var fillerDots = new Queue <double>(); var minX = pruning.PruningStrength; var maxX = strenghts.Last().Key; var size = (maxX - minX) / 200; for (var x = minX; x <= maxX; x += size) { fillerDots.Enqueue(x); } foreach (var strenght in strenghts.Keys) { while (fillerDots.Count > 0 && strenght > fillerDots.Peek()) { row.Points.Add(new Point2D <double>(fillerDots.Dequeue(), max)); } max -= strenghts[strenght]; row.Points.Add(new Point2D <double>(strenght, max)); } row.VisualProperties.PointSize = 6; plot.Rows.Add(row); results.AddOrUpdateResult("PruningSizes", plot); }