Filling out rbtree tests, search routines
This commit is contained in:
parent
66fa6f3824
commit
0b443f510b
176
nilmdb/rbtree.py
176
nilmdb/rbtree.py
|
@ -3,6 +3,9 @@ This is a basic interval tree that holds half-open intervals:
|
|||
[start, end)
|
||||
Intervals must not overlap. Fixing that would involve making this
|
||||
into an augmented interval tree as described in CLRS 14.3.
|
||||
|
||||
Code that assumes non-overlapping intervals is marked with the
|
||||
string 'non-overlapping'.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
@ -17,14 +20,17 @@ class RBNode(object):
|
|||
self.red = False
|
||||
self.left = None
|
||||
self.right = None
|
||||
self.nil = False
|
||||
|
||||
def __str__(self):
|
||||
if self.red:
|
||||
color = "R"
|
||||
else:
|
||||
color = "B"
|
||||
return ("[node "
|
||||
+ str(obj) + " "
|
||||
if self.start == sys.float_info.min:
|
||||
return "[node nil]"
|
||||
return ("[node ("
|
||||
+ str(self.obj) + ") "
|
||||
+ str(self.start) + " -> " + str(self.end) + " "
|
||||
+ color + "]")
|
||||
|
||||
|
@ -38,7 +44,6 @@ class RBTree(object):
|
|||
self.nil.left = self.nil
|
||||
self.nil.right = self.nil
|
||||
self.nil.parent = self.nil
|
||||
self.nil.nil = True
|
||||
|
||||
self.root = RBNode(start = sys.float_info.max,
|
||||
end = sys.float_info.max)
|
||||
|
@ -46,6 +51,11 @@ class RBTree(object):
|
|||
self.root.right = self.nil
|
||||
self.root.parent = self.nil
|
||||
|
||||
# We have a dummy root node to simplify operations, so from an
|
||||
# external point of view, its left child is the real root.
|
||||
def getroot(self):
|
||||
return self.root.left
|
||||
|
||||
# Rotations and basic operations
|
||||
def __rotate_left(self, x):
|
||||
"""Rotate left:
|
||||
|
@ -88,7 +98,7 @@ class RBTree(object):
|
|||
y.parent = x
|
||||
|
||||
def __successor(self, x):
|
||||
"""Returns the successor of x"""
|
||||
"""Returns the successor of RBNode x"""
|
||||
y = x.right
|
||||
if y is not self.nil:
|
||||
while y.left is not self.nil:
|
||||
|
@ -101,6 +111,10 @@ class RBTree(object):
|
|||
if y is self.root:
|
||||
return self.nil
|
||||
return y
|
||||
def successor(self, x):
|
||||
"""Returns the successor of RBNode x, or None"""
|
||||
y = self.__successor(x)
|
||||
return y if y is not self.nil else None
|
||||
|
||||
def __predecessor(self, x):
|
||||
"""Returns the predecessor of RBNode x"""
|
||||
|
@ -117,6 +131,10 @@ class RBTree(object):
|
|||
x = y
|
||||
y = y.parent
|
||||
return y
|
||||
def predecessor(self, x):
|
||||
"""Returns the predecessor of RBNode x, or None"""
|
||||
y = self.__predecessor(x)
|
||||
return y if y is not self.nil else None
|
||||
|
||||
# Insertion
|
||||
def insert(self, z):
|
||||
|
@ -264,129 +282,83 @@ class RBTree(object):
|
|||
x = rootLeft # exit loop
|
||||
x.red = False
|
||||
|
||||
# Rendering
|
||||
def __render_dot_node(self, node, max_depth = 20):
|
||||
from printf import sprintf
|
||||
"""Render a single node and its children into a dot graph fragment"""
|
||||
if max_depth == 0:
|
||||
return ""
|
||||
if node is self.nil:
|
||||
return ""
|
||||
def c(red):
|
||||
if red:
|
||||
return 'color="#ff0000", style=filled, fillcolor="#ffc0c0"'
|
||||
else:
|
||||
return 'color="#000000", style=filled, fillcolor="#c0c0c0"'
|
||||
s = sprintf("%d [label=\"%g\\n%g\", %s];\n",
|
||||
id(node),
|
||||
node.start, node.end,
|
||||
c(node.red))
|
||||
|
||||
if node.left is self.nil:
|
||||
s += sprintf("L%d [label=\"-\", %s];\n", id(node), c(False))
|
||||
s += sprintf("%d -> L%d [label=L];\n", id(node), id(node))
|
||||
else:
|
||||
s += sprintf("%d -> %d [label=L];\n", id(node), id(node.left))
|
||||
if node.right is self.nil:
|
||||
s += sprintf("R%d [label=\"-\", %s];\n", id(node), c(False))
|
||||
s += sprintf("%d -> R%d [label=R];\n", id(node), id(node))
|
||||
else:
|
||||
s += sprintf("%d -> %d [label=R];\n", id(node), id(node.right))
|
||||
s += self.__render_dot_node(node.left, max_depth-1)
|
||||
s += self.__render_dot_node(node.right, max_depth-1)
|
||||
return s
|
||||
|
||||
def render_dot(self, title = "RBTree"):
|
||||
"""Render the entire RBTree as a dot graph"""
|
||||
return ("digraph rbtree {\n"
|
||||
+ self.__render_dot_node(self.root.left)
|
||||
+ "}\n");
|
||||
|
||||
def render_dot_live(self, title = "RBTree"):
|
||||
"""Render the entire RBTree as a dot graph, live GTK view"""
|
||||
import gtk
|
||||
import gtk.gdk
|
||||
sys.path.append("/usr/share/xdot")
|
||||
import xdot
|
||||
xdot.Pen.highlighted = lambda pen: pen
|
||||
s = ("digraph rbtree {\n"
|
||||
+ self.__render_dot_node(self.root)
|
||||
+ "}\n");
|
||||
window = xdot.DotWindow()
|
||||
window.set_dotcode(s)
|
||||
window.set_title(title + " - any key to close")
|
||||
window.connect('destroy', gtk.main_quit)
|
||||
def quit(widget, event):
|
||||
if not event.is_modifier:
|
||||
window.destroy()
|
||||
gtk.main_quit()
|
||||
window.widget.connect('key-press-event', quit)
|
||||
gtk.main()
|
||||
|
||||
# Walking, searching
|
||||
def __iter__(self):
|
||||
return self.inorder(self.root.left)
|
||||
return self.inorder()
|
||||
|
||||
def inorder(self, x = None):
|
||||
"""Generator that performs an inorder walk for the tree
|
||||
starting at RBNode x"""
|
||||
rooted at RBNode x"""
|
||||
if x is None:
|
||||
x = self.root.left
|
||||
x = self.getroot()
|
||||
while x.left is not self.nil:
|
||||
x = x.left
|
||||
while x is not self.nil:
|
||||
yield x
|
||||
x = self.__successor(x)
|
||||
|
||||
def __find_all(self, start, end, x):
|
||||
"""Find node with the specified (start,end) key.
|
||||
Also returns the largest node less than or equal to key,
|
||||
and the smallest node greater or equal to than key."""
|
||||
if x is None:
|
||||
x = self.root.left
|
||||
largest = self.nil
|
||||
smallest = self.nil
|
||||
def find(self, start, end):
|
||||
"""Return the node with exactly the given start and end."""
|
||||
x = self.getroot()
|
||||
while x is not self.nil:
|
||||
if start < x.start:
|
||||
smallest = x
|
||||
x = x.left # start <
|
||||
x = x.left
|
||||
elif start == x.start:
|
||||
if end < x.end:
|
||||
smallest = x
|
||||
x = x.left # start =, end <
|
||||
elif end == x.end: # found it
|
||||
smallest = x
|
||||
largest = x
|
||||
break
|
||||
if end == x.end:
|
||||
break # found it
|
||||
elif end < x.end:
|
||||
x = x.left
|
||||
else:
|
||||
largest = x
|
||||
x = x.right # start =, end >
|
||||
x = x.right
|
||||
else:
|
||||
largest = x
|
||||
x = x.right # start >
|
||||
return (x, smallest, largest)
|
||||
x = x.right
|
||||
return x if x is not self.nil else None
|
||||
|
||||
def find(self, start, end, x = None):
|
||||
"""Find node with the key == (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[1]
|
||||
return y if y is not self.nil else None
|
||||
def find_left_end(self, t):
|
||||
"""Find the leftmode node with end >= t. With non-overlapping
|
||||
intervals, this is the first node that might overlap time t.
|
||||
|
||||
def find_right(self, start, end, x = None):
|
||||
"""Find node with the smallest key >= (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[1]
|
||||
return y if y is not self.nil else None
|
||||
Note that this relies on non-overlapping intervals, since
|
||||
it assumes that we can use the endpoints to traverse the
|
||||
tree even though it was created using the start points."""
|
||||
x = self.getroot()
|
||||
while x is not self.nil:
|
||||
if t < x.end:
|
||||
if x.left is self.nil:
|
||||
break
|
||||
x = x.left
|
||||
elif t == x.end:
|
||||
break
|
||||
else:
|
||||
if x.right is self.nil:
|
||||
x = self.__successor(x)
|
||||
break
|
||||
x = x.right
|
||||
return x if x is not self.nil else None
|
||||
|
||||
def find_left(self, start, end, x = None):
|
||||
"""Find node with the largest key <= (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[2]
|
||||
return y if y is not self.nil else None
|
||||
def find_right_start(self, t):
|
||||
"""Find the rightmode node with start <= t. With non-overlapping
|
||||
intervals, this is the last node that might overlap time t."""
|
||||
x = self.getroot()
|
||||
while x is not self.nil:
|
||||
if t < x.start:
|
||||
if x.left is self.nil:
|
||||
x = self.__predecessor(x)
|
||||
break
|
||||
x = x.left
|
||||
elif t == x.start:
|
||||
break
|
||||
else:
|
||||
if x.right is self.nil:
|
||||
break
|
||||
x = x.right
|
||||
return x if x is not self.nil else None
|
||||
|
||||
# Intersections
|
||||
def intersect(self, start, end):
|
||||
"""Generator that returns nodes that overlap the given
|
||||
(start,end) range, for the tree rooted at RBNode x.
|
||||
(start,end) range. Assumes non-overlapping intervals."""
|
||||
|
||||
NOTE: this assumes non-overlapping intervals."""
|
||||
# Start with the leftmost node before the starting point
|
||||
n = self.find_left(start, start)
|
||||
# If we didn't find one, look for the leftmode node before the
|
||||
|
|
|
@ -12,8 +12,8 @@ stop=
|
|||
verbosity=2
|
||||
#tests=tests/test_cmdline.py
|
||||
#tests=tests/test_layout.py
|
||||
#tests=tests/test_rbtree.py
|
||||
tests=tests/test_interval.py
|
||||
tests=tests/test_rbtree.py
|
||||
#tests=tests/test_interval.py
|
||||
#tests=tests/test_client.py
|
||||
#tests=tests/test_timestamper.py
|
||||
#tests=tests/test_serializer.py
|
||||
|
|
|
@ -11,17 +11,33 @@ from nilmdb.rbtree import RBTree, RBNode
|
|||
from test_helpers import *
|
||||
import unittest
|
||||
|
||||
render = False
|
||||
# set to False to skip live renders
|
||||
do_live_renders = False
|
||||
def render(tree, description = "", live = True):
|
||||
import renderdot
|
||||
if not do_live_renders:
|
||||
# If not doing a live render, still render it to a string.
|
||||
live = False
|
||||
r = renderdot.Renderer(lambda node: node.left,
|
||||
lambda node: node.right,
|
||||
lambda node: node.red,
|
||||
lambda node: node.start,
|
||||
lambda node: node.end,
|
||||
tree.nil)
|
||||
if live:
|
||||
r.render_dot_live(tree.getroot(), description)
|
||||
else:
|
||||
return r.render_dot(tree.getroot())
|
||||
|
||||
class TestRBTree:
|
||||
def test_rbtree(self):
|
||||
rb = RBTree()
|
||||
rb.insert(RBNode(None, 10000, 10001))
|
||||
rb.insert(RBNode(None, 10004, 10007))
|
||||
rb.insert(RBNode(None, 10001, 10002))
|
||||
s = rb.render_dot()
|
||||
rb.insert(RBNode(10000, 10001))
|
||||
rb.insert(RBNode(10004, 10007))
|
||||
rb.insert(RBNode(10001, 10002))
|
||||
# There was a typo that gave the RBTree a loop in this case.
|
||||
# Verify that the dot isn't too big.
|
||||
s = render(rb, live = False)
|
||||
assert(len(s.splitlines()) < 30)
|
||||
|
||||
def test_rbtree_big(self):
|
||||
|
@ -32,44 +48,116 @@ class TestRBTree:
|
|||
rb = RBTree()
|
||||
j = 500
|
||||
for i in xrange(j):
|
||||
rb.insert(RBNode(None, i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("in-order insert")
|
||||
rb.insert(RBNode(i, i+1))
|
||||
render(rb, "in-order insert")
|
||||
|
||||
# remove about half of them
|
||||
for i in random.sample(xrange(j),j):
|
||||
if random.randint(0,1):
|
||||
rb.delete(rb.find(i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("in-order insert, random delete")
|
||||
render(rb, "in-order insert, random delete")
|
||||
|
||||
# make a set of 500 intervals, inserted at random
|
||||
rb = RBTree()
|
||||
j = 500
|
||||
for i in random.sample(xrange(j),j):
|
||||
rb.insert(RBNode(None, i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert")
|
||||
rb.insert(RBNode(i, i+1))
|
||||
render(rb, "random insert")
|
||||
|
||||
# remove about half of them
|
||||
for i in random.sample(xrange(j),j):
|
||||
if random.randint(0,1):
|
||||
rb.delete(rb.find(i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert, random delete")
|
||||
render(rb, "random insert, random delete")
|
||||
|
||||
# in-order insert of 250 more
|
||||
for i in xrange(250):
|
||||
rb.insert(RBNode(None, i+500, i+501))
|
||||
rb.insert(RBNode(i+500, i+501))
|
||||
render(rb, "random insert, random delete, in-order insert")
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert, random delete, in-order insert")
|
||||
def test_rbtree_basics(self):
|
||||
rb = RBTree()
|
||||
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
|
||||
for n in vals:
|
||||
rb.insert(RBNode(n, n))
|
||||
|
||||
# stringify
|
||||
s = ""
|
||||
for node in rb:
|
||||
s += str(node)
|
||||
assert "[node (None) 1 -> 1 B]" in s
|
||||
assert str(rb.nil) == "[node nil]"
|
||||
|
||||
# inorder traversal, successor and predecessor
|
||||
last = 0
|
||||
for node in rb:
|
||||
assert(node.start > last)
|
||||
last = node.start
|
||||
successor = rb.successor(node)
|
||||
if successor:
|
||||
assert(rb.predecessor(successor) is node)
|
||||
predecessor = rb.predecessor(node)
|
||||
if predecessor:
|
||||
assert(rb.successor(predecessor) is node)
|
||||
|
||||
# Delete node not in the tree
|
||||
with assert_raises(AttributeError):
|
||||
rb.delete(RBNode(1,2))
|
||||
|
||||
# Delete all nodes!
|
||||
for node in rb:
|
||||
rb.delete(node)
|
||||
|
||||
# Build it up again, make sure it matches
|
||||
for n in vals:
|
||||
rb.insert(RBNode(n, n))
|
||||
s2 = ""
|
||||
for node in rb:
|
||||
s2 += str(node)
|
||||
assert(s == s2)
|
||||
|
||||
def test_rbtree_find(self):
|
||||
# Get a little bit of coverage for some overlapping cases,
|
||||
# even though the class doesn't fully support it.
|
||||
rb = RBTree()
|
||||
nodes = [ RBNode(1, 5), RBNode(1, 10), RBNode(1, 15) ]
|
||||
for n in nodes:
|
||||
rb.insert(n)
|
||||
assert(rb.find(1, 5) is nodes[0])
|
||||
assert(rb.find(1, 10) is nodes[1])
|
||||
assert(rb.find(1, 15) is nodes[2])
|
||||
|
||||
def test_rbtree_find_leftright(self):
|
||||
# Now let's get some ranges in there
|
||||
rb = RBTree()
|
||||
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
|
||||
for n in vals:
|
||||
rb.insert(RBNode(n*10, n*10+5))
|
||||
|
||||
# Check find_end_left, find_right_start
|
||||
for i in range(160):
|
||||
left = rb.find_left_end(i)
|
||||
right = rb.find_right_start(i)
|
||||
if left:
|
||||
# endpoint should be more than i
|
||||
assert(left.end >= i)
|
||||
# all earlier nodes should have a lower endpoint
|
||||
for node in rb:
|
||||
if node is left:
|
||||
break
|
||||
assert(node.end < i)
|
||||
if right:
|
||||
# startpoint should be less than i
|
||||
assert(right.start <= i)
|
||||
# all later nodes should have a higher startpoint
|
||||
for node in reversed(list(rb)):
|
||||
if node is right:
|
||||
break
|
||||
assert(node.start > i)
|
||||
|
||||
def test_rbtree_intersect(self):
|
||||
# Fill with some ranges again
|
||||
rb = RBTree()
|
||||
vals = [ 7, 14, 1, 2, 8, 11, 5, 15, 4]
|
||||
for n in vals:
|
||||
rb.insert(RBNode(n*10, n*10+5))
|
||||
|
|
Loading…
Reference in New Issue
Block a user