Add interval.set_difference function and associated tests
This commit is contained in:
parent
755255030b
commit
1d61d61a81
|
@ -20,6 +20,8 @@ Intervals are half-open, ie. they include data points with timestamps
|
|||
# and ends directly in the tree, like bxinterval did.
|
||||
|
||||
from ..utils.time import float_time_to_string as ftts
|
||||
from ..utils.iterator import imerge
|
||||
import itertools
|
||||
|
||||
cimport rbtree
|
||||
cdef extern from "stdint.h":
|
||||
|
@ -307,6 +309,62 @@ cdef class IntervalSet:
|
|||
else:
|
||||
yield subset
|
||||
|
||||
def set_difference(self, IntervalSet other not None,
|
||||
Interval bounds = None):
|
||||
"""
|
||||
Compute the difference (self \\ other) between this
|
||||
IntervalSet and the given IntervalSet; i.e., the ranges
|
||||
that are present in 'self' but not 'other'.
|
||||
|
||||
If 'bounds' is not None, results are limited to the range
|
||||
specified by the interval 'bounds'.
|
||||
|
||||
Returns a generator that yields each interval in turn.
|
||||
Output intervals are built as subsets of the intervals in the
|
||||
first argument (self).
|
||||
"""
|
||||
# Iterate through all starts and ends in sorted order. Add a
|
||||
# tag to the iterator so that we can figure out which one they
|
||||
# were, after sorting.
|
||||
def decorate(it, key_start, key_end):
|
||||
for i in it:
|
||||
yield i.start, key_start, i
|
||||
yield i.end, key_end, i
|
||||
if bounds is None:
|
||||
bounds = Interval(-1e12, 1e12)
|
||||
self_iter = decorate(self.intersection(bounds), 0, 2)
|
||||
other_iter = decorate(other.intersection(bounds), 1, 3)
|
||||
|
||||
# Now iterate over the timestamps of each start and end.
|
||||
# At each point, evaluate which type of end it is, to determine
|
||||
# how to build up the output intervals.
|
||||
self_interval = None
|
||||
other_interval = None
|
||||
out_start = None
|
||||
for (ts, k, i) in imerge(self_iter, other_iter):
|
||||
if k == 0:
|
||||
# start self interval
|
||||
self_interval = i
|
||||
if other_interval is None:
|
||||
out_start = ts
|
||||
elif k == 1:
|
||||
# start other interval
|
||||
other_interval = i
|
||||
if out_start is not None and out_start != ts:
|
||||
yield self_interval.subset(out_start, ts)
|
||||
out_start = None
|
||||
elif k == 2:
|
||||
# end self interval
|
||||
if out_start is not None and out_start != ts:
|
||||
yield self_interval.subset(out_start, ts)
|
||||
out_start = None
|
||||
self_interval = None
|
||||
elif k == 3:
|
||||
# end other interval
|
||||
other_interval = None
|
||||
if self_interval:
|
||||
out_start = ts
|
||||
|
||||
cpdef intersects(self, Interval other):
|
||||
"""Return True if this IntervalSet intersects another interval"""
|
||||
for n in self.tree.intersect(other.start, other.end):
|
||||
|
|
|
@ -10,3 +10,4 @@ from nilmdb.utils import atomic
|
|||
import nilmdb.utils.threadsafety
|
||||
import nilmdb.utils.fallocate
|
||||
import nilmdb.utils.time
|
||||
import nilmdb.utils.iterator
|
||||
|
|
36
nilmdb/utils/iterator.py
Normal file
36
nilmdb/utils/iterator.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Misc iterator tools
|
||||
|
||||
# Iterator merging, based on http://code.activestate.com/recipes/491285/
|
||||
import heapq
|
||||
def imerge(*iterables):
|
||||
'''Merge multiple sorted inputs into a single sorted output.
|
||||
|
||||
Equivalent to: sorted(itertools.chain(*iterables))
|
||||
|
||||
>>> list(imerge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
|
||||
[0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]
|
||||
|
||||
'''
|
||||
heappop, siftup, _Stop = heapq.heappop, heapq._siftup, StopIteration
|
||||
|
||||
h = []
|
||||
h_append = h.append
|
||||
for it in map(iter, iterables):
|
||||
try:
|
||||
next = it.next
|
||||
h_append([next(), next])
|
||||
except _Stop:
|
||||
pass
|
||||
heapq.heapify(h)
|
||||
|
||||
while 1:
|
||||
try:
|
||||
while 1:
|
||||
v, next = s = h[0] # raises IndexError when h is empty
|
||||
yield v
|
||||
s[0] = next() # raises StopIteration when exhausted
|
||||
siftup(h, 0) # restore heap condition
|
||||
except _Stop:
|
||||
heappop(h) # remove empty iterator
|
||||
except IndexError:
|
||||
return
|
|
@ -208,64 +208,89 @@ class TestInterval:
|
|||
makeset(" [-|-----|"))
|
||||
|
||||
|
||||
def test_intervalset_intersect(self):
|
||||
def test_intervalset_intersect_difference(self):
|
||||
# Test intersection (&)
|
||||
with assert_raises(TypeError): # was AttributeError
|
||||
x = makeset("[--)") & 1234
|
||||
|
||||
# Intersection with interval
|
||||
eq_(makeset("[---|---)[)") &
|
||||
list(makeset(" [------) "))[0],
|
||||
makeset(" [-----) "))
|
||||
def do_test(a, b, c, d):
|
||||
# a & b == c
|
||||
ab = IntervalSet()
|
||||
for x in b:
|
||||
for i in (a & x):
|
||||
ab += i
|
||||
eq_(ab,c)
|
||||
|
||||
# Intersection with sets
|
||||
eq_(makeset("[---------)") &
|
||||
makeset(" [---) "),
|
||||
makeset(" [---) "))
|
||||
# a \ b == d
|
||||
eq_(IntervalSet(a.set_difference(b)), d)
|
||||
|
||||
eq_(makeset(" [---) ") &
|
||||
makeset("[---------)"),
|
||||
makeset(" [---) "))
|
||||
# Intersection with intervals
|
||||
do_test(makeset("[---|---)[)"),
|
||||
makeset(" [------) "),
|
||||
makeset(" [-----) "), # intersection
|
||||
makeset("[-) [)")) # difference
|
||||
|
||||
eq_(makeset(" [-----)") &
|
||||
makeset(" [-----) "),
|
||||
makeset(" [--) "))
|
||||
do_test(makeset("[---------)"),
|
||||
makeset(" [---) "),
|
||||
makeset(" [---) "), # intersection
|
||||
makeset("[) [----)")) # difference
|
||||
|
||||
eq_(makeset(" [--) [--)") &
|
||||
makeset(" [------) "),
|
||||
makeset(" [-) [-) "))
|
||||
do_test(makeset(" [---) "),
|
||||
makeset("[---------)"),
|
||||
makeset(" [---) "), # intersection
|
||||
makeset(" ")) # difference
|
||||
|
||||
eq_(makeset(" [---)") &
|
||||
makeset(" [--) "),
|
||||
makeset(" "))
|
||||
do_test(makeset(" [-----)"),
|
||||
makeset(" [-----) "),
|
||||
makeset(" [--) "), # intersection
|
||||
makeset(" [--)")) # difference
|
||||
|
||||
eq_(makeset(" [-|---)") &
|
||||
makeset(" [-----|-) "),
|
||||
makeset(" [----) "))
|
||||
do_test(makeset(" [--) [--)"),
|
||||
makeset(" [------) "),
|
||||
makeset(" [-) [-) "), # intersection
|
||||
makeset(" [) [)")) # difference
|
||||
|
||||
eq_(makeset(" [-|-) ") &
|
||||
makeset(" [-|--|--) "),
|
||||
makeset(" [---) "))
|
||||
do_test(makeset(" [---)"),
|
||||
makeset(" [--) "),
|
||||
makeset(" "), # intersection
|
||||
makeset(" [---)")) # difference
|
||||
|
||||
do_test(makeset(" [-|---)"),
|
||||
makeset(" [-----|-) "),
|
||||
makeset(" [----) "), # intersection
|
||||
makeset(" [)")) # difference
|
||||
|
||||
do_test(makeset(" [-|-) "),
|
||||
makeset(" [-|--|--) "),
|
||||
makeset(" [---) "), # intersection
|
||||
makeset(" ")) # difference
|
||||
|
||||
do_test(makeset("[-)[-)[-)[)"),
|
||||
makeset(" [) [|)[) "),
|
||||
makeset(" [) [) "), # intersection
|
||||
makeset("[) [-) [)[)")) # difference
|
||||
|
||||
# Border cases -- will give different results if intervals are
|
||||
# half open or fully closed. Right now, they are half open,
|
||||
# although that's a little messy since the database intervals
|
||||
# often contain a data point at the endpoint.
|
||||
half_open = True
|
||||
if half_open:
|
||||
eq_(makeset(" [---)") &
|
||||
# half open or fully closed. In nilmdb, they are half open.
|
||||
do_test(makeset(" [---)"),
|
||||
makeset(" [----) "),
|
||||
makeset(" "))
|
||||
eq_(makeset(" [----)[--)") &
|
||||
makeset(" "), # intersection
|
||||
makeset(" [---)")) # difference
|
||||
|
||||
do_test(makeset(" [----)[--)"),
|
||||
makeset("[-) [--) [)"),
|
||||
makeset(" [) [-) [)"))
|
||||
else:
|
||||
eq_(makeset(" [---)") &
|
||||
makeset(" [----) "),
|
||||
makeset(" . "))
|
||||
eq_(makeset(" [----)[--)") &
|
||||
makeset("[-) [--) [)"),
|
||||
makeset(" [) [-). [)"))
|
||||
makeset(" [) [-) [)"), # intersection
|
||||
makeset(" [-) [-) ")) # difference
|
||||
|
||||
# Set difference with bounds
|
||||
a = makeset(" [----)[--)")
|
||||
b = makeset("[-) [--) [)")
|
||||
c = makeset("[----) ")
|
||||
d = makeset(" [-) ")
|
||||
eq_(a.set_difference(b, list(c)[0]), d)
|
||||
|
||||
# Empty second set
|
||||
eq_(a.set_difference(IntervalSet()), a)
|
||||
|
||||
class TestIntervalDB:
|
||||
def test_dbinterval(self):
|
||||
|
@ -371,4 +396,3 @@ class TestIntervalSpeed:
|
|||
aplotter.plot(speeds.keys(), speeds.values(), plot_slope=True)
|
||||
yappi.stop()
|
||||
yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user