Exemplo n.º 1
0
        private int GetClass(int[] instance, Id3Node tree)
        {
            if (tree.IsLeaf)
            {
                return(tree.Class);
            }

            int valueIndex = instance[tree.AttributeIndex];

            if (!tree.Children.ContainsKey(valueIndex))
            {
                Id3Node maxChild = null;
                int     maxCount = int.MinValue;
                foreach (Id3Node child in tree.Children.Values)
                {
                    int count = child.ValueClassCounts.Values.SelectMany(kvp => kvp.Values).Sum();
                    if (count > maxCount)
                    {
                        maxCount = count;
                        maxChild = child;
                    }
                }

                return(GetClass(instance, maxChild));
            }

            return(GetClass(instance, tree.Children[valueIndex]));
        }
Exemplo n.º 2
0
        public static Id3Node BuildTree(List <int[]> instances, int classAttributeIndex, double confidence, int maxDepth)
        {
            bool[]  visitedAttributes = new bool[instances.First().Length];
            Id3Node node = BuildTree(instances, classAttributeIndex, confidence, visitedAttributes, maxDepth);

            return(node);
        }
Exemplo n.º 3
0
        public Id3Classifier(List <int[]> instances, int classIndex, double confidence)
        {
            Instances  = instances;
            Confidence = confidence;

            Tree = Id3Node.BuildTree(instances, classIndex, confidence);
        }
Exemplo n.º 4
0
        /// <summary>
        /// Builds a tree based on the instances received.
        /// </summary>
        /// <param name="instances">The list of instances to train on.</param>
        /// <param name="classAttributeIndex">The index of the class attribute.</param>
        /// <param name="confidence">The confidence threshold for split stop.</param>
        /// <param name="visitedAttributes">Flag array to keep track of which attributes have already been visited so far.</param>
        /// <returns>The trained tree for the given list of instances and confidence threshold</returns>
        public static Id3Node BuildTree(List <int[]> instances, int classAttributeIndex, double confidence, bool[] visitedAttributes)
        {
            // If all instances are of the same class.
            int classType = instances[0][classAttributeIndex];

            if (instances.All(i => i[classAttributeIndex] == classType))
            {
                return(new Id3Node
                {
                    Class = classType
                });
            }

            // If all attributes except the class attribute has been visited already
            bool allAttributesVisited = true;

            for (int i = 0; i < visitedAttributes.Length; i++)
            {
                if (i == classAttributeIndex)
                {
                    continue;
                }

                if (!visitedAttributes[i])
                {
                    allAttributesVisited = false;
                    break;
                }
            }

            // If all attributes have been visited, return the class with higher occurrance.
            if (allAttributesVisited)
            {
                return(GetClassNodeForMax(instances, classAttributeIndex));
            }

            Id3Node bestNode = BestAttributeNode(instances, visitedAttributes, classAttributeIndex);

            if (!IsValidChiSquared(bestNode, confidence))
            {
                return(GetClassNodeForMax(instances, classAttributeIndex));
            }

            bool[] localVisitedAttributes = new bool[visitedAttributes.Length];
            visitedAttributes.CopyTo(localVisitedAttributes, 0);
            localVisitedAttributes[bestNode.AttributeIndex] = true;

            foreach (int attributeValue in bestNode.ValueClassCounts.Keys)
            {
                bestNode.Children[attributeValue]             = BuildTree(instances.Where(i => i[bestNode.AttributeIndex] == attributeValue).ToList(), classAttributeIndex, confidence, localVisitedAttributes);
                bestNode.Children[attributeValue].Parent      = bestNode;
                bestNode.Children[attributeValue].ParentValue = attributeValue;
            }

            return(bestNode);
        }
Exemplo n.º 5
0
 private static int GetCount(Id3Node node)
 {
     if (node == null)
     {
         return(0);
     }
     if (node.IsLeaf)
     {
         return(1);
     }
     return(1 + node.Children.Values.Select(n => GetCount(n)).Sum());
 }
Exemplo n.º 6
0
        private int GetClass(int[] instance, Id3Node tree)
        {
            if (tree.IsLeaf)
            {
                return(tree.Class);
            }

            int     valueIndex = instance[tree.AttributeIndex];
            Id3Node child      = tree.Children.ContainsKey(valueIndex)
                ? tree.Children[valueIndex]
                : tree.MaxChild;

            return(GetClass(instance, child));
        }
Exemplo n.º 7
0
        private static void PrintTreeAsRules(StringBuilder sb, ref StringBuilder sbMaxPositive, ref StringBuilder sbMaxNegative, ref int maxPositive, ref int maxNegative, Id3Node tree, ArffHeader header)
        {
            if (tree.IsLeaf)
            {
                Id3Node       leaf    = tree;
                StringBuilder localSB = new StringBuilder();
                localSB.AppendLine("Rule is:");
                localSB.AppendLine();

                int count = tree.Parent.ValueClassCounts[tree.ParentValue].Values.Sum();
                while (tree.Parent != null)
                {
                    string value = tree.ParentValue == -1 ? "?" : ((ArffNominalAttribute)header.Attributes.ElementAt(tree.Parent.AttributeIndex).Type).Values[tree.ParentValue];
                    localSB.Append($"<{header.Attributes.ElementAt(tree.Parent.AttributeIndex).Name}> equals to <{value}>");
                    localSB.Append(" and ");
                    tree = tree.Parent;
                }
                localSB.AppendLine();
                localSB.AppendLine("---------------------------------------");

                if (leaf.Class == 0)
                {
                    sb.Append(localSB.ToString());
                    if (count > maxPositive)
                    {
                        maxPositive   = count;
                        sbMaxPositive = localSB;
                        sbMaxPositive.AppendLine($"COUNT: {count}");
                    }
                }
                else
                {
                    if (count > maxNegative)
                    {
                        maxNegative   = count;
                        sbMaxNegative = localSB;
                        sbMaxNegative.AppendLine($"COUNT: {count}");
                    }
                }
            }
            else
            {
                foreach (KeyValuePair <int, Id3Node> kvp in tree.Children)
                {
                    PrintTreeAsRules(sb, ref sbMaxPositive, ref sbMaxNegative, ref maxPositive, ref maxNegative, kvp.Value, header);
                }
            }
        }
Exemplo n.º 8
0
        private Id3Node()
        {
            _maxChild = new Lazy <Id3Node>(() =>
            {
                Id3Node maxChild = null;
                int maxCount     = int.MinValue;
                foreach (Id3Node child in Children.Values)
                {
                    int count = child.ValueClassCounts.Values.SelectMany(kvp => kvp.Values).Sum();
                    if (count > maxCount)
                    {
                        maxCount = count;
                        maxChild = child;
                    }
                }

                return(maxChild);
            });
            _count = new Lazy <int>(() =>
            {
                return(GetCount(this));
            });
        }
Exemplo n.º 9
0
        private static Id3Node BestAttributeNode(IEnumerable <int[]> instances, bool[] visitedAttributes, int classAttributeIndex)
        {
            // Each node will have the information on how many times a given value is present for a given class.
            Dictionary <int, Id3Node> attributeNodes = GetAttributeNodes(instances, visitedAttributes, classAttributeIndex);

            // Calculate the conditional entropy and pick the min, as this will maximize mutual information.
            int    minEntropyAttributeIndex = -1;
            double minEntropy = double.MaxValue;

            foreach (KeyValuePair <int, Id3Node> kvp in attributeNodes)
            {
                int     attributeIndex = kvp.Key;
                Id3Node currentNode    = kvp.Value;

                double entropy = GetEntropy(currentNode.ValueClassCounts);
                if (entropy < minEntropy)
                {
                    minEntropy = entropy;
                    minEntropyAttributeIndex = attributeIndex;
                }
            }

            return(attributeNodes[minEntropyAttributeIndex]);
        }
Exemplo n.º 10
0
        public Id3Classifier(List <int[]> instances, int classIndex, double confidence, int maxDepth)
        {
            Confidence = confidence;

            Tree = Id3Node.BuildTree(instances, classIndex, confidence, maxDepth);
        }
Exemplo n.º 11
0
        private static bool IsValidChiSquared(Id3Node node, double confidence)
        {
            // Get total counts for all classes
            double total = 0;
            Dictionary <int, double> allClassCounts = new Dictionary <int, double>();

            foreach (KeyValuePair <int, Dictionary <int, int> > kvp in node.ValueClassCounts)
            {
                Dictionary <int, int> classCounts = kvp.Value;
                foreach (KeyValuePair <int, int> kvpClassCount in classCounts)
                {
                    int classIndex = kvpClassCount.Key;
                    int classCount = kvpClassCount.Value;

                    total += classCount;
                    if (!allClassCounts.ContainsKey(classIndex))
                    {
                        allClassCounts[classIndex] = 0;
                    }

                    allClassCounts[classIndex] += classCount;
                }
            }

            // Get expected counts for all possible values of the attribute
            Dictionary <int, Dictionary <int, double> > expectedCountPerValuePerClass = new Dictionary <int, Dictionary <int, double> >();

            foreach (KeyValuePair <int, Dictionary <int, int> > kvpValueClassCounts in node.ValueClassCounts)
            {
                int valueIndex = kvpValueClassCounts.Key;
                Dictionary <int, int> classCounts = kvpValueClassCounts.Value;

                if (!expectedCountPerValuePerClass.ContainsKey(valueIndex))
                {
                    expectedCountPerValuePerClass[valueIndex] = new Dictionary <int, double>();
                }

                double totalCountForValue = classCounts.Select(kvp => kvp.Value).Sum();

                foreach (KeyValuePair <int, int> kvpClassCount in classCounts)
                {
                    int classIndex = kvpClassCount.Key;
                    int classCount = kvpClassCount.Value;

                    expectedCountPerValuePerClass[valueIndex][classIndex] = allClassCounts[classIndex] * (totalCountForValue / total);
                }
            }

            double distribution = 0;

            foreach (KeyValuePair <int, Dictionary <int, int> > kvpValueClassCounts in node.ValueClassCounts)
            {
                int valueIndex = kvpValueClassCounts.Key;
                Dictionary <int, int> classCounts = kvpValueClassCounts.Value;

                foreach (KeyValuePair <int, int> kvpClassCount in classCounts)
                {
                    int    classIndex    = kvpClassCount.Key;
                    int    classCount    = kvpClassCount.Value;
                    double expectedCount = expectedCountPerValuePerClass[valueIndex][classIndex];

                    distribution += Math.Pow((classCount - expectedCount), 2) / expectedCount;
                }
            }

            int possibleValues = node.ValueClassCounts.Keys.Count;

            if (possibleValues <= 0)
            {
                possibleValues = 1;
            }
            ChiSquared chiSquared = new ChiSquared(possibleValues);

            return(confidence <= chiSquared.CumulativeDistribution(distribution));
        }
Exemplo n.º 12
0
        static void Main(string[] args)
        {
            // Training
            ArffHeader      header    = null;
            List <object[]> instances = new List <object[]>();

            using (ArffReader arffReader = new ArffReader(_arffFile))
            {
                header = arffReader.ReadHeader();
                object[] instance;
                while ((instance = arffReader.ReadInstance()) != null)
                {
                    instances.Add(instance);
                }
            }

            List <int[]> trainingData = new List <int[]>(instances.Select(objectArray => objectArray.Select(o => o == null ? -1 : (int)o).ToArray()));

            // Test
            instances = new List <object[]>();
            using (ArffReader arffReader = new ArffReader(_testArffFile))
            {
                header = arffReader.ReadHeader();
                object[] instance;
                while ((instance = arffReader.ReadInstance()) != null)
                {
                    instances.Add(instance);
                }
            }

            List <int[]> testData = new List <int[]>(instances.Select(objectArray => objectArray.Select(o => o == null ? -1 : (int)o).ToArray()));

            double[] confidences = new double[]
            {
                0.0,
                0.1,
                0.2,
                0.4,
                0.6,
                0.8,
                0.9,
                0.95,
                0.99,
                0.9999
            };

            PrintAsCsv(header, trainingData, @"c:\users\andresz\desktop\data.csv");

            Parallel.ForEach(confidences, confidence =>
            {
                Id3Node tree = Id3Node.BuildTree(trainingData, trainingData[0].Length - 1, confidence);

                Console.WriteLine($"Confidence {confidence}: Num of nodes {GetCount(tree)}");
                // Test accuracy on training
                Console.WriteLine($"Confidence {confidence}: Accuracy on train = { trainingData.Where(instance => GetClass(instance, tree) == instance[trainingData[0].Length - 1]).Count() / (double)trainingData.Count}");

                // Test accuracy on test
                Console.WriteLine($"Confidence {confidence}: Accuracy on test = { testData.Where(instance => GetClass(instance, tree) == instance[testData[0].Length - 1]).Count() / (double)testData.Count}");

                StringBuilder sb            = new StringBuilder();
                StringBuilder sbMaxPositive = new StringBuilder();
                StringBuilder sbMaxNegative = new StringBuilder();
                int maxPositive             = int.MinValue;
                int maxNegative             = int.MinValue;
                // Only print small trees.
                if (confidence > 0.5)
                {
                    PrintTreeAsRules(sb, ref sbMaxPositive, ref sbMaxNegative, ref maxPositive, ref maxNegative, tree, header);
                    sb.AppendLine("The most max positive rule is:");
                    sb.AppendLine(sbMaxPositive.ToString());
                    sb.AppendLine();
                    sb.AppendLine("The most max negative rule is:");
                    sb.AppendLine(sbMaxNegative.ToString());
                    Directory.CreateDirectory(_outputFolder);
                    File.WriteAllText(Path.Combine(_outputFolder, $"Tree{confidence}.txt"), sb.ToString());
                }
            });
        }