diff --git a/nilmdb/rbtree.py b/nilmdb/rbtree.py index d935f1a..ba1f723 100644 --- a/nilmdb/rbtree.py +++ b/nilmdb/rbtree.py @@ -3,6 +3,9 @@ This is a basic interval tree that holds half-open intervals: [start, end) Intervals must not overlap. Fixing that would involve making this into an augmented interval tree as described in CLRS 14.3. + +Code that assumes non-overlapping intervals is marked with the +string 'non-overlapping'. """ import sys @@ -17,14 +20,17 @@ class RBNode(object): self.red = False self.left = None self.right = None + self.nil = False def __str__(self): if self.red: color = "R" else: color = "B" - return ("[node " - + str(obj) + " " + if self.start == sys.float_info.min: + return "[node nil]" + return ("[node (" + + str(self.obj) + ") " + str(self.start) + " -> " + str(self.end) + " " + color + "]") @@ -38,7 +44,6 @@ class RBTree(object): self.nil.left = self.nil self.nil.right = self.nil self.nil.parent = self.nil - self.nil.nil = True self.root = RBNode(start = sys.float_info.max, end = sys.float_info.max) @@ -46,6 +51,11 @@ class RBTree(object): self.root.right = self.nil self.root.parent = self.nil + # We have a dummy root node to simplify operations, so from an + # external point of view, its left child is the real root. + def getroot(self): + return self.root.left + # Rotations and basic operations def __rotate_left(self, x): """Rotate left: @@ -88,7 +98,7 @@ class RBTree(object): y.parent = x def __successor(self, x): - """Returns the successor of x""" + """Returns the successor of RBNode x""" y = x.right if y is not self.nil: while y.left is not self.nil: @@ -101,6 +111,10 @@ class RBTree(object): if y is self.root: return self.nil return y + def successor(self, x): + """Returns the successor of RBNode x, or None""" + y = self.__successor(x) + return y if y is not self.nil else None def __predecessor(self, x): """Returns the predecessor of RBNode x""" @@ -117,6 +131,10 @@ class RBTree(object): x = y y = y.parent return y + def predecessor(self, x): + """Returns the predecessor of RBNode x, or None""" + y = self.__predecessor(x) + return y if y is not self.nil else None # Insertion def insert(self, z): @@ -264,129 +282,83 @@ class RBTree(object): x = rootLeft # exit loop x.red = False - # Rendering - def __render_dot_node(self, node, max_depth = 20): - from printf import sprintf - """Render a single node and its children into a dot graph fragment""" - if max_depth == 0: - return "" - if node is self.nil: - return "" - def c(red): - if red: - return 'color="#ff0000", style=filled, fillcolor="#ffc0c0"' - else: - return 'color="#000000", style=filled, fillcolor="#c0c0c0"' - s = sprintf("%d [label=\"%g\\n%g\", %s];\n", - id(node), - node.start, node.end, - c(node.red)) - - if node.left is self.nil: - s += sprintf("L%d [label=\"-\", %s];\n", id(node), c(False)) - s += sprintf("%d -> L%d [label=L];\n", id(node), id(node)) - else: - s += sprintf("%d -> %d [label=L];\n", id(node), id(node.left)) - if node.right is self.nil: - s += sprintf("R%d [label=\"-\", %s];\n", id(node), c(False)) - s += sprintf("%d -> R%d [label=R];\n", id(node), id(node)) - else: - s += sprintf("%d -> %d [label=R];\n", id(node), id(node.right)) - s += self.__render_dot_node(node.left, max_depth-1) - s += self.__render_dot_node(node.right, max_depth-1) - return s - - def render_dot(self, title = "RBTree"): - """Render the entire RBTree as a dot graph""" - return ("digraph rbtree {\n" - + self.__render_dot_node(self.root.left) - + "}\n"); - - def render_dot_live(self, title = "RBTree"): - """Render the entire RBTree as a dot graph, live GTK view""" - import gtk - import gtk.gdk - sys.path.append("/usr/share/xdot") - import xdot - xdot.Pen.highlighted = lambda pen: pen - s = ("digraph rbtree {\n" - + self.__render_dot_node(self.root) - + "}\n"); - window = xdot.DotWindow() - window.set_dotcode(s) - window.set_title(title + " - any key to close") - window.connect('destroy', gtk.main_quit) - def quit(widget, event): - if not event.is_modifier: - window.destroy() - gtk.main_quit() - window.widget.connect('key-press-event', quit) - gtk.main() - # Walking, searching def __iter__(self): - return self.inorder(self.root.left) + return self.inorder() def inorder(self, x = None): """Generator that performs an inorder walk for the tree - starting at RBNode x""" + rooted at RBNode x""" if x is None: - x = self.root.left + x = self.getroot() while x.left is not self.nil: x = x.left while x is not self.nil: yield x x = self.__successor(x) - def __find_all(self, start, end, x): - """Find node with the specified (start,end) key. - Also returns the largest node less than or equal to key, - and the smallest node greater or equal to than key.""" - if x is None: - x = self.root.left - largest = self.nil - smallest = self.nil + def find(self, start, end): + """Return the node with exactly the given start and end.""" + x = self.getroot() while x is not self.nil: if start < x.start: - smallest = x - x = x.left # start < + x = x.left elif start == x.start: - if end < x.end: - smallest = x - x = x.left # start =, end < - elif end == x.end: # found it - smallest = x - largest = x - break + if end == x.end: + break # found it + elif end < x.end: + x = x.left else: - largest = x - x = x.right # start =, end > + x = x.right else: - largest = x - x = x.right # start > - return (x, smallest, largest) + x = x.right + return x if x is not self.nil else None - def find(self, start, end, x = None): - """Find node with the key == (start,end), or None""" - y = self.__find_all(start, end, x)[1] - return y if y is not self.nil else None + def find_left_end(self, t): + """Find the leftmode node with end >= t. With non-overlapping + intervals, this is the first node that might overlap time t. - def find_right(self, start, end, x = None): - """Find node with the smallest key >= (start,end), or None""" - y = self.__find_all(start, end, x)[1] - return y if y is not self.nil else None + Note that this relies on non-overlapping intervals, since + it assumes that we can use the endpoints to traverse the + tree even though it was created using the start points.""" + x = self.getroot() + while x is not self.nil: + if t < x.end: + if x.left is self.nil: + break + x = x.left + elif t == x.end: + break + else: + if x.right is self.nil: + x = self.__successor(x) + break + x = x.right + return x if x is not self.nil else None - def find_left(self, start, end, x = None): - """Find node with the largest key <= (start,end), or None""" - y = self.__find_all(start, end, x)[2] - return y if y is not self.nil else None + def find_right_start(self, t): + """Find the rightmode node with start <= t. With non-overlapping + intervals, this is the last node that might overlap time t.""" + x = self.getroot() + while x is not self.nil: + if t < x.start: + if x.left is self.nil: + x = self.__predecessor(x) + break + x = x.left + elif t == x.start: + break + else: + if x.right is self.nil: + break + x = x.right + return x if x is not self.nil else None # Intersections def intersect(self, start, end): """Generator that returns nodes that overlap the given - (start,end) range, for the tree rooted at RBNode x. + (start,end) range. Assumes non-overlapping intervals.""" - NOTE: this assumes non-overlapping intervals.""" # Start with the leftmost node before the starting point n = self.find_left(start, start) # If we didn't find one, look for the leftmode node before the diff --git a/setup.cfg b/setup.cfg index 45f94e4..da015c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,8 +12,8 @@ stop= verbosity=2 #tests=tests/test_cmdline.py #tests=tests/test_layout.py -#tests=tests/test_rbtree.py -tests=tests/test_interval.py +tests=tests/test_rbtree.py +#tests=tests/test_interval.py #tests=tests/test_client.py #tests=tests/test_timestamper.py #tests=tests/test_serializer.py diff --git a/tests/test_rbtree.py b/tests/test_rbtree.py index 2dda853..0171070 100644 --- a/tests/test_rbtree.py +++ b/tests/test_rbtree.py @@ -11,17 +11,33 @@ from nilmdb.rbtree import RBTree, RBNode from test_helpers import * import unittest -render = False +# set to False to skip live renders +do_live_renders = False +def render(tree, description = "", live = True): + import renderdot + if not do_live_renders: + # If not doing a live render, still render it to a string. + live = False + r = renderdot.Renderer(lambda node: node.left, + lambda node: node.right, + lambda node: node.red, + lambda node: node.start, + lambda node: node.end, + tree.nil) + if live: + r.render_dot_live(tree.getroot(), description) + else: + return r.render_dot(tree.getroot()) class TestRBTree: def test_rbtree(self): rb = RBTree() - rb.insert(RBNode(None, 10000, 10001)) - rb.insert(RBNode(None, 10004, 10007)) - rb.insert(RBNode(None, 10001, 10002)) - s = rb.render_dot() + rb.insert(RBNode(10000, 10001)) + rb.insert(RBNode(10004, 10007)) + rb.insert(RBNode(10001, 10002)) # There was a typo that gave the RBTree a loop in this case. # Verify that the dot isn't too big. + s = render(rb, live = False) assert(len(s.splitlines()) < 30) def test_rbtree_big(self): @@ -32,44 +48,116 @@ class TestRBTree: rb = RBTree() j = 500 for i in xrange(j): - rb.insert(RBNode(None, i, i+1)) - - # show the graph - if render: - rb.render_dot_live("in-order insert") + rb.insert(RBNode(i, i+1)) + render(rb, "in-order insert") # remove about half of them for i in random.sample(xrange(j),j): if random.randint(0,1): rb.delete(rb.find(i, i+1)) - - # show the graph - if render: - rb.render_dot_live("in-order insert, random delete") + render(rb, "in-order insert, random delete") # make a set of 500 intervals, inserted at random rb = RBTree() j = 500 for i in random.sample(xrange(j),j): - rb.insert(RBNode(None, i, i+1)) - - # show the graph - if render: - rb.render_dot_live("random insert") + rb.insert(RBNode(i, i+1)) + render(rb, "random insert") # remove about half of them for i in random.sample(xrange(j),j): if random.randint(0,1): rb.delete(rb.find(i, i+1)) - - # show the graph - if render: - rb.render_dot_live("random insert, random delete") + render(rb, "random insert, random delete") # in-order insert of 250 more for i in xrange(250): - rb.insert(RBNode(None, i+500, i+501)) + rb.insert(RBNode(i+500, i+501)) + render(rb, "random insert, random delete, in-order insert") - # show the graph - if render: - rb.render_dot_live("random insert, random delete, in-order insert") + def test_rbtree_basics(self): + rb = RBTree() + vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4] + for n in vals: + rb.insert(RBNode(n, n)) + + # stringify + s = "" + for node in rb: + s += str(node) + assert "[node (None) 1 -> 1 B]" in s + assert str(rb.nil) == "[node nil]" + + # inorder traversal, successor and predecessor + last = 0 + for node in rb: + assert(node.start > last) + last = node.start + successor = rb.successor(node) + if successor: + assert(rb.predecessor(successor) is node) + predecessor = rb.predecessor(node) + if predecessor: + assert(rb.successor(predecessor) is node) + + # Delete node not in the tree + with assert_raises(AttributeError): + rb.delete(RBNode(1,2)) + + # Delete all nodes! + for node in rb: + rb.delete(node) + + # Build it up again, make sure it matches + for n in vals: + rb.insert(RBNode(n, n)) + s2 = "" + for node in rb: + s2 += str(node) + assert(s == s2) + + def test_rbtree_find(self): + # Get a little bit of coverage for some overlapping cases, + # even though the class doesn't fully support it. + rb = RBTree() + nodes = [ RBNode(1, 5), RBNode(1, 10), RBNode(1, 15) ] + for n in nodes: + rb.insert(n) + assert(rb.find(1, 5) is nodes[0]) + assert(rb.find(1, 10) is nodes[1]) + assert(rb.find(1, 15) is nodes[2]) + + def test_rbtree_find_leftright(self): + # Now let's get some ranges in there + rb = RBTree() + vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4] + for n in vals: + rb.insert(RBNode(n*10, n*10+5)) + + # Check find_end_left, find_right_start + for i in range(160): + left = rb.find_left_end(i) + right = rb.find_right_start(i) + if left: + # endpoint should be more than i + assert(left.end >= i) + # all earlier nodes should have a lower endpoint + for node in rb: + if node is left: + break + assert(node.end < i) + if right: + # startpoint should be less than i + assert(right.start <= i) + # all later nodes should have a higher startpoint + for node in reversed(list(rb)): + if node is right: + break + assert(node.start > i) + + def test_rbtree_intersect(self): + # Fill with some ranges again + rb = RBTree() + vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4] + for n in vals: + rb.insert(RBNode(n*10, n*10+5))