diff --git a/nilmdb/interval.py b/nilmdb/interval.py index fd2fa27..96d6dc0 100644 --- a/nilmdb/interval.py +++ b/nilmdb/interval.py @@ -210,7 +210,7 @@ class IntervalSet(object): if self.intersects(other): raise IntervalError("Tried to add overlapping interval " "to this set") - self.tree.insert(rbtree.RBNode(other)) + self.tree.insert(rbtree.RBNode(other.start, other.end, other)) else: for x in other: self.__iadd__(x) @@ -244,11 +244,11 @@ class IntervalSet(object): if not isinstance(other, IntervalSet): for i in self.intersection(other): - out.tree.insert(rbtree.RBNode(i)) + out.tree.insert(rbtree.RBNode(i.start, i.end, i)) else: for x in other: for i in self.intersection(x): - out.tree.insert(rbtree.RBNode(i)) + out.tree.insert(rbtree.RBNode(i.start, i.end, i)) return out @@ -269,23 +269,14 @@ class IntervalSet(object): if i: if i.start >= interval.start and i.end <= interval.end: yield i - elif i.start > interval.end: - break else: subset = i.subset(max(i.start, interval.start), min(i.end, interval.end)) yield subset def intersects(self, other): - ### PROBABLY WRONG """Return True if this IntervalSet intersects another interval""" - node = self.tree.find_left(other.start, other.end) - if node is None: - return False - for n in self.tree.inorder(node): - if n.obj: - if n.obj.intersects(other): - return True - if n.obj > other: - break + for n in self.tree.intersect(other.start, other.end): + if n.obj.intersects(other): + return True return False diff --git a/setup.cfg b/setup.cfg index da015c3..8c3fb3d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ stop= verbosity=2 #tests=tests/test_cmdline.py #tests=tests/test_layout.py -tests=tests/test_rbtree.py +tests=tests/test_rbtree.py,tests/test_interval.py #tests=tests/test_interval.py #tests=tests/test_client.py #tests=tests/test_timestamper.py diff --git a/tests/renderdot.py b/tests/renderdot.py index 597e68b..934e924 100644 --- a/tests/renderdot.py +++ b/tests/renderdot.py @@ -71,3 +71,20 @@ class Renderer(object): gtk.main_quit() window.widget.connect('key-press-event', quit) gtk.main() + +class RBTreeRenderer(Renderer): + def __init__(self, tree): + Renderer.__init__(self, + lambda node: node.left, + lambda node: node.right, + lambda node: node.red, + lambda node: node.start, + lambda node: node.end, + tree.nil) + self.tree = tree + + def render(self, title = "RBTree", live = True): + if live: + return Renderer.render_dot_live(self, self.tree.getroot(), title) + else: + return Renderer.render_dot(self, self.tree.getroot(), title) diff --git a/tests/test_interval.py b/tests/test_interval.py index 7e06746..b207330 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -13,6 +13,13 @@ from nilmdb.interval import Interval, DBInterval, IntervalSet, IntervalError from test_helpers import * import unittest +# set to False to skip live renders +do_live_renders = False +def render(iset, description = "", live = True): + import renderdot + r = renderdot.RBTreeRenderer(iset.tree) + return r.render(description, live and do_live_renders) + def makeset(string): """Build an IntervalSet from a string, for testing purposes @@ -75,7 +82,7 @@ class TestInterval: x = (i == 123) # subset - assert(Interval(d1, d3).subset(d1, d2) == Interval(d1, d2)) + eq_(Interval(d1, d3).subset(d1, d2), Interval(d1, d2)) with assert_raises(IntervalError): x = Interval(d2, d3).subset(d1, d2) @@ -172,15 +179,15 @@ class TestInterval: def test_intervalset_geniset(self): # Test basic iset construction - assert(makeset(" [----] ") == - makeset(" [-|--] ")) + eq_(makeset(" [----] "), + makeset(" [-|--] ")) - assert(makeset("[] [--] ") + - makeset(" [] [--]") == - makeset("[|] [-----]")) + eq_(makeset("[] [--] ") + + makeset(" [] [--]"), + makeset("[|] [-----]")) - assert(makeset(" [-------]") == - makeset(" [-|-----|")) + eq_(makeset(" [-------]"), + makeset(" [-|-----|")) def test_intervalset_intersect(self): @@ -188,41 +195,59 @@ class TestInterval: with assert_raises(TypeError): # was AttributeError x = makeset("[--]") & 1234 - assert(makeset("[---------]") & - makeset(" [---] ") == - makeset(" [---] ")) - - assert(makeset(" [---] ") & - makeset("[---------]") == - makeset(" [---] ")) - - assert(makeset(" [-----]") & - makeset(" [-----] ") == - makeset(" [--] ")) - - assert(makeset(" [--] [--]") & - makeset(" [------] ") == - makeset(" [-] [-] ")) - - assert(makeset(" [---]") & - makeset(" [--] ") == - makeset(" ")) - - assert(makeset(" [---]") & - makeset(" [----] ") == - makeset(" . ")) - - assert(makeset(" [-|---]") & - makeset(" [-----|-] ") == - makeset(" [----] ")) - - assert(makeset(" [-|-] ") & - makeset(" [-|--|--] ") == - makeset(" [---] ")) - - assert(makeset(" [----][--]") & - makeset("[-] [--] []") == - makeset(" [] [-]. []")) + # Intersection with interval + eq_(makeset("[---|---][]") & + list(makeset(" [------] "))[0], + makeset(" [-----] ")) + + # Intersection with sets + eq_(makeset("[---------]") & + makeset(" [---] "), + makeset(" [---] ")) + + eq_(makeset(" [---] ") & + makeset("[---------]"), + makeset(" [---] ")) + + eq_(makeset(" [-----]") & + makeset(" [-----] "), + makeset(" [--] ")) + + eq_(makeset(" [--] [--]") & + makeset(" [------] "), + makeset(" [-] [-] ")) + + eq_(makeset(" [---]") & + makeset(" [--] "), + makeset(" ")) + + eq_(makeset(" [-|---]") & + makeset(" [-----|-] "), + makeset(" [----] ")) + + eq_(makeset(" [-|-] ") & + makeset(" [-|--|--] "), + makeset(" [---] ")) + + # Border cases -- will give different results if intervals are + # half open or fully closed. Right now, they are half open, + # although that's a little messy since the database intervals + # often contain a data point at the endpoint. + half_open = True + if half_open: + eq_(makeset(" [---]") & + makeset(" [----] "), + makeset(" ")) + eq_(makeset(" [----][--]") & + makeset("[-] [--] []"), + makeset(" [] [-] []")) + else: + eq_(makeset(" [---]") & + makeset(" [----] "), + makeset(" . ")) + eq_(makeset(" [----][--]") & + makeset("[-] [--] []"), + makeset(" [] [-]. []")) class TestIntervalDB: def test_dbinterval(self): @@ -270,25 +295,16 @@ class TestIntervalDB: class TestIntervalTree: def test_interval_tree(self): - import renderdot import random random.seed(1234) - # make a set of 500 intervals + # make a set of 100 intervals iset = IntervalSet() - j = 500 + j = 100 for i in random.sample(xrange(j),j): interval = Interval(i, i+1) iset += interval - - # Plot it - r = renderdot.Renderer(lambda node: node.cleft, - lambda node: node.cright, - lambda node: False, - lambda node: node.start, - lambda node: node.end, - iset.tree.emptynode()) - r.render_dot_live(iset.tree.rootnode(), "Random insertion") + render(iset, "Random Insertion") # remove about half of them for i in random.sample(xrange(j),j): @@ -298,33 +314,15 @@ class TestIntervalTree: # try removing an interval that doesn't exist with assert_raises(IntervalError): iset -= Interval(1234,5678) + render(iset, "Random Insertion, deletion") - # Plot it - r = renderdot.Renderer(lambda node: node.cleft, - lambda node: node.cright, - lambda node: False, - lambda node: node.start, - lambda node: node.end, - iset.tree.emptynode()) - r.render_dot_live(iset.tree.rootnode(), "Random insertion, deletion") - - # make a set of 500 intervals, inserted in order + # make a set of 100 intervals, inserted in order iset = IntervalSet() - j = 500 + j = 100 for i in xrange(j): interval = Interval(i, i+1) iset += interval - - # Plot it - r = renderdot.Renderer(lambda node: node.cleft, - lambda node: node.cright, - lambda node: False, - lambda node: node.start, - lambda node: node.end, - iset.tree.emptynode()) - r.render_dot_live(iset.tree.rootnode(), "In-order insertion") - - assert(False) + render(iset, "In-order insertion") class TestIntervalSpeed: @unittest.skip("this is slow") diff --git a/tests/test_rbtree.py b/tests/test_rbtree.py index 98d9dca..08091c4 100644 --- a/tests/test_rbtree.py +++ b/tests/test_rbtree.py @@ -15,19 +15,8 @@ import unittest do_live_renders = False def render(tree, description = "", live = True): import renderdot - if not do_live_renders: - # If not doing a live render, still render it to a string. - live = False - r = renderdot.Renderer(lambda node: node.left, - lambda node: node.right, - lambda node: node.red, - lambda node: node.start, - lambda node: node.end, - tree.nil) - if live: - r.render_dot_live(tree.getroot(), description) - else: - return r.render_dot(tree.getroot()) + r = renderdot.RBTreeRenderer(tree) + return r.render(description, live and do_live_renders) class TestRBTree: def test_rbtree(self): @@ -44,9 +33,9 @@ class TestRBTree: import random random.seed(1234) - # make a set of 500 intervals, inserted in order + # make a set of 100 intervals, inserted in order rb = RBTree() - j = 500 + j = 100 for i in xrange(j): rb.insert(RBNode(i, i+1)) render(rb, "in-order insert") @@ -57,9 +46,9 @@ class TestRBTree: rb.delete(rb.find(i, i+1)) render(rb, "in-order insert, random delete") - # make a set of 500 intervals, inserted at random + # make a set of 100 intervals, inserted at random rb = RBTree() - j = 500 + j = 100 for i in random.sample(xrange(j),j): rb.insert(RBNode(i, i+1)) render(rb, "random insert") @@ -70,8 +59,8 @@ class TestRBTree: rb.delete(rb.find(i, i+1)) render(rb, "random insert, random delete") - # in-order insert of 250 more - for i in xrange(250): + # in-order insert of 50 more + for i in xrange(50): rb.insert(RBNode(i+500, i+501)) render(rb, "random insert, random delete, in-order insert")