diff --git a/nilmdb/utils/interval.py b/nilmdb/utils/interval.py index cb16eee..2e9baaa 100644 --- a/nilmdb/utils/interval.py +++ b/nilmdb/utils/interval.py @@ -58,18 +58,11 @@ class Interval: raise IntervalError("not a subset") return Interval(start, end) -def set_difference(a, b): - """ - Compute the difference (a \\ b) between the intervals in 'a' and - the intervals in 'b'; i.e., the ranges that are present in 'self' - but not 'other'. - - 'a' and 'b' must both be iterables. - - Returns a generator that yields each interval in turn. - Output intervals are built as subsets of the intervals in the - first argument (a). - """ +def _interval_math_helper(a, b, op, subset = True): + """Helper for set_difference, intersection functions, + to compute interval subsets based on a math operator on ranges + present in A and B. Subsets are computed from A, or new intervals + are generated if subset = False.""" # Iterate through all starts and ends in sorted order. Add a # tag to the iterator so that we can figure out which one they # were, after sorting. @@ -84,31 +77,57 @@ def set_difference(a, b): # At each point, evaluate which type of end it is, to determine # how to build up the output intervals. a_interval = None - b_interval = None + in_a = False + in_b = False out_start = None for (ts, k, i) in nilmdb.utils.iterator.imerge(a_iter, b_iter): if k == 0: - # start a interval a_interval = i - if b_interval is None: - out_start = ts + in_a = True elif k == 1: - # start b interval - b_interval = i - if out_start is not None and out_start != ts: - yield a_interval.subset(out_start, ts) - out_start = None + in_b = True elif k == 2: - # end a interval + in_a = False + elif k == 3: + in_b = False + include = op(in_a, in_b) + if include and out_start is None: + out_start = ts + elif not include: if out_start is not None and out_start != ts: - yield a_interval.subset(out_start, ts) + if subset: + yield a_interval.subset(out_start, ts) + else: + yield Interval(out_start, ts) out_start = None - a_interval = None - elif k == 3: - # end b interval - b_interval = None - if a_interval: - out_start = ts + +def set_difference(a, b): + """ + Compute the difference (a \\ b) between the intervals in 'a' and + the intervals in 'b'; i.e., the ranges that are present in 'self' + but not 'other'. + + 'a' and 'b' must both be iterables. + + Returns a generator that yields each interval in turn. + Output intervals are built as subsets of the intervals in the + first argument (a). + """ + return _interval_math_helper(a, b, (lambda a, b: a and not b)) + +def intersection(a, b): + """ + Compute the intersection between the intervals in 'a' and the + intervals in 'b'; i.e., the ranges that are present in both 'a' + and 'b'. + + 'a' and 'b' must both be iterables. + + Returns a generator that yields each interval in turn. + Output intervals are built as subsets of the intervals in the + first argument (a). + """ + return _interval_math_helper(a, b, (lambda a, b: a and b)) def optimize(it): """ diff --git a/tests/test_interval.py b/tests/test_interval.py index f32d9d0..82ddfa9 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -234,13 +234,16 @@ class TestInterval: x = makeset("[--)") & 1234 def do_test(a, b, c, d): - # a & b == c + # a & b == c (using nilmdb.server.interval) ab = IntervalSet() for x in b: for i in (a & x): ab += i eq_(ab,c) + # a & b == c (using nilmdb.utils.interval) + eq_(IntervalSet(nilmdb.utils.interval.intersection(a,b)), c) + # a \ b == d eq_(IntervalSet(nilmdb.utils.interval.set_difference(a,b)), d) @@ -310,6 +313,17 @@ class TestInterval: eq_(nilmdb.utils.interval.set_difference( a.intersection(list(c)[0]), b.intersection(list(c)[0])), d) + # Fill out test coverage for non-subsets + def diff2(a,b, subset): + return nilmdb.utils.interval._interval_math_helper( + a, b, (lambda a, b: b and not a), subset=subset) + with assert_raises(nilmdb.utils.interval.IntervalError): + list(diff2(a,b,True)) + list(diff2(a,b,False)) + + # Empty second set + eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a) + # Empty second set eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a)