Compare commits
14 Commits
bxinterval
...
bxinterval
Author | SHA1 | Date | |
---|---|---|---|
9b9f392d43 | |||
3c441de498 | |||
7211217f40 | |||
d34b980516 | |||
6aee52d980 | |||
090c8d5315 | |||
1042ff9f4b | |||
bc687969c1 | |||
de27bd3f41 | |||
4dcf713d0e | |||
f9dea53c24 | |||
6cedd7c327 | |||
6278d32f7d | |||
991039903c |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.coverage
|
||||
*.pyc
|
@@ -1,26 +1,7 @@
|
||||
# cython: profile=False
|
||||
# This is from bx-python 554:07aca5a9f6fc (BSD licensed), modified to
|
||||
# store interval ranges as doubles rather than 32-bit integers.
|
||||
|
||||
"""
|
||||
Data structure for performing intersect queries on a set of intervals which
|
||||
preserves all information about the intervals (unlike bitset projection methods).
|
||||
|
||||
:Authors: James Taylor (james@jamestaylor.org),
|
||||
Ian Schenk (ian.schenck@gmail.com),
|
||||
Brent Pedersen (bpederse@gmail.com)
|
||||
"""
|
||||
|
||||
# Historical note:
|
||||
# This module original contained an implementation based on sorted endpoints
|
||||
# and a binary search, using an idea from Scott Schwartz and Piotr Berman.
|
||||
# Later an interval tree implementation was implemented by Ian for Galaxy's
|
||||
# join tool (see `bx.intervals.operations.quicksect.py`). This was then
|
||||
# converted to Cython by Brent, who also added support for
|
||||
# upstream/downstream/neighbor queries. This was modified by James to
|
||||
# handle half-open intervals strictly, to maintain sort order, and to
|
||||
# implement the same interface as the original Intersecter.
|
||||
|
||||
# This is based on bxintersect in bx-python 554:07aca5a9f6fc (BSD licensed);
|
||||
# modified to store interval ranges as doubles rather than 32-bit integers,
|
||||
# use fully closed intervals, support deletion, etc.
|
||||
#cython: cdivision=True
|
||||
|
||||
import operator
|
||||
@@ -194,76 +175,6 @@ cdef class IntervalNode:
|
||||
self.cright._intersect( start, end, results )
|
||||
|
||||
|
||||
cdef void _seek_left(IntervalNode self, double position, list results, int n, double max_dist):
|
||||
# we know we can bail in these 2 cases.
|
||||
if self.maxend + max_dist < position:
|
||||
return
|
||||
if self.minstart > position:
|
||||
return
|
||||
|
||||
# the ordering of these 3 blocks makes it so the results are
|
||||
# ordered nearest to farest from the query position
|
||||
if self.cright is not EmptyNode:
|
||||
self.cright._seek_left(position, results, n, max_dist)
|
||||
|
||||
if -1 < position - self.end < max_dist:
|
||||
results.append(self.interval)
|
||||
|
||||
# TODO: can these conditionals be more stringent?
|
||||
if self.cleft is not EmptyNode:
|
||||
self.cleft._seek_left(position, results, n, max_dist)
|
||||
|
||||
|
||||
|
||||
cdef void _seek_right(IntervalNode self, double position, list results, int n, double max_dist):
|
||||
# we know we can bail in these 2 cases.
|
||||
if self.maxend < position: return
|
||||
if self.minstart - max_dist > position: return
|
||||
|
||||
#print "SEEK_RIGHT:",self, self.cleft, self.maxend, self.minstart, position
|
||||
|
||||
# the ordering of these 3 blocks makes it so the results are
|
||||
# ordered nearest to farest from the query position
|
||||
if self.cleft is not EmptyNode:
|
||||
self.cleft._seek_right(position, results, n, max_dist)
|
||||
|
||||
if -1 < self.start - position < max_dist:
|
||||
results.append(self.interval)
|
||||
|
||||
if self.cright is not EmptyNode:
|
||||
self.cright._seek_right(position, results, n, max_dist)
|
||||
|
||||
|
||||
cpdef left(self, position, int n=1, double max_dist=2500):
|
||||
"""
|
||||
find n features with a start > than `position`
|
||||
f: a Interval object (or anything with an `end` attribute)
|
||||
n: the number of features to return
|
||||
max_dist: the maximum distance to look before giving up.
|
||||
"""
|
||||
cdef list results = []
|
||||
# use start - 1 becuase .left() assumes strictly left-of
|
||||
self._seek_left( position - 1, results, n, max_dist )
|
||||
if len(results) == n: return results
|
||||
r = results
|
||||
r.sort(key=operator.attrgetter('end'), reverse=True)
|
||||
return r[:n]
|
||||
|
||||
cpdef right(self, position, int n=1, double max_dist=2500):
|
||||
"""
|
||||
find n features with a end < than position
|
||||
f: a Interval object (or anything with a `start` attribute)
|
||||
n: the number of features to return
|
||||
max_dist: the maximum distance to look before giving up.
|
||||
"""
|
||||
cdef list results = []
|
||||
# use end + 1 becuase .right() assumes strictly right-of
|
||||
self._seek_right(position + 1, results, n, max_dist)
|
||||
if len(results) == n: return results
|
||||
r = results
|
||||
r.sort(key=operator.attrgetter('start'))
|
||||
return r[:n]
|
||||
|
||||
def traverse(self):
|
||||
if self.cleft is not EmptyNode:
|
||||
for node in self.cleft.traverse():
|
||||
@@ -392,6 +303,7 @@ cdef class IntervalTree:
|
||||
|
||||
# ---- Position based interfaces -----------------------------------------
|
||||
|
||||
## KEEP
|
||||
def insert( self, double start, double end, object value=None ):
|
||||
"""
|
||||
Insert the interval [start,end) associated with value `value`.
|
||||
@@ -401,8 +313,14 @@ cdef class IntervalTree:
|
||||
else:
|
||||
self.root = self.root.insert( start, end, value )
|
||||
|
||||
add = insert
|
||||
|
||||
def delete( self, double start, double end, object value=None ):
|
||||
"""
|
||||
Delete the interval [start,end) associated with value `value`.
|
||||
"""
|
||||
if self.root is None:
|
||||
self.root = IntervalNode( start, end, value )
|
||||
else:
|
||||
self.root = self.root.insert( start, end, value )
|
||||
|
||||
def find( self, start, end ):
|
||||
"""
|
||||
@@ -412,26 +330,9 @@ cdef class IntervalTree:
|
||||
return []
|
||||
return self.root.find( start, end )
|
||||
|
||||
def before( self, position, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie before `position` and are no
|
||||
further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
return self.root.left( position, num_intervals, max_dist )
|
||||
|
||||
def after( self, position, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie after `position` and are no
|
||||
further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
return self.root.right( position, num_intervals, max_dist )
|
||||
|
||||
# ---- Interval-like object based interfaces -----------------------------
|
||||
|
||||
## KEEP
|
||||
def insert_interval( self, interval ):
|
||||
"""
|
||||
Insert an "interval" like object (one with at least start and end
|
||||
@@ -439,50 +340,6 @@ cdef class IntervalTree:
|
||||
"""
|
||||
self.insert( interval.start, interval.end, interval )
|
||||
|
||||
add_interval = insert_interval
|
||||
|
||||
def before_interval( self, interval, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie completely before `interval`
|
||||
and are no further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
return self.root.left( interval.start, num_intervals, max_dist )
|
||||
|
||||
def after_interval( self, interval, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie completely after `interval` and
|
||||
are no further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
return self.root.right( interval.end, num_intervals, max_dist )
|
||||
|
||||
def upstream_of_interval( self, interval, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie completely upstream of
|
||||
`interval` and are no further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
if interval.strand == -1 or interval.strand == "-":
|
||||
return self.root.right( interval.end, num_intervals, max_dist )
|
||||
else:
|
||||
return self.root.left( interval.start, num_intervals, max_dist )
|
||||
|
||||
def downstream_of_interval( self, interval, num_intervals=1, max_dist=2500 ):
|
||||
"""
|
||||
Find `num_intervals` intervals that lie completely downstream of
|
||||
`interval` and are no further than `max_dist` positions away
|
||||
"""
|
||||
if self.root is None:
|
||||
return []
|
||||
if interval.strand == -1 or interval.strand == "-":
|
||||
return self.root.left( interval.start, num_intervals, max_dist )
|
||||
else:
|
||||
return self.root.right( interval.end, num_intervals, max_dist )
|
||||
|
||||
def traverse(self):
|
||||
"""
|
||||
iterator that traverses the tree
|
||||
|
@@ -8,20 +8,25 @@ Intervals are closed, ie. they include timestamps [start, end]
|
||||
# First implementation kept a sorted list of intervals and used
|
||||
# biesct() to optimize some operations, but this was too slow.
|
||||
|
||||
# This version is based on the quicksect implementation from python-bx,
|
||||
# modified slightly to handle floating point intervals.
|
||||
# Second version was based on the quicksect implementation from
|
||||
# python-bx, modified slightly to handle floating point intervals.
|
||||
# This didn't support deletion.
|
||||
|
||||
# Third version is more similar to the first version, using a rb-tree
|
||||
# instead of a simple sorted list to maintain O(log n) operations.
|
||||
|
||||
# Fourth version is an optimized rb-tree that stores interval starts
|
||||
# and ends directly in the tree, like bxinterval did.
|
||||
|
||||
# Fifth version is back to modified bxintersect...
|
||||
|
||||
import pyximport
|
||||
pyximport.install()
|
||||
import bxintersect
|
||||
|
||||
import bisect
|
||||
|
||||
class IntervalError(Exception):
|
||||
"""Error due to interval overlap, etc"""
|
||||
pass
|
||||
|
||||
class Interval(bxintersect.Interval):
|
||||
class Interval(object):
|
||||
"""Represents an interval of time."""
|
||||
|
||||
def __init__(self, start, end):
|
||||
@@ -30,7 +35,8 @@ class Interval(bxintersect.Interval):
|
||||
"""
|
||||
if start > end:
|
||||
raise IntervalError("start %s must precede end %s" % (start, end))
|
||||
bxintersect.Interval.__init__(self, start, end)
|
||||
self.start = float(start)
|
||||
self.end = float(end)
|
||||
|
||||
def __repr__(self):
|
||||
s = repr(self.start) + ", " + repr(self.end)
|
||||
@@ -39,6 +45,20 @@ class Interval(bxintersect.Interval):
|
||||
def __str__(self):
|
||||
return "[" + str(self.start) + " -> " + str(self.end) + "]"
|
||||
|
||||
def __cmp__(self, other):
|
||||
"""Compare two intervals. If non-equal, order by start then end"""
|
||||
if not isinstance(other, Interval):
|
||||
raise TypeError("bad type")
|
||||
if self.start == other.start:
|
||||
if self.end < other.end:
|
||||
return -1
|
||||
if self.end > other.end:
|
||||
return 1
|
||||
return 0
|
||||
if self.start < other.start:
|
||||
return -1
|
||||
return 1
|
||||
|
||||
def intersects(self, other):
|
||||
"""Return True if two Interval objects intersect"""
|
||||
if (self.end <= other.start or self.start >= other.end):
|
||||
@@ -66,6 +86,7 @@ class DBInterval(Interval):
|
||||
end = 150
|
||||
db_end = 200, db_endpos = 20000
|
||||
"""
|
||||
|
||||
def __init__(self, start, end,
|
||||
db_start, db_end,
|
||||
db_startpos, db_endpos):
|
||||
@@ -109,12 +130,14 @@ class IntervalSet(object):
|
||||
"""
|
||||
'source' is an Interval or IntervalSet to add.
|
||||
"""
|
||||
self.tree = bxintersect.IntervalTree()
|
||||
self.tree = bxinterval.IntervalTree()
|
||||
if source is not None:
|
||||
self += source
|
||||
|
||||
def __iter__(self):
|
||||
return self.tree.traverse()
|
||||
for node in self.tree:
|
||||
if node.obj:
|
||||
yield node.obj
|
||||
|
||||
def __len__(self):
|
||||
return sum(1 for x in self)
|
||||
@@ -195,6 +218,17 @@ class IntervalSet(object):
|
||||
self.__iadd__(x)
|
||||
return self
|
||||
|
||||
def __isub__(self, other):
|
||||
"""Inplace subtract -- modifies self
|
||||
|
||||
Removes an interval from the set. Must exist exactly
|
||||
as provided -- cannot remove a subset of an existing interval."""
|
||||
i = self.tree.find(other.start, other.end)
|
||||
if i is None:
|
||||
raise IntervalError("interval " + str(other) + " not in tree")
|
||||
self.tree.delete(i)
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
"""Add -- returns a new object"""
|
||||
new = IntervalSet(self)
|
||||
@@ -211,11 +245,12 @@ class IntervalSet(object):
|
||||
out = IntervalSet()
|
||||
|
||||
if not isinstance(other, IntervalSet):
|
||||
other = [ other ]
|
||||
|
||||
for x in other:
|
||||
for i in self.intersection(x):
|
||||
out.tree.insert_interval(i)
|
||||
for i in self.intersection(other):
|
||||
out.tree.insert(rbtree.RBNode(i))
|
||||
else:
|
||||
for x in other:
|
||||
for i in self.intersection(x):
|
||||
out.tree.insert(rbtree.RBNode(i))
|
||||
|
||||
return out
|
||||
|
||||
@@ -229,13 +264,30 @@ class IntervalSet(object):
|
||||
Output intervals are built as subsets of the intervals in the
|
||||
first argument (self).
|
||||
"""
|
||||
for i in self.tree.find(interval.start, interval.end):
|
||||
if i.start > interval.start and i.end < interval.end:
|
||||
yield i
|
||||
else:
|
||||
yield i.subset(max(i.start, interval.start),
|
||||
min(i.end, interval.end))
|
||||
if not isinstance(interval, Interval):
|
||||
raise TypeError("bad type")
|
||||
for n in self.tree.intersect(interval.start, interval.end):
|
||||
i = n.obj
|
||||
if i:
|
||||
if i.start >= interval.start and i.end <= interval.end:
|
||||
yield i
|
||||
elif i.start > interval.end:
|
||||
break
|
||||
else:
|
||||
subset = i.subset(max(i.start, interval.start),
|
||||
min(i.end, interval.end))
|
||||
yield subset
|
||||
|
||||
def intersects(self, other):
|
||||
### PROBABLY WRONG
|
||||
"""Return True if this IntervalSet intersects another interval"""
|
||||
return len(self.tree.find(other.start, other.end)) > 0
|
||||
node = self.tree.find_left(other.start, other.end)
|
||||
if node is None:
|
||||
return False
|
||||
for n in self.tree.inorder(node):
|
||||
if n.obj:
|
||||
if n.obj.intersects(other):
|
||||
return True
|
||||
if n.obj > other:
|
||||
break
|
||||
return False
|
||||
|
392
nilmdb/rbtree.py
Normal file
392
nilmdb/rbtree.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Red-black tree, where keys are stored as start/end timestamps."""
|
||||
|
||||
import sys
|
||||
|
||||
class RBNode(object):
|
||||
"""One node of the Red/Black tree. obj points to any object,
|
||||
'start' and 'end' are timestamps that represent the key."""
|
||||
def __init__(self, obj = None, start = None, end = None):
|
||||
"""If given an object but no start/end times, get the
|
||||
start/end times from the object.
|
||||
|
||||
If given start/end times, obj can be anything, including None."""
|
||||
self.obj = obj
|
||||
if start is None:
|
||||
start = obj.start
|
||||
if end is None:
|
||||
end = obj.end
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.red = False
|
||||
self.left = None
|
||||
self.right = None
|
||||
|
||||
def __str__(self):
|
||||
if self.red:
|
||||
color = "R"
|
||||
else:
|
||||
color = "B"
|
||||
return ("[node "
|
||||
+ str(self.start) + " -> " + str(self.end) + " "
|
||||
+ color + "]")
|
||||
|
||||
class RBTree(object):
|
||||
"""Red/Black tree"""
|
||||
|
||||
# Init
|
||||
def __init__(self):
|
||||
self.nil = RBNode(start = sys.float_info.min,
|
||||
end = sys.float_info.min)
|
||||
self.nil.left = self.nil
|
||||
self.nil.right = self.nil
|
||||
self.nil.parent = self.nil
|
||||
self.nil.nil = True
|
||||
|
||||
self.root = RBNode(start = sys.float_info.max,
|
||||
end = sys.float_info.max)
|
||||
self.root.left = self.nil
|
||||
self.root.right = self.nil
|
||||
self.root.parent = self.nil
|
||||
|
||||
# Rotations and basic operations
|
||||
def __rotate_left(self, x):
|
||||
y = x.right
|
||||
x.right = y.left
|
||||
if y.left is not self.nil:
|
||||
y.left.parent = x
|
||||
y.parent = x.parent
|
||||
if x is x.parent.left:
|
||||
x.parent.left = y
|
||||
else:
|
||||
x.parent.right = y
|
||||
y.left = x
|
||||
x.parent = y
|
||||
|
||||
def __rotate_right(self, y):
|
||||
x = y.left
|
||||
y.left = x.right
|
||||
if x.right is not self.nil:
|
||||
x.right.parent = y
|
||||
x.parent = y.parent
|
||||
if y is y.parent.left:
|
||||
y.parent.left = x
|
||||
else:
|
||||
y.parent.right = x
|
||||
x.right = y
|
||||
y.parent = x
|
||||
|
||||
def __successor(self, x):
|
||||
"""Returns the successor of RBNode x"""
|
||||
y = x.right
|
||||
if y is not self.nil:
|
||||
while y.left is not self.nil:
|
||||
y = y.left
|
||||
else:
|
||||
y = x.parent
|
||||
while x is y.right:
|
||||
x = y
|
||||
y = y.parent
|
||||
if y is self.root:
|
||||
return self.nil
|
||||
return y
|
||||
|
||||
def _predecessor(self, x):
|
||||
"""Returns the predecessor of RBNode x"""
|
||||
y = x.left
|
||||
if y is not self.nil:
|
||||
while y.right is not self.nil:
|
||||
y = y.right
|
||||
else:
|
||||
y = x.parent
|
||||
while x is y.left:
|
||||
if y is self.root:
|
||||
y = self.nil
|
||||
break
|
||||
x = y
|
||||
y = y.parent
|
||||
return y
|
||||
|
||||
# Insertion
|
||||
def insert(self, z):
|
||||
"""Insert RBNode z into RBTree and rebalance as necessary"""
|
||||
z.left = self.nil
|
||||
z.right = self.nil
|
||||
y = self.root
|
||||
x = self.root.left
|
||||
while x is not self.nil:
|
||||
y = x
|
||||
if (x.start > z.start or (x.start == z.start and x.end > z.end)):
|
||||
x = x.left
|
||||
else:
|
||||
x = x.right
|
||||
z.parent = y
|
||||
if (y is self.root or
|
||||
(y.start > z.start or (y.start == z.start and y.end > z.end))):
|
||||
y.left = z
|
||||
else:
|
||||
y.right = z
|
||||
# relabel/rebalance
|
||||
self.__insert_fixup(z)
|
||||
|
||||
def __insert_fixup(self, x):
|
||||
"""Rebalance/fix RBTree after a simple insertion of RBNode x"""
|
||||
x.red = True
|
||||
while x.parent.red:
|
||||
if x.parent is x.parent.parent.left:
|
||||
y = x.parent.parent.right
|
||||
if y.red:
|
||||
x.parent.red = False
|
||||
y.red = False
|
||||
x.parent.parent.red = True
|
||||
x = x.parent.parent
|
||||
else:
|
||||
if x is x.parent.right:
|
||||
x = x.parent
|
||||
self.__rotate_left(x)
|
||||
x.parent.red = False
|
||||
x.parent.parent.red = True
|
||||
self.__rotate_right(x.parent.parent)
|
||||
else: # same as above, left/right switched
|
||||
y = x.parent.parent.left
|
||||
if y.red:
|
||||
x.parent.red = False
|
||||
y.red = False
|
||||
x.parent.parent.red = True
|
||||
x = x.parent.parent
|
||||
else:
|
||||
if x is x.parent.left:
|
||||
x = x.parent
|
||||
self.__rotate_right(x)
|
||||
x.parent.red = False
|
||||
x.parent.parent.red = True
|
||||
self.__rotate_left(x.parent.parent)
|
||||
self.root.left.red = False
|
||||
|
||||
# Deletion
|
||||
def delete(self, z):
|
||||
if z.left is None or z.right is None:
|
||||
raise AttributeError("you can only delete a node object "
|
||||
+ "from the tree; use find() to get one")
|
||||
if z.left is self.nil or z.right is self.nil:
|
||||
y = z
|
||||
else:
|
||||
y = self.__successor(z)
|
||||
if y.left is self.nil:
|
||||
x = y.right
|
||||
else:
|
||||
x = y.left
|
||||
x.parent = y.parent
|
||||
if x.parent is self.root:
|
||||
self.root.left = x
|
||||
else:
|
||||
if y is y.parent.left:
|
||||
y.parent.left = x
|
||||
else:
|
||||
y.parent.right = x
|
||||
if y is not z:
|
||||
# y is the node to splice out, x is its child
|
||||
y.left = z.left
|
||||
y.right = z.right
|
||||
y.parent = z.parent
|
||||
z.left.parent = y
|
||||
z.right.parent = y
|
||||
if z is z.parent.left:
|
||||
z.parent.left = y
|
||||
else:
|
||||
z.parent.right = y
|
||||
if not y.red:
|
||||
y.red = z.red
|
||||
self.__delete_fixup(x)
|
||||
else:
|
||||
y.red = z.red
|
||||
else:
|
||||
if not y.red:
|
||||
self.__delete_fixup(x)
|
||||
|
||||
def __delete_fixup(self, x):
|
||||
"""Rebalance/fix RBTree after a deletion. RBNode x is the
|
||||
child of the spliced out node."""
|
||||
rootLeft = self.root.left
|
||||
while not x.red and x is not rootLeft:
|
||||
if x is x.parent.left:
|
||||
w = x.parent.right
|
||||
if w.red:
|
||||
w.red = False
|
||||
x.parent.red = True
|
||||
self.__rotate_left(x.parent)
|
||||
w = x.parent.right
|
||||
if not w.right.red and not w.left.red:
|
||||
w.red = True
|
||||
x = x.parent
|
||||
else:
|
||||
if not w.right.red:
|
||||
w.left.red = False
|
||||
w.red = True
|
||||
self.__rotate_right(w)
|
||||
w = x.parent.right
|
||||
w.red = x.parent.red
|
||||
x.parent.red = False
|
||||
w.right.red = False
|
||||
self.__rotate_left(x.parent)
|
||||
x = rootLeft # exit loop
|
||||
else: # same as above, left/right switched
|
||||
w = x.parent.left
|
||||
if w.red:
|
||||
w.red = False
|
||||
x.parent.red = True
|
||||
self.__rotate_right(x.parent)
|
||||
w = x.parent.left
|
||||
if not w.left.red and not w.right.red:
|
||||
w.red = True
|
||||
x = x.parent
|
||||
else:
|
||||
if not w.left.red:
|
||||
w.right.red = False
|
||||
w.red = True
|
||||
self.__rotate_left(w)
|
||||
w = x.parent.left
|
||||
w.red = x.parent.red
|
||||
x.parent.red = False
|
||||
w.left.red = False
|
||||
self.__rotate_right(x.parent)
|
||||
x = rootLeft # exit loop
|
||||
x.red = False
|
||||
|
||||
# Rendering
|
||||
def __render_dot_node(self, node, max_depth = 20):
|
||||
from printf import sprintf
|
||||
"""Render a single node and its children into a dot graph fragment"""
|
||||
if max_depth == 0:
|
||||
return ""
|
||||
if node is self.nil:
|
||||
return ""
|
||||
def c(red):
|
||||
if red:
|
||||
return 'color="#ff0000", style=filled, fillcolor="#ffc0c0"'
|
||||
else:
|
||||
return 'color="#000000", style=filled, fillcolor="#c0c0c0"'
|
||||
s = sprintf("%d [label=\"%g\\n%g\", %s];\n",
|
||||
id(node),
|
||||
node.start, node.end,
|
||||
c(node.red))
|
||||
|
||||
if node.left is self.nil:
|
||||
s += sprintf("L%d [label=\"-\", %s];\n", id(node), c(False))
|
||||
s += sprintf("%d -> L%d [label=L];\n", id(node), id(node))
|
||||
else:
|
||||
s += sprintf("%d -> %d [label=L];\n", id(node), id(node.left))
|
||||
if node.right is self.nil:
|
||||
s += sprintf("R%d [label=\"-\", %s];\n", id(node), c(False))
|
||||
s += sprintf("%d -> R%d [label=R];\n", id(node), id(node))
|
||||
else:
|
||||
s += sprintf("%d -> %d [label=R];\n", id(node), id(node.right))
|
||||
s += self.__render_dot_node(node.left, max_depth-1)
|
||||
s += self.__render_dot_node(node.right, max_depth-1)
|
||||
return s
|
||||
|
||||
def render_dot(self, title = "RBTree"):
|
||||
"""Render the entire RBTree as a dot graph"""
|
||||
return ("digraph rbtree {\n"
|
||||
+ self.__render_dot_node(self.root.left)
|
||||
+ "}\n");
|
||||
|
||||
def render_dot_live(self, title = "RBTree"):
|
||||
"""Render the entire RBTree as a dot graph, live GTK view"""
|
||||
import gtk
|
||||
import gtk.gdk
|
||||
sys.path.append("/usr/share/xdot")
|
||||
import xdot
|
||||
xdot.Pen.highlighted = lambda pen: pen
|
||||
s = ("digraph rbtree {\n"
|
||||
+ self.__render_dot_node(self.root)
|
||||
+ "}\n");
|
||||
window = xdot.DotWindow()
|
||||
window.set_dotcode(s)
|
||||
window.set_title(title + " - any key to close")
|
||||
window.connect('destroy', gtk.main_quit)
|
||||
def quit(widget, event):
|
||||
if not event.is_modifier:
|
||||
window.destroy()
|
||||
gtk.main_quit()
|
||||
window.widget.connect('key-press-event', quit)
|
||||
gtk.main()
|
||||
|
||||
# Walking, searching
|
||||
def __iter__(self):
|
||||
return self.inorder(self.root.left)
|
||||
|
||||
def inorder(self, x = None):
|
||||
"""Generator that performs an inorder walk for the tree
|
||||
starting at RBNode x"""
|
||||
if x is None:
|
||||
x = self.root.left
|
||||
while x.left is not self.nil:
|
||||
x = x.left
|
||||
while x is not self.nil:
|
||||
yield x
|
||||
x = self.__successor(x)
|
||||
|
||||
def __find_all(self, start, end, x):
|
||||
"""Find node with the specified (start,end) key.
|
||||
Also returns the largest node less than or equal to key,
|
||||
and the smallest node greater or equal to than key."""
|
||||
if x is None:
|
||||
x = self.root.left
|
||||
largest = self.nil
|
||||
smallest = self.nil
|
||||
while x is not self.nil:
|
||||
if start < x.start:
|
||||
smallest = x
|
||||
x = x.left # start <
|
||||
elif start == x.start:
|
||||
if end < x.end:
|
||||
smallest = x
|
||||
x = x.left # start =, end <
|
||||
elif end == x.end: # found it
|
||||
smallest = x
|
||||
largest = x
|
||||
break
|
||||
else:
|
||||
largest = x
|
||||
x = x.right # start =, end >
|
||||
else:
|
||||
largest = x
|
||||
x = x.right # start >
|
||||
return (x, smallest, largest)
|
||||
|
||||
def find(self, start, end, x = None):
|
||||
"""Find node with the key == (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[1]
|
||||
return y if y is not self.nil else None
|
||||
|
||||
def find_right(self, start, end, x = None):
|
||||
"""Find node with the smallest key >= (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[1]
|
||||
return y if y is not self.nil else None
|
||||
|
||||
def find_left(self, start, end, x = None):
|
||||
"""Find node with the largest key <= (start,end), or None"""
|
||||
y = self.__find_all(start, end, x)[2]
|
||||
return y if y is not self.nil else None
|
||||
|
||||
# Intersections
|
||||
def intersect(self, start, end):
|
||||
"""Generator that returns nodes that overlap the given
|
||||
(start,end) range, for the tree rooted at RBNode x.
|
||||
|
||||
NOTE: this assumes non-overlapping intervals."""
|
||||
# Start with the leftmost node before the starting point
|
||||
n = self.find_left(start, start)
|
||||
# If we didn't find one, look for the leftmode node before the
|
||||
# ending point instead.
|
||||
if n is None:
|
||||
n = self.find_left(end, end)
|
||||
# If we still didn't find it, there are no intervals that intersect.
|
||||
if n is None:
|
||||
return none
|
||||
|
||||
# Now yield this node and all successors until their endpoints
|
||||
|
||||
if False:
|
||||
yield
|
||||
return
|
@@ -12,6 +12,7 @@ stop=
|
||||
verbosity=2
|
||||
#tests=tests/test_cmdline.py
|
||||
#tests=tests/test_layout.py
|
||||
#tests=tests/test_rbtree.py
|
||||
tests=tests/test_interval.py
|
||||
#tests=tests/test_client.py
|
||||
#tests=tests/test_timestamper.py
|
||||
|
@@ -20,6 +20,7 @@ def makeset(string):
|
||||
[ = interval start
|
||||
| = interval end + adjacent start
|
||||
] = interval end
|
||||
. = zero-width interval (identical start and end)
|
||||
anything else is ignored
|
||||
"""
|
||||
iset = IntervalSet()
|
||||
@@ -33,6 +34,8 @@ def makeset(string):
|
||||
elif (c == "]"):
|
||||
iset += Interval(start, day)
|
||||
del start
|
||||
elif (c == "."):
|
||||
iset += Interval(day, day)
|
||||
return iset
|
||||
|
||||
class TestInterval:
|
||||
@@ -68,7 +71,7 @@ class TestInterval:
|
||||
assert(Interval(d1, d3) < Interval(d2, d3))
|
||||
assert(Interval(d2, d2) > Interval(d1, d3))
|
||||
assert(Interval(d3, d3) == Interval(d3, d3))
|
||||
with assert_raises(AttributeError):
|
||||
with assert_raises(TypeError): # was AttributeError, that's wrong
|
||||
x = (i == 123)
|
||||
|
||||
# subset
|
||||
@@ -182,7 +185,7 @@ class TestInterval:
|
||||
|
||||
def test_intervalset_intersect(self):
|
||||
# Test intersection (&)
|
||||
with assert_raises(AttributeError):
|
||||
with assert_raises(TypeError): # was AttributeError
|
||||
x = makeset("[--]") & 1234
|
||||
|
||||
assert(makeset("[---------]") &
|
||||
@@ -197,10 +200,18 @@ class TestInterval:
|
||||
makeset(" [-----] ") ==
|
||||
makeset(" [--] "))
|
||||
|
||||
assert(makeset(" [--] [--]") &
|
||||
makeset(" [------] ") ==
|
||||
makeset(" [-] [-] "))
|
||||
|
||||
assert(makeset(" [---]") &
|
||||
makeset(" [--] ") ==
|
||||
makeset(" "))
|
||||
|
||||
assert(makeset(" [---]") &
|
||||
makeset(" [----] ") ==
|
||||
makeset(" . "))
|
||||
|
||||
assert(makeset(" [-|---]") &
|
||||
makeset(" [-----|-] ") ==
|
||||
makeset(" [----] "))
|
||||
@@ -211,8 +222,9 @@ class TestInterval:
|
||||
|
||||
assert(makeset(" [----][--]") &
|
||||
makeset("[-] [--] []") ==
|
||||
makeset(" [] [-] []"))
|
||||
makeset(" [] [-]. []"))
|
||||
|
||||
class TestIntervalDB:
|
||||
def test_dbinterval(self):
|
||||
# Test DBInterval class
|
||||
i = DBInterval(100, 200, 100, 200, 10000, 20000)
|
||||
@@ -255,20 +267,47 @@ class TestInterval:
|
||||
for i in IntervalSet(iseta.intersection(Interval(125,250))):
|
||||
assert(isinstance(i, DBInterval))
|
||||
|
||||
class TestIntervalTree:
|
||||
|
||||
def test_interval_tree(self):
|
||||
import random
|
||||
random.seed(1234)
|
||||
|
||||
# make a set of 500 intervals
|
||||
iset = IntervalSet()
|
||||
j = 500
|
||||
for i in random.sample(xrange(j),j):
|
||||
interval = Interval(i, i+1)
|
||||
iset += interval
|
||||
|
||||
# remove about half of them
|
||||
for i in random.sample(xrange(j),j):
|
||||
if random.randint(0,1):
|
||||
iset -= Interval(i, i+1)
|
||||
|
||||
# try removing an interval that doesn't exist
|
||||
with assert_raises(IntervalError):
|
||||
iset -= Interval(1234,5678)
|
||||
|
||||
# show the graph
|
||||
if False:
|
||||
iset.tree.render_dot_live()
|
||||
|
||||
class TestIntervalSpeed:
|
||||
#@unittest.skip("this is slow")
|
||||
@unittest.skip("this is slow")
|
||||
def test_interval_speed(self):
|
||||
import yappi
|
||||
import time
|
||||
import aplotter
|
||||
import random
|
||||
|
||||
print
|
||||
yappi.start()
|
||||
speeds = {}
|
||||
for j in [ 2**x for x in range(5,22) ]:
|
||||
for j in [ 2**x for x in range(5,18) ]:
|
||||
start = time.time()
|
||||
iset = IntervalSet()
|
||||
for i in xrange(j):
|
||||
for i in random.sample(xrange(j),j):
|
||||
interval = Interval(i, i+1)
|
||||
iset += interval
|
||||
speed = (time.time() - start) * 1000000.0
|
||||
@@ -277,3 +316,4 @@ class TestIntervalSpeed:
|
||||
aplotter.plot(speeds.keys(), speeds.values(), plot_slope=True)
|
||||
yappi.stop()
|
||||
yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)
|
||||
|
||||
|
75
tests/test_rbtree.py
Normal file
75
tests/test_rbtree.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import nilmdb
|
||||
from nilmdb.printf import *
|
||||
|
||||
from nose.tools import *
|
||||
from nose.tools import assert_raises
|
||||
|
||||
from nilmdb.rbtree import RBTree, RBNode
|
||||
|
||||
from test_helpers import *
|
||||
import unittest
|
||||
|
||||
render = False
|
||||
|
||||
class TestRBTree:
|
||||
def test_rbtree(self):
|
||||
rb = RBTree()
|
||||
rb.insert(RBNode(None, 10000, 10001))
|
||||
rb.insert(RBNode(None, 10004, 10007))
|
||||
rb.insert(RBNode(None, 10001, 10002))
|
||||
s = rb.render_dot()
|
||||
# There was a typo that gave the RBTree a loop in this case.
|
||||
# Verify that the dot isn't too big.
|
||||
assert(len(s.splitlines()) < 30)
|
||||
|
||||
def test_rbtree_big(self):
|
||||
import random
|
||||
random.seed(1234)
|
||||
|
||||
# make a set of 500 intervals, inserted in order
|
||||
rb = RBTree()
|
||||
j = 500
|
||||
for i in xrange(j):
|
||||
rb.insert(RBNode(None, i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("in-order insert")
|
||||
|
||||
# remove about half of them
|
||||
for i in random.sample(xrange(j),j):
|
||||
if random.randint(0,1):
|
||||
rb.delete(rb.find(i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("in-order insert, random delete")
|
||||
|
||||
# make a set of 500 intervals, inserted at random
|
||||
rb = RBTree()
|
||||
j = 500
|
||||
for i in random.sample(xrange(j),j):
|
||||
rb.insert(RBNode(None, i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert")
|
||||
|
||||
# remove about half of them
|
||||
for i in random.sample(xrange(j),j):
|
||||
if random.randint(0,1):
|
||||
rb.delete(rb.find(i, i+1))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert, random delete")
|
||||
|
||||
# in-order insert of 250 more
|
||||
for i in xrange(250):
|
||||
rb.insert(RBNode(None, i+500, i+501))
|
||||
|
||||
# show the graph
|
||||
if render:
|
||||
rb.render_dot_live("random insert, random delete, in-order insert")
|
54
time-bxintersect
Normal file
54
time-bxintersect
Normal file
@@ -0,0 +1,54 @@
|
||||
nosetests
|
||||
|
||||
32: 386 μs (12.0625 μs each)
|
||||
64: 672.102 μs (10.5016 μs each)
|
||||
128: 1510.86 μs (11.8036 μs each)
|
||||
256: 2782.11 μs (10.8676 μs each)
|
||||
512: 5591.87 μs (10.9216 μs each)
|
||||
1024: 12812.1 μs (12.5119 μs each)
|
||||
2048: 21835.1 μs (10.6617 μs each)
|
||||
4096: 46059.1 μs (11.2449 μs each)
|
||||
8192: 114127 μs (13.9315 μs each)
|
||||
16384: 181217 μs (11.0606 μs each)
|
||||
32768: 419649 μs (12.8067 μs each)
|
||||
65536: 804320 μs (12.2729 μs each)
|
||||
131072: 1.73534e+06 μs (13.2396 μs each)
|
||||
262144: 3.74451e+06 μs (14.2842 μs each)
|
||||
524288: 8.8694e+06 μs (16.917 μs each)
|
||||
1048576: 1.69993e+07 μs (16.2118 μs each)
|
||||
2097152: 3.29387e+07 μs (15.7064 μs each)
|
||||
|
|
||||
+3.29387e+07 *
|
||||
| ----
|
||||
| -----
|
||||
| ----
|
||||
| -----
|
||||
| -----
|
||||
| ----
|
||||
| -----
|
||||
| -----
|
||||
| ----
|
||||
| -----
|
||||
| ----
|
||||
| -----
|
||||
| ---
|
||||
| ---
|
||||
| ---
|
||||
| -------
|
||||
---+386---------------------------------------------------------------------+---
|
||||
+32 +2.09715e+06
|
||||
|
||||
name #n tsub ttot tavg
|
||||
..vl/lees/bucket/nilm/nilmdb/nilmdb/interval.py.__iadd__:184 4194272 10.025323 30.262723 0.000007
|
||||
..evl/lees/bucket/nilm/nilmdb/nilmdb/interval.py.__init__:27 4194272 24.715377 24.715377 0.000006
|
||||
../lees/bucket/nilm/nilmdb/nilmdb/interval.py.intersects:239 4194272 6.705053 12.577620 0.000003
|
||||
..im/devl/lees/bucket/nilm/nilmdb/tests/aplotter.py.plot:404 1 0.000048 0.001412 0.001412
|
||||
../lees/bucket/nilm/nilmdb/tests/aplotter.py.plot_double:311 1 0.000106 0.001346 0.001346
|
||||
..vl/lees/bucket/nilm/nilmdb/tests/aplotter.py.plot_data:201 1 0.000098 0.000672 0.000672
|
||||
..vl/lees/bucket/nilm/nilmdb/tests/aplotter.py.plot_line:241 16 0.000298 0.000496 0.000031
|
||||
..jim/devl/lees/bucket/nilm/nilmdb/nilmdb/printf.py.printf:4 17 0.000252 0.000334 0.000020
|
||||
..vl/lees/bucket/nilm/nilmdb/tests/aplotter.py.transposed:39 1 0.000229 0.000235 0.000235
|
||||
..vl/lees/bucket/nilm/nilmdb/tests/aplotter.py.y_reversed:45 1 0.000151 0.000174 0.000174
|
||||
|
||||
name tid fname ttot scnt
|
||||
_MainThread 47269783682784 ..b/python2.7/threading.py.setprofile:88 64.746000 1
|
Reference in New Issue
Block a user