Filling out rbtree tests, search routines

This commit is contained in:
Jim Paris 2012-11-28 20:57:23 -05:00
parent 66fa6f3824
commit 0b443f510b
3 changed files with 191 additions and 131 deletions

View File

@ -3,6 +3,9 @@ This is a basic interval tree that holds half-open intervals:
[start, end)
Intervals must not overlap. Fixing that would involve making this
into an augmented interval tree as described in CLRS 14.3.
Code that assumes non-overlapping intervals is marked with the
string 'non-overlapping'.
"""
import sys
@ -17,14 +20,17 @@ class RBNode(object):
self.red = False
self.left = None
self.right = None
self.nil = False
def __str__(self):
if self.red:
color = "R"
else:
color = "B"
return ("[node "
+ str(obj) + " "
if self.start == sys.float_info.min:
return "[node nil]"
return ("[node ("
+ str(self.obj) + ") "
+ str(self.start) + " -> " + str(self.end) + " "
+ color + "]")
@ -38,7 +44,6 @@ class RBTree(object):
self.nil.left = self.nil
self.nil.right = self.nil
self.nil.parent = self.nil
self.nil.nil = True
self.root = RBNode(start = sys.float_info.max,
end = sys.float_info.max)
@ -46,6 +51,11 @@ class RBTree(object):
self.root.right = self.nil
self.root.parent = self.nil
# We have a dummy root node to simplify operations, so from an
# external point of view, its left child is the real root.
def getroot(self):
return self.root.left
# Rotations and basic operations
def __rotate_left(self, x):
"""Rotate left:
@ -88,7 +98,7 @@ class RBTree(object):
y.parent = x
def __successor(self, x):
"""Returns the successor of x"""
"""Returns the successor of RBNode x"""
y = x.right
if y is not self.nil:
while y.left is not self.nil:
@ -101,6 +111,10 @@ class RBTree(object):
if y is self.root:
return self.nil
return y
def successor(self, x):
"""Returns the successor of RBNode x, or None"""
y = self.__successor(x)
return y if y is not self.nil else None
def __predecessor(self, x):
"""Returns the predecessor of RBNode x"""
@ -117,6 +131,10 @@ class RBTree(object):
x = y
y = y.parent
return y
def predecessor(self, x):
"""Returns the predecessor of RBNode x, or None"""
y = self.__predecessor(x)
return y if y is not self.nil else None
# Insertion
def insert(self, z):
@ -264,129 +282,83 @@ class RBTree(object):
x = rootLeft # exit loop
x.red = False
# Rendering
def __render_dot_node(self, node, max_depth = 20):
from printf import sprintf
"""Render a single node and its children into a dot graph fragment"""
if max_depth == 0:
return ""
if node is self.nil:
return ""
def c(red):
if red:
return 'color="#ff0000", style=filled, fillcolor="#ffc0c0"'
else:
return 'color="#000000", style=filled, fillcolor="#c0c0c0"'
s = sprintf("%d [label=\"%g\\n%g\", %s];\n",
id(node),
node.start, node.end,
c(node.red))
if node.left is self.nil:
s += sprintf("L%d [label=\"-\", %s];\n", id(node), c(False))
s += sprintf("%d -> L%d [label=L];\n", id(node), id(node))
else:
s += sprintf("%d -> %d [label=L];\n", id(node), id(node.left))
if node.right is self.nil:
s += sprintf("R%d [label=\"-\", %s];\n", id(node), c(False))
s += sprintf("%d -> R%d [label=R];\n", id(node), id(node))
else:
s += sprintf("%d -> %d [label=R];\n", id(node), id(node.right))
s += self.__render_dot_node(node.left, max_depth-1)
s += self.__render_dot_node(node.right, max_depth-1)
return s
def render_dot(self, title = "RBTree"):
"""Render the entire RBTree as a dot graph"""
return ("digraph rbtree {\n"
+ self.__render_dot_node(self.root.left)
+ "}\n");
def render_dot_live(self, title = "RBTree"):
"""Render the entire RBTree as a dot graph, live GTK view"""
import gtk
import gtk.gdk
sys.path.append("/usr/share/xdot")
import xdot
xdot.Pen.highlighted = lambda pen: pen
s = ("digraph rbtree {\n"
+ self.__render_dot_node(self.root)
+ "}\n");
window = xdot.DotWindow()
window.set_dotcode(s)
window.set_title(title + " - any key to close")
window.connect('destroy', gtk.main_quit)
def quit(widget, event):
if not event.is_modifier:
window.destroy()
gtk.main_quit()
window.widget.connect('key-press-event', quit)
gtk.main()
# Walking, searching
def __iter__(self):
return self.inorder(self.root.left)
return self.inorder()
def inorder(self, x = None):
"""Generator that performs an inorder walk for the tree
starting at RBNode x"""
rooted at RBNode x"""
if x is None:
x = self.root.left
x = self.getroot()
while x.left is not self.nil:
x = x.left
while x is not self.nil:
yield x
x = self.__successor(x)
def __find_all(self, start, end, x):
"""Find node with the specified (start,end) key.
Also returns the largest node less than or equal to key,
and the smallest node greater or equal to than key."""
if x is None:
x = self.root.left
largest = self.nil
smallest = self.nil
def find(self, start, end):
"""Return the node with exactly the given start and end."""
x = self.getroot()
while x is not self.nil:
if start < x.start:
smallest = x
x = x.left # start <
x = x.left
elif start == x.start:
if end < x.end:
smallest = x
x = x.left # start =, end <
elif end == x.end: # found it
smallest = x
largest = x
break
if end == x.end:
break # found it
elif end < x.end:
x = x.left
else:
largest = x
x = x.right # start =, end >
x = x.right
else:
largest = x
x = x.right # start >
return (x, smallest, largest)
x = x.right
return x if x is not self.nil else None
def find(self, start, end, x = None):
"""Find node with the key == (start,end), or None"""
y = self.__find_all(start, end, x)[1]
return y if y is not self.nil else None
def find_left_end(self, t):
"""Find the leftmode node with end >= t. With non-overlapping
intervals, this is the first node that might overlap time t.
def find_right(self, start, end, x = None):
"""Find node with the smallest key >= (start,end), or None"""
y = self.__find_all(start, end, x)[1]
return y if y is not self.nil else None
Note that this relies on non-overlapping intervals, since
it assumes that we can use the endpoints to traverse the
tree even though it was created using the start points."""
x = self.getroot()
while x is not self.nil:
if t < x.end:
if x.left is self.nil:
break
x = x.left
elif t == x.end:
break
else:
if x.right is self.nil:
x = self.__successor(x)
break
x = x.right
return x if x is not self.nil else None
def find_left(self, start, end, x = None):
"""Find node with the largest key <= (start,end), or None"""
y = self.__find_all(start, end, x)[2]
return y if y is not self.nil else None
def find_right_start(self, t):
"""Find the rightmode node with start <= t. With non-overlapping
intervals, this is the last node that might overlap time t."""
x = self.getroot()
while x is not self.nil:
if t < x.start:
if x.left is self.nil:
x = self.__predecessor(x)
break
x = x.left
elif t == x.start:
break
else:
if x.right is self.nil:
break
x = x.right
return x if x is not self.nil else None
# Intersections
def intersect(self, start, end):
"""Generator that returns nodes that overlap the given
(start,end) range, for the tree rooted at RBNode x.
(start,end) range. Assumes non-overlapping intervals."""
NOTE: this assumes non-overlapping intervals."""
# Start with the leftmost node before the starting point
n = self.find_left(start, start)
# If we didn't find one, look for the leftmode node before the

View File

@ -12,8 +12,8 @@ stop=
verbosity=2
#tests=tests/test_cmdline.py
#tests=tests/test_layout.py
#tests=tests/test_rbtree.py
tests=tests/test_interval.py
tests=tests/test_rbtree.py
#tests=tests/test_interval.py
#tests=tests/test_client.py
#tests=tests/test_timestamper.py
#tests=tests/test_serializer.py

View File

@ -11,17 +11,33 @@ from nilmdb.rbtree import RBTree, RBNode
from test_helpers import *
import unittest
render = False
# set to False to skip live renders
do_live_renders = False
def render(tree, description = "", live = True):
import renderdot
if not do_live_renders:
# If not doing a live render, still render it to a string.
live = False
r = renderdot.Renderer(lambda node: node.left,
lambda node: node.right,
lambda node: node.red,
lambda node: node.start,
lambda node: node.end,
tree.nil)
if live:
r.render_dot_live(tree.getroot(), description)
else:
return r.render_dot(tree.getroot())
class TestRBTree:
def test_rbtree(self):
rb = RBTree()
rb.insert(RBNode(None, 10000, 10001))
rb.insert(RBNode(None, 10004, 10007))
rb.insert(RBNode(None, 10001, 10002))
s = rb.render_dot()
rb.insert(RBNode(10000, 10001))
rb.insert(RBNode(10004, 10007))
rb.insert(RBNode(10001, 10002))
# There was a typo that gave the RBTree a loop in this case.
# Verify that the dot isn't too big.
s = render(rb, live = False)
assert(len(s.splitlines()) < 30)
def test_rbtree_big(self):
@ -32,44 +48,116 @@ class TestRBTree:
rb = RBTree()
j = 500
for i in xrange(j):
rb.insert(RBNode(None, i, i+1))
# show the graph
if render:
rb.render_dot_live("in-order insert")
rb.insert(RBNode(i, i+1))
render(rb, "in-order insert")
# remove about half of them
for i in random.sample(xrange(j),j):
if random.randint(0,1):
rb.delete(rb.find(i, i+1))
# show the graph
if render:
rb.render_dot_live("in-order insert, random delete")
render(rb, "in-order insert, random delete")
# make a set of 500 intervals, inserted at random
rb = RBTree()
j = 500
for i in random.sample(xrange(j),j):
rb.insert(RBNode(None, i, i+1))
# show the graph
if render:
rb.render_dot_live("random insert")
rb.insert(RBNode(i, i+1))
render(rb, "random insert")
# remove about half of them
for i in random.sample(xrange(j),j):
if random.randint(0,1):
rb.delete(rb.find(i, i+1))
# show the graph
if render:
rb.render_dot_live("random insert, random delete")
render(rb, "random insert, random delete")
# in-order insert of 250 more
for i in xrange(250):
rb.insert(RBNode(None, i+500, i+501))
rb.insert(RBNode(i+500, i+501))
render(rb, "random insert, random delete, in-order insert")
# show the graph
if render:
rb.render_dot_live("random insert, random delete, in-order insert")
def test_rbtree_basics(self):
rb = RBTree()
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
for n in vals:
rb.insert(RBNode(n, n))
# stringify
s = ""
for node in rb:
s += str(node)
assert "[node (None) 1 -> 1 B]" in s
assert str(rb.nil) == "[node nil]"
# inorder traversal, successor and predecessor
last = 0
for node in rb:
assert(node.start > last)
last = node.start
successor = rb.successor(node)
if successor:
assert(rb.predecessor(successor) is node)
predecessor = rb.predecessor(node)
if predecessor:
assert(rb.successor(predecessor) is node)
# Delete node not in the tree
with assert_raises(AttributeError):
rb.delete(RBNode(1,2))
# Delete all nodes!
for node in rb:
rb.delete(node)
# Build it up again, make sure it matches
for n in vals:
rb.insert(RBNode(n, n))
s2 = ""
for node in rb:
s2 += str(node)
assert(s == s2)
def test_rbtree_find(self):
# Get a little bit of coverage for some overlapping cases,
# even though the class doesn't fully support it.
rb = RBTree()
nodes = [ RBNode(1, 5), RBNode(1, 10), RBNode(1, 15) ]
for n in nodes:
rb.insert(n)
assert(rb.find(1, 5) is nodes[0])
assert(rb.find(1, 10) is nodes[1])
assert(rb.find(1, 15) is nodes[2])
def test_rbtree_find_leftright(self):
# Now let's get some ranges in there
rb = RBTree()
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
for n in vals:
rb.insert(RBNode(n*10, n*10+5))
# Check find_end_left, find_right_start
for i in range(160):
left = rb.find_left_end(i)
right = rb.find_right_start(i)
if left:
# endpoint should be more than i
assert(left.end >= i)
# all earlier nodes should have a lower endpoint
for node in rb:
if node is left:
break
assert(node.end < i)
if right:
# startpoint should be less than i
assert(right.start <= i)
# all later nodes should have a higher startpoint
for node in reversed(list(rb)):
if node is right:
break
assert(node.start > i)
def test_rbtree_intersect(self):
# Fill with some ranges again
rb = RBTree()
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
for n in vals:
rb.insert(RBNode(n*10, n*10+5))