Converted rbtree.py to Cython

About 3x faster
This commit is contained in:
Jim Paris 2012-11-29 01:25:51 -05:00
parent f5c60f68dc
commit 99ec0f4946
4 changed files with 56 additions and 39 deletions

View File

@ -1,2 +1,4 @@
sudo apt-get install python-nose python-coverage
sudo apt-get install python-tables cython python-cherrypy3
sudo apt-get install python-tables python-cherrypy3
sudo apt-get install cython # 0.17.1-1 or newer

View File

@ -1,4 +1,10 @@
"""Red-black tree, where keys are stored as start/end timestamps.
# cython: profile=False
# cython: cdivision=True
"""
Jim Paris <jim@jtan.com>
Red-black tree, where keys are stored as start/end timestamps.
This is a basic interval tree that holds half-open intervals:
[start, end)
Intervals must not overlap. Fixing that would involve making this
@ -10,17 +16,21 @@ string 'non-overlapping'.
import sys
class RBNode(object):
cdef class RBNode:
"""One node of the Red/Black tree, containing a key (start, end)
and value (obj)"""
def __init__(self, start, end, obj = None):
cdef public object obj
cdef public double start, end
cdef public int red
cdef public RBNode left, right, parent
def __cinit__(RBNode self, double start, double end, object obj = None):
self.obj = obj
self.start = start
self.end = end
self.red = False
self.left = None
self.right = None
self.nil = False
def __str__(self):
if self.red:
@ -34,11 +44,13 @@ class RBNode(object):
+ str(self.start) + " -> " + str(self.end) + " "
+ color + "]")
class RBTree(object):
cdef class RBTree:
"""Red/Black tree"""
cdef public RBNode nil, root
# Init
def __init__(self):
def __cinit__(RBTree self):
self.nil = RBNode(start = sys.float_info.min,
end = sys.float_info.min)
self.nil.left = self.nil
@ -53,11 +65,11 @@ class RBTree(object):
# We have a dummy root node to simplify operations, so from an
# external point of view, its left child is the real root.
def getroot(self):
cpdef getroot(RBTree self):
return self.root.left
# Rotations and basic operations
def __rotate_left(self, x):
cdef void __rotate_left(RBTree self, RBNode x):
"""Rotate left:
# x y
# / \ --> / \
@ -65,7 +77,7 @@ class RBTree(object):
# / \ / \
# v w z v
"""
y = x.right
cdef RBNode y = x.right
x.right = y.left
if y.left is not self.nil:
y.left.parent = x
@ -77,7 +89,7 @@ class RBTree(object):
y.left = x
x.parent = y
def __rotate_right(self, y):
cdef void __rotate_right(RBTree self, RBNode y):
"""Rotate right:
# y x
# / \ --> / \
@ -85,7 +97,7 @@ class RBTree(object):
# / \ / \
# z v v w
"""
x = y.left
cdef RBNode x = y.left
y.left = x.right
if x.right is not self.nil:
x.right.parent = y
@ -97,9 +109,9 @@ class RBTree(object):
x.right = y
y.parent = x
def __successor(self, x):
cdef RBNode __successor(RBTree self, RBNode x):
"""Returns the successor of RBNode x"""
y = x.right
cdef RBNode y = x.right
if y is not self.nil:
while y.left is not self.nil:
y = y.left
@ -111,14 +123,14 @@ class RBTree(object):
if y is self.root:
return self.nil
return y
def successor(self, x):
cpdef RBNode successor(RBTree self, RBNode x):
"""Returns the successor of RBNode x, or None"""
y = self.__successor(x)
cdef RBNode y = self.__successor(x)
return y if y is not self.nil else None
def __predecessor(self, x):
cdef RBNode __predecessor(RBTree self, RBNode x):
"""Returns the predecessor of RBNode x"""
y = x.left
cdef RBNode y = x.left
if y is not self.nil:
while y.right is not self.nil:
y = y.right
@ -131,18 +143,18 @@ class RBTree(object):
x = y
y = y.parent
return y
def predecessor(self, x):
cpdef RBNode predecessor(RBTree self, RBNode x):
"""Returns the predecessor of RBNode x, or None"""
y = self.__predecessor(x)
cdef RBNode y = self.__predecessor(x)
return y if y is not self.nil else None
# Insertion
def insert(self, z):
cpdef insert(RBTree self, RBNode z):
"""Insert RBNode z into RBTree and rebalance as necessary"""
z.left = self.nil
z.right = self.nil
y = self.root
x = self.root.left
cdef RBNode y = self.root
cdef RBNode x = self.root.left
while x is not self.nil:
y = x
if (x.start > z.start or (x.start == z.start and x.end > z.end)):
@ -158,7 +170,7 @@ class RBTree(object):
# relabel/rebalance
self.__insert_fixup(z)
def __insert_fixup(self, x):
cdef void __insert_fixup(RBTree self, RBNode x):
"""Rebalance/fix RBTree after a simple insertion of RBNode x"""
x.red = True
while x.parent.red:
@ -193,10 +205,11 @@ class RBTree(object):
self.root.left.red = False
# Deletion
def delete(self, z):
cpdef delete(RBTree self, RBNode z):
if z.left is None or z.right is None:
raise AttributeError("you can only delete a node object "
+ "from the tree; use find() to get one")
cdef RBNode x, y
if z.left is self.nil or z.right is self.nil:
y = z
else:
@ -233,10 +246,10 @@ class RBTree(object):
if not y.red:
self.__delete_fixup(x)
def __delete_fixup(self, x):
cdef void __delete_fixup(RBTree self, RBNode x):
"""Rebalance/fix RBTree after a deletion. RBNode x is the
child of the spliced out node."""
rootLeft = self.root.left
cdef RBNode rootLeft = self.root.left
while not x.red and x is not rootLeft:
if x is x.parent.left:
w = x.parent.right
@ -283,10 +296,10 @@ class RBTree(object):
x.red = False
# Walking, searching
def __iter__(self):
def __iter__(RBTree self):
return self.inorder()
def inorder(self, x = None):
def inorder(RBTree self, RBNode x = None):
"""Generator that performs an inorder walk for the tree
rooted at RBNode x"""
if x is None:
@ -297,9 +310,9 @@ class RBTree(object):
yield x
x = self.__successor(x)
def find(self, start, end):
cpdef RBNode find(RBTree self, double start, double end):
"""Return the node with exactly the given start and end."""
x = self.getroot()
cdef RBNode x = self.getroot()
while x is not self.nil:
if start < x.start:
x = x.left
@ -314,14 +327,14 @@ class RBTree(object):
x = x.right
return x if x is not self.nil else None
def find_left_end(self, t):
cpdef RBNode find_left_end(RBTree self, double t):
"""Find the leftmode node with end >= t. With non-overlapping
intervals, this is the first node that might overlap time t.
Note that this relies on non-overlapping intervals, since
it assumes that we can use the endpoints to traverse the
tree even though it was created using the start points."""
x = self.getroot()
cdef RBNode x = self.getroot()
while x is not self.nil:
if t < x.end:
if x.left is self.nil:
@ -336,10 +349,10 @@ class RBTree(object):
x = x.right
return x if x is not self.nil else None
def find_right_start(self, t):
cpdef RBNode find_right_start(RBTree self, double t):
"""Find the rightmode node with start <= t. With non-overlapping
intervals, this is the last node that might overlap time t."""
x = self.getroot()
cdef RBNode x = self.getroot()
while x is not self.nil:
if t < x.start:
if x.left is self.nil:
@ -355,11 +368,11 @@ class RBTree(object):
return x if x is not self.nil else None
# Intersections
def intersect(self, start, end):
def intersect(RBTree self, double start, double end):
"""Generator that returns nodes that overlap the given
(start,end) range. Assumes non-overlapping intervals."""
# Start with the leftmode node that ends after start
n = self.find_left_end(start)
cdef RBNode n = self.find_left_end(start)
while n is not None:
if n.start >= end:
# this node starts after the requested end; we're done

View File

@ -12,6 +12,8 @@ stop=
verbosity=2
#tests=tests/test_cmdline.py
#tests=tests/test_layout.py
#tests=tests/test_rbtree.py
#tests=tests/test_interval.py
tests=tests/test_rbtree.py,tests/test_interval.py
#tests=tests/test_interval.py
#tests=tests/test_client.py

View File

@ -74,8 +74,8 @@ class TestRBTree:
s = ""
for node in rb:
s += str(node)
assert "[node (None) 1 -> 1 B]" in s
assert str(rb.nil) == "[node nil]"
in_("[node (None) 1", s)
eq_(str(rb.nil), "[node nil]")
# inorder traversal, successor and predecessor
last = 0