Update tests etc

git-svn-id: https://bucket.mit.edu/svn/nilm/nilmdb@9268 ddd99763-3ecb-0310-9145-efcb8ce7c51f
This commit is contained in:
Jim Paris 2011-02-16 22:30:26 +00:00
parent 5844afed0b
commit e2f89982cb
5 changed files with 102 additions and 5 deletions

View File

@ -1,3 +1,3 @@
from interval import Interval
from interval import *
del interval

Binary file not shown.

View File

@ -1,4 +1,7 @@
import datetime
from datetime import datetime
class IntervalException(Exception):
pass
class Interval(object):
"""Represents an interval of time"""
@ -7,9 +10,62 @@ class Interval(object):
self.start = start
self.end = end
def __repr__(self):
return "Interval(" + repr(self.start) + ", " + repr(self.end) + ")"
def __str__(self):
return "[" + str(self.start) + " -> " + str(self.end) + "]"
def __getattr__(self, name):
return { }
def __setattr__(self, name):
raise Exception("__setattr__ " + name + "called")
def __setattr__(self, name, value):
if (type(value).__name__ != 'datetime'):
raise IntervalException("Must set datetime values")
self.__dict__[name] = value
if (type(self.start) is type(self.end)):
if (self.start > self.end):
raise IntervalException("Interval start must precede interval end")
def intersects(self, other):
if (self.end <= other.start or
self.start >= other.end):
return False;
else:
return True;
"""hi
if (self.start < other.start &&
self.end > other.start &&
self.end < other.end)
start-other.start start-other.end end-other.start end-other.end
---- < < > <
----
---- < < < < n
-----
--- > < > <
--------
------- < < > >
---
---- > < > >
-----
----------- > > > > n
-----
"""
class IntervalSet(object):
"""A non-intersecting set of intervals"""
def __init__(self, value):
print "hello" + value

Binary file not shown.

View File

@ -0,0 +1,41 @@
from nilmdb import Interval, IntervalSet, IntervalException
from datetime import datetime
from nose.tools import assert_raises
import itertools
def test_interval():
"""Test the Interval class"""
start = datetime.strptime("19801205","%Y%m%d")
end = datetime.strptime("20110216","%Y%m%d")
# basic construction
i = Interval(start, end)
assert(i.start == start)
assert(i.end == end)
# end before start
assert_raises(Exception, Interval, end, start)
def test_interval_intersect():
"""Test Interval intersections"""
dates = [ datetime.strptime(year, "%y") for year in [ "00", "01", "02", "03" ] ]
perm = list(itertools.permutations(dates, 2))
prod = list(itertools.product(perm, perm))
should_intersect = {
False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]}
for i,((a,b),(c,d)) in enumerate(prod):
try:
i1 = Interval(a, b)
i2 = Interval(c, d)
assert(i1.intersects(i2) == i2.intersects(i1))
assert(i in should_intersect[i1.intersects(i2)])
except IntervalException:
assert(i not in should_intersect[True] and
i not in should_intersect[False])
def test_intervalset():
"""Test interval sets"""
#iset = IntervalSet("hi")