public void TrainRecursive(ID3Node root, Instances S, int targetAttributeIndex, List <int> attributeList, double confidenceLevel, int maxDepth = 0) { // For each possible discrete value that the target attribute can have, count how many times it is present in the examples Dictionary <int, Instances> targetValueCounts = new Dictionary <int, Instances>(); for (int i = 0; i < S.attribute(targetAttributeIndex).numValues(); i++) { targetValueCounts.Add(i, new Instances(S, 0, 0)); } // Check the most common target attribute value of every example in S // and keep track of whether all target values are the same value int countOfS = S.numInstances(); int firstTargetValue = (int)S.instance(0).value(targetAttributeIndex); bool allTargetValuesAreEqual = true; for (int i = 0; i < countOfS; i++) { if (Double.IsNaN(S.instance(i).value(targetAttributeIndex))) { // For target values, this shouldn't happen throw new Exception(String.Format("Value at targetAttributeIndex {0} is NaN", targetAttributeIndex)); } int value = (int)S.instance(i).value(targetAttributeIndex); targetValueCounts[value].add(S.instance(i)); if (firstTargetValue != value) { allTargetValuesAreEqual = false; } } // If all target values are the same we can make this a leaf with that value and return if (allTargetValuesAreEqual == true) { root.IsLeaf = true; root.TargetValue = firstTargetValue; Log.LogInfo("All Targets Equal. Node with split {0}, value {1}, leaf {2}, weight {3}", root.SplitAttributeIndex, root.TargetValue, root.IsLeaf, root.Weight); return; } // Find the most common target attribute value int mostCommonTargetValue = 0; for (int i = 0; i < targetValueCounts.Count(); i++) { if (targetValueCounts[i].numInstances() > targetValueCounts[mostCommonTargetValue].numInstances()) { mostCommonTargetValue = i; } } // Check if the attribute list is empty and if so return most common target value if (attributeList.Count == 0) { // Now set the node to this target value and return root.IsLeaf = true; root.TargetValue = mostCommonTargetValue; Log.LogInfo("Attribute List Empty. Node with split {0}, value {1}, leaf {2}, weight {3}", root.SplitAttributeIndex, root.TargetValue, root.IsLeaf, root.Weight); return; } // Figure out which attribute will give us the most gain double gainSum = 0; SortedList <double, int> sortedGainList = new SortedList <double, int>(); for (int i = 0; i < attributeList.Count(); i++) { double gain = this.CalculateGain(S, i, targetAttributeIndex); gainSum += gain; // TODO: remove if (Double.IsNaN(gain)) { } // We use a sorted list which must have a unique key. Since the key is gain, then this might not be unique // across all attributes. Thus, if we encounter duplicate keys figure out which on has higher gain ratio. // Whichever has higher gain ratio wins and gets into the list. Later, we pick from the list the attribute // with highest gain ratio anyways so we won't lose any information with this approach. if (sortedGainList.ContainsKey(gain)) { double oldGainRatio = this.CalculateGainRatio(S, sortedGainList[gain], targetAttributeIndex); double newGainRatio = this.CalculateGainRatio(S, i, targetAttributeIndex); if (newGainRatio > oldGainRatio) { // Replace the old value with the one that has higher gain ratio sortedGainList[gain] = i; } } else { sortedGainList.Add(gain, i); } } double maxGain = sortedGainList.Last().Key; int maxGainAttribute = sortedGainList.Last().Value; double averageGain = gainSum / attributeList.Count(); // Use gain ratio on top N% from the gainListOrdered and calculate maxGainRatio double maxGainRatio = 0; int maxGainRatioAttribute = sortedGainList.Count() - 1; // default to the largest gain double NPercent = 0.2; int topNPercent = (int)Math.Ceiling(NPercent * sortedGainList.Count()); for (int i = 0; i < topNPercent; i++) { int reverse_i = sortedGainList.Count() - 1 - i; // Since we are search the list from bottom to top int index = sortedGainList.ElementAt(reverse_i).Value; double gainRatio = this.CalculateGainRatio(S, index, targetAttributeIndex); if (gainRatio > maxGainRatio) { maxGainRatio = gainRatio; maxGainRatioAttribute = index; } } // Now we know which attribute to split on Log.LogGain("MaxGainRatio {0} from attrib {1}. Max Gain {2} from attrib {3}. Avg Gain {4}.", maxGainRatio, maxGainRatioAttribute, maxGain, maxGainAttribute, averageGain); // Check if we should stop splitting if (ChiSquare.ChiSquaredTest(confidenceLevel, S, maxGainRatioAttribute, targetAttributeIndex) == false) { root.IsLeaf = true; root.TargetValue = mostCommonTargetValue; Log.LogInfo("ChiSquared stop split. Node with split {0}, value {1}, leaf {2}, weight {3}", root.SplitAttributeIndex, root.TargetValue, root.IsLeaf, root.Weight); return; } // We are going to split. Create a new list of attributes that won't include the attribute we split on. root.SplitAttributeIndex = maxGainRatioAttribute; List <int> newAttributeList = new List <int>(attributeList); newAttributeList.RemoveAt(maxGainRatioAttribute); // Partition the examples by their attribute value Dictionary <int, Instances> examplesVi = new Dictionary <int, Instances>(); // Initialize the examplesVi dictionary for (int i = 0; i < S.attribute(maxGainRatioAttribute).numValues(); i++) { examplesVi.Add(i, new Instances(S, 0, 0)); } // Fill the examplesVi dictionary int totalExamplesVi = 0; for (int i = 0; i < S.numInstances(); i++) { if (Double.IsNaN(S.instance(i).value(maxGainRatioAttribute))) { Log.LogVerbose("IsNaN encountered for instance {0} of maxGainAttribute {1}", i, maxGainRatioAttribute); continue; } int value = (int)S.instance(i).value(maxGainRatioAttribute); examplesVi[value].add(S.instance(i)); totalExamplesVi++; } // Split for (int i = 0; i < S.attribute(maxGainRatioAttribute).numValues(); i++) { ID3Node newChild = new ID3Node(); newChild.Depth = root.Depth + 1; root.ChildNodes.Add(newChild); if (examplesVi[i].numInstances() == 0) // no more examples to split on { newChild.IsLeaf = true; newChild.TargetValue = mostCommonTargetValue; Log.LogInfo("No instances to split on. Create new leaf child from parent split {0}, new value {1}", root.SplitAttributeIndex, newChild.TargetValue, root.IsLeaf, root.Weight); } else if (maxDepth > 0 && newChild.Depth > maxDepth) // we hit max depth { newChild.IsLeaf = true; newChild.TargetValue = mostCommonTargetValue; Log.LogInfo("Hit max depth of {0}. Create new leaf child from parent split {1}, new value {2}", maxDepth, root.SplitAttributeIndex, newChild.TargetValue, root.IsLeaf, root.Weight); } else { Log.LogInfo("Splitting from node with split {0}, value {1}, leaf {2}, weight {3}", root.SplitAttributeIndex, root.TargetValue, root.IsLeaf, root.Weight); newChild.IsLeaf = false; newChild.SplitAttributeIndex = i; newChild.Weight = examplesVi[i].numInstances() / (double)totalExamplesVi; this.TrainRecursive(newChild, examplesVi[i], targetAttributeIndex, newAttributeList, confidenceLevel, maxDepth); } } }