Update tests etc
git-svn-id: https://bucket.mit.edu/svn/nilm/nilmdb@9268 ddd99763-3ecb-0310-9145-efcb8ce7c51f
This commit is contained in:
parent
5844afed0b
commit
e2f89982cb
|
@ -1,3 +1,3 @@
|
|||
from interval import Interval
|
||||
from interval import *
|
||||
|
||||
del interval
|
||||
|
|
Binary file not shown.
|
@ -1,4 +1,7 @@
|
|||
import datetime
|
||||
from datetime import datetime
|
||||
|
||||
class IntervalException(Exception):
|
||||
pass
|
||||
|
||||
class Interval(object):
|
||||
"""Represents an interval of time"""
|
||||
|
@ -7,9 +10,62 @@ class Interval(object):
|
|||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __repr__(self):
|
||||
return "Interval(" + repr(self.start) + ", " + repr(self.end) + ")"
|
||||
|
||||
def __str__(self):
|
||||
return "[" + str(self.start) + " -> " + str(self.end) + "]"
|
||||
|
||||
def __getattr__(self, name):
|
||||
return { }
|
||||
|
||||
def __setattr__(self, name):
|
||||
raise Exception("__setattr__ " + name + "called")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (type(value).__name__ != 'datetime'):
|
||||
raise IntervalException("Must set datetime values")
|
||||
|
||||
self.__dict__[name] = value
|
||||
if (type(self.start) is type(self.end)):
|
||||
if (self.start > self.end):
|
||||
raise IntervalException("Interval start must precede interval end")
|
||||
|
||||
def intersects(self, other):
|
||||
if (self.end <= other.start or
|
||||
self.start >= other.end):
|
||||
return False;
|
||||
else:
|
||||
return True;
|
||||
|
||||
"""hi
|
||||
if (self.start < other.start &&
|
||||
self.end > other.start &&
|
||||
self.end < other.end)
|
||||
|
||||
start-other.start start-other.end end-other.start end-other.end
|
||||
|
||||
|
||||
---- < < > <
|
||||
----
|
||||
|
||||
---- < < < < n
|
||||
-----
|
||||
|
||||
--- > < > <
|
||||
--------
|
||||
|
||||
------- < < > >
|
||||
---
|
||||
|
||||
---- > < > >
|
||||
-----
|
||||
|
||||
----------- > > > > n
|
||||
-----
|
||||
"""
|
||||
|
||||
|
||||
class IntervalSet(object):
|
||||
"""A non-intersecting set of intervals"""
|
||||
|
||||
def __init__(self, value):
|
||||
print "hello" + value
|
||||
|
||||
|
|
Binary file not shown.
41
src/nilmdb/test_interval.py
Normal file
41
src/nilmdb/test_interval.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from nilmdb import Interval, IntervalSet, IntervalException
|
||||
from datetime import datetime
|
||||
from nose.tools import assert_raises
|
||||
import itertools
|
||||
|
||||
def test_interval():
|
||||
"""Test the Interval class"""
|
||||
start = datetime.strptime("19801205","%Y%m%d")
|
||||
end = datetime.strptime("20110216","%Y%m%d")
|
||||
|
||||
# basic construction
|
||||
i = Interval(start, end)
|
||||
assert(i.start == start)
|
||||
assert(i.end == end)
|
||||
|
||||
# end before start
|
||||
assert_raises(Exception, Interval, end, start)
|
||||
|
||||
def test_interval_intersect():
|
||||
"""Test Interval intersections"""
|
||||
dates = [ datetime.strptime(year, "%y") for year in [ "00", "01", "02", "03" ] ]
|
||||
perm = list(itertools.permutations(dates, 2))
|
||||
prod = list(itertools.product(perm, perm))
|
||||
should_intersect = {
|
||||
False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
|
||||
True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
|
||||
32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]}
|
||||
for i,((a,b),(c,d)) in enumerate(prod):
|
||||
try:
|
||||
i1 = Interval(a, b)
|
||||
i2 = Interval(c, d)
|
||||
assert(i1.intersects(i2) == i2.intersects(i1))
|
||||
assert(i in should_intersect[i1.intersects(i2)])
|
||||
except IntervalException:
|
||||
assert(i not in should_intersect[True] and
|
||||
i not in should_intersect[False])
|
||||
|
||||
def test_intervalset():
|
||||
"""Test interval sets"""
|
||||
#iset = IntervalSet("hi")
|
||||
|
Loading…
Reference in New Issue
Block a user