private static void AreEqual(Node n1, Node n2, bool ignoreParent)
        {
            Assert.AreEqual(n1.IsLeaf, n2.IsLeaf);
            Assert.AreEqual(n1.Value, n2.Value);
            Assert.AreEqual(n1.Label, n2.Label);
            Assert.AreEqual(n1.Column, n2.Column);
            Assert.AreEqual(n1.Name, n2.Name);
            if (n1.Edges == null) Assert.IsNull(n2.Edges);
            else
            {
                Assert.AreEqual(n1.Edges.Length, n2.Edges.Length);

                // since we are not ignoring parent,
                // no need to check edges since these
                // are checked when ignoring parents
                if (ignoreParent) return;
                for (int i = 0; i < n1.Edges.Length; i++)
                {
                    var n1e = n1.Edges[i];
                    var n2e = n2.Edges[i];
                    Assert.AreEqual(n1e.Min, n2e.Min);
                    Assert.AreEqual(n1e.Max, n2e.Max);
                    Assert.AreEqual(n1e.Discrete, n2e.Discrete);
                    Assert.AreEqual(n1e.Label, n2e.Label);

                    if (!ignoreParent)
                        AreEqual(n1e.Parent, n2e.Parent, true);

                    AreEqual(n1e.Child, n2e.Child, true);
                }
            }
        }
예제 #2
0
        public override void ReadXml(XmlReader reader)
        {
            reader.MoveToContent();
            Hint = double.Parse(reader.GetAttribute("Hint"));
            reader.ReadStartElement();

            Descriptor = ReadXml<Descriptor>(reader);
            Tree = ReadXml<Node>(reader);

            // re-establish tree cycles and values
            ReLinkNodes(Tree);
        }
예제 #3
0
        private string PrintNode(Node n, string pre)
        {
            if (n.IsLeaf)
                return String.Format("{0} +({1}, {2:#.####})\n", pre, n.Label, n.Value);
            else
            {
                StringBuilder sb = new StringBuilder();
                sb.AppendLine(String.Format("{0}[{1}, {2:0.0000}]", pre, n.Name, n.Gain));
                foreach (Edge edge in n.Edges)
                {
                    sb.AppendLine(String.Format("{0} |- {1}", pre, edge.Label));
                    sb.Append(PrintNode(edge.Child, String.Format("{0} |\t", pre)));
                }

                return sb.ToString();
            }
        }
예제 #4
0
        /// <summary>Print node.</summary>
        /// <param name="n">The Node to process.</param>
        /// <param name="pre">The pre.</param>
        /// <returns>A string.</returns>
        private string PrintNode(Node n, string pre)
        {
            if (n.IsLeaf)
                return String.Format("{0} +({1}, {2})\n", pre, Descriptor.Label.Convert(n.Value), n.Value);
            else
            {
                StringBuilder sb = new StringBuilder();
                sb.AppendLine(String.Format("{0}[{1}, {2:0.0000}]", pre, n.Name, n.Gain));
                foreach (Edge edge in Tree.GetOutEdges(n))
                {
                    sb.AppendLine(String.Format("{0} |- {1}", pre, edge.Label));
                    sb.Append(PrintNode((Node)Tree.GetVertex(edge.ChildId), String.Format("{0} |\t", pre)));
                }

                return sb.ToString();
            }
        }
예제 #5
0
        /// <summary>Walk node.</summary>
        /// <exception cref="InvalidOperationException">Thrown when the requested operation is invalid.</exception>
        /// <param name="v">The Vector to process.</param>
        /// <param name="node">The node.</param>
        /// <returns>A double.</returns>
        private double WalkNode(Vector v, Node node)
        {
            if (node.IsLeaf)
                return node.Value;

            // Get the index of the feature for this node.
            var col = node.Column;
            if (col == -1)
                throw new InvalidOperationException("Invalid Feature encountered during node walk!");

            for (int i = 0; i < node.Edges.Length; i++)
            {
                Edge edge = node.Edges[i];
                if (edge.Discrete && v[col] == edge.Min)
                    return WalkNode(v, edge.Child);
                if (!edge.Discrete && v[col] >= edge.Min && v[col] < edge.Max)
                    return WalkNode(v, edge.Child);
            }

            if (Hint != double.Epsilon)
                return Hint;
            else
                throw new InvalidOperationException(String.Format("Unable to match split value {0} for feature {1}[2]\nConsider setting a Hint in order to avoid this error.", v[col], Descriptor.At(col), col));
        }
예제 #6
0
 private void ReLinkNodes(Node n)
 {
     if (n.Edges != null)
     {
         foreach (Edge e in n.Edges)
         {
             e.Parent = n;
             if (e.Child.IsLeaf)
                 e.Child.Label = Descriptor.Label.Convert(e.Child.Value);
             else
                 ReLinkNodes(e.Child);
         }
     }
 }
예제 #7
0
        /// <summary>Builds a tree.</summary>
        /// <param name="x">The Matrix to process.</param>
        /// <param name="y">The Vector to process.</param>
        /// <param name="depth">The depth.</param>
        /// <param name="used">The used.</param>
        /// <returns>A Node.</returns>
        private Node BuildTree(Matrix x, Vector y, int depth, List<int> used, Tree tree)
        {
            if (depth < 0)
                return BuildLeafNode(y.Mode());

            var tuple = GetBestSplit(x, y, used);
            var col = tuple.Item1;
            var gain = tuple.Item2;
            var measure = tuple.Item3;

            // uh oh, need to return something?
            // a weird node of some sort...
            // but just in case...
            if (col == -1)
                return BuildLeafNode(y.Mode());

            used.Add(col);

            Node node = new Node
            {
                Column = col,
                Gain = gain,
                IsLeaf = false,
                Name = Descriptor.ColumnAt(col)
            };

            // populate edges
            List<Edge> edges = new List<Edge>(measure.Segments.Length);
            for (int i = 0; i < measure.Segments.Length; i++)
            {
                // working set
                var segment = measure.Segments[i];
                var edge = new Edge()
                {
                    ParentId = node.Id,
                    Discrete = measure.Discrete,
                    Min = segment.Min,
                    Max = segment.Max
                };

                IEnumerable<int> slice;

                if (edge.Discrete)
                {
                    // get discrete label
                    edge.Label = Descriptor.At(col).Convert(segment.Min).ToString();
                    // do value check for matrix slicing
                    slice = x.Indices(v => v[col] == segment.Min);
                }
                else
                {
                    // get range label
                    edge.Label = string.Format("{0} <= x < {1}", segment.Min, segment.Max);
                    // do range check for matrix slicing
                    slice = x.Indices(v => v[col] >= segment.Min && v[col] < segment.Max);
                }

                // something to look at?
                // if this number is 0 then this edge
                // leads to a dead end - the edge will
                // not be built
                if (slice.Count() > 0)
                {
                    Vector ySlice = y.Slice(slice);
                    // only one answer, set leaf
                    if (ySlice.Distinct().Count() == 1)
                    {
                        var child = BuildLeafNode(ySlice[0]);
                        tree.AddVertex(child);
                        edge.ChildId = child.Id;
                    }
                    // otherwise continue to build tree
                    else
                    {
                        var child = BuildTree(x.Slice(slice), ySlice, depth - 1, used, tree);
                        tree.AddVertex(child);
                        edge.ChildId = child.Id;
                    }

                    edges.Add(edge);
                }
            }

            // problem, need to convert
            // parent to terminal node
            // with mode
            if (edges.Count <= 1)
            {
                var val = y.Mode();
                node.IsLeaf = true;
                node.Value = val;
            }

            tree.AddVertex(node);

            if(edges.Count > 1)
                foreach (var e in edges)
                    tree.AddEdge(e);

            return node;
        }
예제 #8
0
        public void ReadXml(XmlReader reader)
        {
            reader.MoveToContent();
            Hint = double.Parse(reader.GetAttribute("Hint"));
            reader.ReadStartElement();

            XmlSerializer dserializer = new XmlSerializer(typeof(Descriptor));
            Descriptor = (Descriptor)dserializer.Deserialize(reader);
            reader.Read();

            XmlSerializer serializer = new XmlSerializer(typeof(Node));
            Tree = (Node)serializer.Deserialize(reader);
            // re-establish tree cycles and values
            ReLinkNodes(Tree);
        }