////////////////////////////////////////////////////// // Function: // Classify // // Purpose: // Returns the classification of the given data. // // Note: // Data is given as an example object, but the // classification calculated is not placed // into the query object. ////////////////////////////////////////////////////// public override string Classify( CClassificationData inQuery) { var szOutClassification = "UNKNOWN"; CDTNode pNodeToInspect = new CDTNode(m_decisionTreeRootNode); while (pNodeToInspect != null) { int iChildCount = pNodeToInspect.GetChildren().Count(); if (iChildCount == 0) { szOutClassification = m_rwClasses[pNodeToInspect.GetClassIdentifier()]; break; } else { int iAttributeID = pNodeToInspect.GetAttributeIdentifier(); int iValueID = inQuery.GetValueIdentifier(iAttributeID); pNodeToInspect = pNodeToInspect.GetChildren()[iValueID]; } } return(szOutClassification); }
public CDTNode(CDTNode inOtherNode) { m_pDecisionTree = inOtherNode.m_pDecisionTree; m_fEntropy = inOtherNode.m_fEntropy; m_iClass = inOtherNode.m_iClass; m_bIsLeaf = inOtherNode.m_bIsLeaf; m_rwRemainingAttributeIDs = inOtherNode.m_rwRemainingAttributeIDs; m_rwExampleIDs = inOtherNode.m_rwExampleIDs; m_rwChildren = inOtherNode.m_rwChildren; m_rwInformationGain = inOtherNode.m_rwInformationGain; m_iAttributeID = inOtherNode.m_iAttributeID; m_rwClassCount = inOtherNode.m_rwClassCount; }
public CDecisionTree() { m_decisionTreeRootNode = new CDTNode(); }
////////////////////////////////////////////////////// // Mutator: // CDTNode::Train // // Purpose: // Calculates the information gain for the remaining // attributes, the entropy of the remaining examples // and creates splits/sub-nodes ////////////////////////////////////////////////////// public void Train() { // Note: // I need a better approach to situations where noe more attributes // are left to split on, but multiple classes exist in the examples. // If we don't have any examples to train on, then // this node HAS to be a leaf. We will keep the classification // given to us which was the most common class from the parent node if (m_rwExampleIDs.Count() == 0) { m_bIsLeaf = true; return; } m_bIsLeaf = false; m_fEntropy = GetEntropy(); //assert( m_fEntropy >= 0.0f ); //assert( m_fEntropy <= ( (float)m_pDecisionTree->GetNumClasses() - 1.0f ) ); // If we don't have any entropy, then go ahead and treat this like a leaf // since we don't have any reason to calculate any more if (m_fEntropy == 0.0f) { m_bIsLeaf = true; m_iClass = m_pDecisionTree.GetExample(m_rwExampleIDs[0]).GetClassIdentifier(); return; } // Calculate the gains for the attributes int iAttributeID = 0; int i = 0; int iBestClassCount = 0; // The number of times the most common class appears int iBestAttributeID = m_rwRemainingAttributeIDs[0]; // Treat the first attribute initially as the best float fBestGain = GetInformationGain(iBestAttributeID); // Save the information gain from the first attribute m_rwInformationGain = new List <float>(); for (i = 0; i < m_pDecisionTree.GetNumAttributes(); ++i) { m_rwInformationGain.Add(0.0f); } // Put the gain we just calculated into the proper place m_rwInformationGain[iBestAttributeID] = fBestGain; // Find the information gain for each attribute and store it // while keeping track of the best gain and the attribute // that it goes with. int iRemainingAttributeCount = m_rwRemainingAttributeIDs.Count(); for (i = 1; i < iRemainingAttributeCount; ++i) { iAttributeID = m_rwRemainingAttributeIDs[i]; m_rwInformationGain[iAttributeID] = GetInformationGain(iAttributeID); // If we find a better gain, store it // and remember which attribute it came from if (m_rwInformationGain[iAttributeID] > fBestGain) { fBestGain = m_rwInformationGain[iAttributeID]; iBestAttributeID = iAttributeID; } } // Set the class counts to 0 m_rwClassCount = new List <int>(); int iTreeClassCount = m_pDecisionTree.GetNumClasses(); for (i = 0; i < iTreeClassCount; ++i) { m_rwClassCount.Add(0); } // Store the attribute this splits on m_iAttributeID = iBestAttributeID; // Modify the list of attributes to give to the new children List <int> rwNewAttributeList = m_rwRemainingAttributeIDs; // If we don't have any attributes left to split on // then this node has to be a leaf, so we need to // find the most common class for the examples // and make that the node's classification m_bIsLeaf = !m_rwRemainingAttributeIDs.Any(); // If we are not a leaf, then generate the children nodes // and give them a list of remaining attributes they can split on if (!m_bIsLeaf) { rwNewAttributeList = (from attributeId in rwNewAttributeList where attributeId != m_iAttributeID select attributeId).ToList(); // Create the sub-nodes CAttribute attributeToSplitOn = m_pDecisionTree.GetAttribute(m_iAttributeID); int iAttributeValueCount = attributeToSplitOn.GetNumValues(); for (i = 0; i < iAttributeValueCount; ++i) { CDTNode pNewNode = new CDTNode(m_pDecisionTree, rwNewAttributeList); m_rwChildren.Add(pNewNode); } } // Now we know the best attribute to split/branch on // send the examples to their appropriate child nodes // while finding the most common classification // for this node's examples int iExampleCount = m_rwExampleIDs.Count(); for (i = 0; i < iExampleCount; ++i) { int iExampleID = m_rwExampleIDs[i]; CExample curExample = m_pDecisionTree.GetExample(iExampleID); int iExampleClassID = curExample.GetClassIdentifier(); // Save a tally of the occurrence of each classification // and which class is the most common ++m_rwClassCount[iExampleClassID]; if (iBestClassCount < m_rwClassCount[iExampleClassID]) { iBestClassCount = m_rwClassCount[iExampleClassID]; m_iClass = iExampleClassID; } // Match the examples with the sub tree that matches the attribute's value if (!m_bIsLeaf) { int iValueID = curExample.GetValueIdentifier(m_iAttributeID); // Add the example into the correct sub-tree and remove it from this level m_rwChildren[iValueID].GetExampleIdentifierList().Add(iExampleID); } } // If we are a leaf, then we don't have any children // nodes to calculate so just return back if (m_bIsLeaf) { return; } // No more examples should be associated with this node. m_rwExampleIDs.Clear(); // Calculate all the subtrees for this node int iNumValues = m_pDecisionTree.GetAttribute(m_iAttributeID).GetNumValues(); for (i = 0; i < iNumValues; ++i) { m_rwChildren[i].m_iClass = m_iClass; m_rwChildren[i].Train(); } }
////////////////////////////////////////////////////// // Function: // LoadFromFile // // Purpose: // Loads the node from the given IO file. // Returns FALSE in case of error. ////////////////////////////////////////////////////// public bool LoadFromFile( StreamReader pInInputFile) { m_iClass = 0; m_iAttributeID = 0; m_fEntropy = 0.0f; m_rwInformationGain.Clear(); m_rwExampleIDs.Clear(); m_rwChildren.Clear(); string inputLine = pInInputFile.ReadLine().Trim(); string[] tokens = inputLine.Split(new[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); if (tokens == null || tokens.Length != 2) { return(false); } string szNodeType = tokens[0].Trim(); string szAttributeName = tokens[1].Trim(); if (string.Compare(m_szSplitKeyword, szNodeType, StringComparison.InvariantCultureIgnoreCase) == 0) { int iAttributeIdentifier = 0; if (m_pDecisionTree.GetAttributeIdentifier(szAttributeName, ref iAttributeIdentifier)) { int iNumChildren = m_pDecisionTree.GetAttribute(iAttributeIdentifier).GetNumValues(); m_iAttributeID = iAttributeIdentifier; int iChildIndex = 0; // Create the data first since we could read the // tree in an arbitrary order.. for (iChildIndex = 0; iChildIndex < iNumChildren; ++iChildIndex) { var pNewNode = new CDTNode(); if (pNewNode == null) { return(false); } pNewNode.SetDecisionTree(m_pDecisionTree); m_rwChildren.Add(pNewNode); } m_bIsLeaf = false; } else { return(false); } foreach (CDTNode pLinkNode in m_rwChildren) { string szAttributeValue = pInInputFile.ReadLine().Trim(); int iValueIdentifer = 0; if (m_pDecisionTree.GetAttribute(m_iAttributeID).GetValueIdentifier(szAttributeValue, ref iValueIdentifer)) { pLinkNode.LoadFromFile(pInInputFile); } else { return(false); } } } else if (m_szOutcomeKeyword == szNodeType) { // Leaf in tree int iClassIdentifier = 0; m_bIsLeaf = true; if (m_pDecisionTree.GetClassIdentifier(szAttributeName, ref iClassIdentifier)) { m_iClass = iClassIdentifier; } else { return(false); } } else { // Unexpected input return(false); } return(true); }