More work on interval sets this time
git-svn-id: https://bucket.mit.edu/svn/nilm/nilmdb@9273 ddd99763-3ecb-0310-9145-efcb8ce7c51f
This commit is contained in:
parent
c033d69836
commit
a84b9850d0
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
import bisect
|
||||
|
||||
class IntervalException(Exception):
|
||||
class IntervalError(Exception):
|
||||
pass
|
||||
|
||||
class Interval(object):
|
||||
|
@ -20,14 +21,25 @@ class Interval(object):
|
|||
return { }
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (type(value).__name__ != 'datetime'):
|
||||
raise IntervalException("Must set datetime values")
|
||||
|
||||
if (type(value) is not datetime):
|
||||
raise IntervalError("Must set datetime values")
|
||||
self.__dict__[name] = value
|
||||
if (type(self.start) is type(self.end)):
|
||||
if (self.start > self.end):
|
||||
raise IntervalException("Interval start must precede interval end")
|
||||
|
||||
raise IntervalError("Interval start must precede interval end")
|
||||
|
||||
def __cmp__(self, other):
|
||||
"""Compare two intervals. If non-equal, order determined by start then end"""
|
||||
if (self.start == other.start):
|
||||
if (self.end < other.end):
|
||||
return -1
|
||||
if (self.end > other.end):
|
||||
return 1
|
||||
return 0
|
||||
if (self.start < other.start):
|
||||
return -1
|
||||
return 1
|
||||
|
||||
def intersects(self, other):
|
||||
if (self.end <= other.start or
|
||||
self.start >= other.end):
|
||||
|
@ -35,37 +47,44 @@ class Interval(object):
|
|||
else:
|
||||
return True;
|
||||
|
||||
"""hi
|
||||
if (self.start < other.start &&
|
||||
self.end > other.start &&
|
||||
self.end < other.end)
|
||||
|
||||
start-other.start start-other.end end-other.start end-other.end
|
||||
|
||||
|
||||
---- < < > <
|
||||
----
|
||||
|
||||
---- < < < < n
|
||||
-----
|
||||
|
||||
--- > < > <
|
||||
--------
|
||||
|
||||
------- < < > >
|
||||
---
|
||||
|
||||
---- > < > >
|
||||
-----
|
||||
|
||||
----------- > > > > n
|
||||
-----
|
||||
"""
|
||||
# start-other.start start-other.end end-other.start end-other.end
|
||||
# ---- < < > <
|
||||
# ----
|
||||
#
|
||||
# ---- < < < < n
|
||||
# -----
|
||||
#
|
||||
# --- > < > <
|
||||
# --------
|
||||
#
|
||||
# ------- < < > >
|
||||
# ---
|
||||
#
|
||||
# ---- > < > >
|
||||
# -----
|
||||
#
|
||||
# ----------- > > > > n
|
||||
# -----
|
||||
|
||||
|
||||
class IntervalSet(object):
|
||||
"""A non-intersecting set of intervals"""
|
||||
"""A non-intersecting set of intervals
|
||||
|
||||
def __init__(self, value):
|
||||
print "hello" + value
|
||||
Kept sorted internally"""
|
||||
|
||||
def __init__(self, iterable=None):
|
||||
self.data = []
|
||||
if type(iterable) is not None:
|
||||
if type(iterable) is Interval:
|
||||
iterable = [iterable]
|
||||
self.add_intervals(iterable)
|
||||
|
||||
def add_intervals(self, iterable):
|
||||
for element in iter(iterable):
|
||||
self.add_single_interval(element)
|
||||
|
||||
def add_single_interval(self, interval):
|
||||
for existing in self.data:
|
||||
if existing.intersects(interval):
|
||||
raise IntervalError("Tried to add overlapping interval to this set")
|
||||
bisect.insort(self.data, interval)
|
||||
|
|
|
@ -1,20 +1,41 @@
|
|||
from nilmdb import Interval, IntervalSet, IntervalException
|
||||
from nilmdb import Interval, IntervalSet, IntervalError
|
||||
from datetime import datetime
|
||||
from nose.tools import assert_raises
|
||||
import itertools
|
||||
|
||||
def test_interval():
|
||||
"""Test the Interval class"""
|
||||
start = datetime.strptime("19801205","%Y%m%d")
|
||||
end = datetime.strptime("20110216","%Y%m%d")
|
||||
d1 = datetime.strptime("19801205","%Y%m%d")
|
||||
d2 = datetime.strptime("19900216","%Y%m%d")
|
||||
d3 = datetime.strptime("20111205","%Y%m%d")
|
||||
|
||||
# basic construction
|
||||
i = Interval(start, end)
|
||||
assert(i.start == start)
|
||||
assert(i.end == end)
|
||||
i = Interval(d1, d1)
|
||||
i = Interval(d1, d3)
|
||||
assert(i.start == d1)
|
||||
assert(i.end == d3)
|
||||
|
||||
# assignment should work
|
||||
i.start = d2
|
||||
try:
|
||||
i.end = d1
|
||||
raise Exception("should have bombed out there")
|
||||
except IntervalError:
|
||||
pass
|
||||
|
||||
# end before start
|
||||
assert_raises(Exception, Interval, end, start)
|
||||
assert_raises(IntervalError, Interval, d3, d1)
|
||||
|
||||
# wrong type
|
||||
assert_raises(IntervalError, Interval, 1, 2)
|
||||
|
||||
# compare
|
||||
assert(Interval(d1, d2) == Interval(d1, d2))
|
||||
assert(Interval(d1, d2) < Interval(d1, d3))
|
||||
assert(Interval(d1, d2) < Interval(d2, d3))
|
||||
assert(Interval(d1, d3) < Interval(d2, d3))
|
||||
assert(Interval(d2, d2) > Interval(d1, d3))
|
||||
assert(Interval(d3, d3) == Interval(d3, d3))
|
||||
|
||||
def test_interval_intersect():
|
||||
"""Test Interval intersections"""
|
||||
|
@ -31,11 +52,23 @@ def test_interval_intersect():
|
|||
i2 = Interval(c, d)
|
||||
assert(i1.intersects(i2) == i2.intersects(i1))
|
||||
assert(i in should_intersect[i1.intersects(i2)])
|
||||
except IntervalException:
|
||||
except IntervalError:
|
||||
assert(i not in should_intersect[True] and
|
||||
i not in should_intersect[False])
|
||||
|
||||
def test_intervalset():
|
||||
"""Test interval sets"""
|
||||
#iset = IntervalSet("hi")
|
||||
|
||||
d1 = datetime.strptime("19801205","%Y%m%d")
|
||||
d2 = datetime.strptime("19900216","%Y%m%d")
|
||||
d3 = datetime.strptime("20111205","%Y%m%d")
|
||||
|
||||
a = Interval(d1, d2)
|
||||
b = Interval(d2, d3)
|
||||
c = Interval(d1, d3)
|
||||
|
||||
iset = IntervalSet(a)
|
||||
iset = IntervalSet([a, b])
|
||||
assert_raises(IntervalError, IntervalSet, [a, b, c])
|
||||
|
||||
iset = IntervalSet(iset)
|
||||
|
|
Loading…
Reference in New Issue
Block a user