nilmdb/tests/test_interval.py

432 lines
14 KiB
Python

# -*- coding: utf-8 -*-
import nilmdb
from nilmdb.utils.printf import *
from nilmdb.utils import datetime_tz
from nose.tools import *
from nose.tools import assert_raises
import itertools
from nilmdb.utils.interval import IntervalError
from nilmdb.server.interval import Interval, DBInterval, IntervalSet
# so we can test them separately
from nilmdb.utils.interval import Interval as UtilsInterval
from testutil.helpers import *
import unittest
# set to False to skip live renders
do_live_renders = False
def render(iset, description = "", live = True):
import testutil.renderdot as renderdot
r = renderdot.RBTreeRenderer(iset.tree)
return r.render(description, live and do_live_renders)
def makeset(string):
"""Build an IntervalSet from a string, for testing purposes
Each character is 1 second
[ = interval start
| = interval end + next start
] = interval end
. = zero-width interval (identical start and end)
anything else is ignored
"""
iset = IntervalSet()
for i, c in enumerate(string):
day = i + 10000
if (c == "["):
start = day
elif (c == "|"):
iset += Interval(start, day)
start = day
elif (c == ")"):
iset += Interval(start, day)
del start
elif (c == "."):
iset += Interval(day, day)
return iset
class TestInterval:
def test_client_interval(self):
# Run interval tests against the Python version of Interval.
global Interval
NilmdbInterval = Interval
Interval = UtilsInterval
self.test_interval()
self.test_interval_intersect()
Interval = NilmdbInterval
# Other helpers in nilmdb.utils.interval
i = [ UtilsInterval(1,2), UtilsInterval(2,3), UtilsInterval(4,5) ]
eq_(list(nilmdb.utils.interval.optimize(i)),
[ UtilsInterval(1,3), UtilsInterval(4,5) ])
eq_(UtilsInterval(1234567890123456, 1234567890654321).human_string(),
"[ Fri, 13 Feb 2009 18:31:30.123456 -0500 -> " +
"Fri, 13 Feb 2009 18:31:30.654321 -0500 ]")
def test_interval(self):
# Test Interval class
os.environ['TZ'] = "America/New_York"
datetime_tz._localtz = None
(d1, d2, d3) = [ nilmdb.utils.time.parse_time(x)
for x in [ "03/24/2012", "03/25/2012", "03/26/2012" ] ]
# basic construction
i = Interval(d1, d2)
i = Interval(d1, d3)
eq_(i.start, d1)
eq_(i.end, d3)
# assignment is allowed, but not verified
i.start = d2
#with assert_raises(IntervalError):
# i.end = d1
i.start = d1
i.end = d2
# end before start
with assert_raises(IntervalError):
i = Interval(d3, d1)
# compare
assert(Interval(d1, d2) == Interval(d1, d2))
assert(Interval(d1, d2) < Interval(d1, d3))
assert(Interval(d1, d3) > Interval(d1, d2))
assert(Interval(d1, d2) < Interval(d2, d3))
assert(Interval(d1, d3) < Interval(d2, d3))
assert(Interval(d2, d2+1) > Interval(d1, d3))
assert(Interval(d3, d3+1) == Interval(d3, d3+1))
#with assert_raises(TypeError): # was AttributeError, that's wrong
# x = (i == 123)
# subset
eq_(Interval(d1, d3).subset(d1, d2), Interval(d1, d2))
with assert_raises(IntervalError):
x = Interval(d2, d3).subset(d1, d2)
# big integers, negative integers
x = Interval(5000111222000000, 6000111222000000)
eq_(str(x), "[5000111222000000 -> 6000111222000000)")
x = Interval(-5000111222000000, -4000111222000000)
eq_(str(x), "[-5000111222000000 -> -4000111222000000)")
# misc
i = Interval(d1, d2)
eq_(repr(i), repr(eval(repr(i))))
eq_(str(i), "[1332561600000000 -> 1332648000000000)")
def test_interval_intersect(self):
# Test Interval intersections
dates = [ 100, 200, 300, 400 ]
perm = list(itertools.permutations(dates, 2))
prod = list(itertools.product(perm, perm))
should_intersect = {
False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]
}
for i,((a,b),(c,d)) in enumerate(prod):
try:
i1 = Interval(a, b)
i2 = Interval(c, d)
eq_(i1.intersects(i2), i2.intersects(i1))
in_(i, should_intersect[i1.intersects(i2)])
except IntervalError:
assert(i not in should_intersect[True] and
i not in should_intersect[False])
with assert_raises(TypeError):
x = i1.intersects(1234)
def test_intervalset_construct(self):
# Test IntervalSet construction
dates = [ 100, 200, 300, 400 ]
a = Interval(dates[0], dates[1])
b = Interval(dates[1], dates[2])
c = Interval(dates[0], dates[2])
d = Interval(dates[2], dates[3])
iseta = IntervalSet(a)
isetb = IntervalSet([a, b])
isetc = IntervalSet([a])
ne_(iseta, isetb)
eq_(iseta, isetc)
with assert_raises(TypeError):
x = iseta != 3
ne_(IntervalSet(a), IntervalSet(b))
# Note that assignment makes a new reference (not a copy)
isetd = IntervalSet(isetb)
isete = isetd
eq_(isetd, isetb)
eq_(isetd, isete)
isetd -= a
ne_(isetd, isetb)
eq_(isetd, isete)
# test iterator
for interval in iseta:
pass
# overlap
with assert_raises(IntervalError):
x = IntervalSet([a, b, c])
# bad types
with assert_raises(Exception):
x = IntervalSet([1, 2])
iset = IntervalSet(isetb) # test iterator
eq_(iset, isetb)
eq_(len(iset), 2)
eq_(len(IntervalSet()), 0)
# Test adding
iset = IntervalSet(a)
iset += IntervalSet(b)
eq_(iset, IntervalSet([a, b]))
iset = IntervalSet(a)
iset += b
eq_(iset, IntervalSet([a, b]))
iset = IntervalSet(a)
iset.iadd_nocheck(b)
eq_(iset, IntervalSet([a, b]))
iset = IntervalSet(a) + IntervalSet(b)
eq_(iset, IntervalSet([a, b]))
iset = IntervalSet(b) + a
eq_(iset, IntervalSet([a, b]))
# A set consisting of [0-1],[1-2] should match a set consisting of [0-2]
eq_(IntervalSet([a,b]), IntervalSet([c]))
# Etc
ne_(IntervalSet([a,d]), IntervalSet([c]))
ne_(IntervalSet([c]), IntervalSet([a,d]))
ne_(IntervalSet([c,d]), IntervalSet([b,d]))
# misc
eq_(repr(iset), repr(eval(repr(iset))))
eq_(str(iset),
"[[100 -> 200), [200 -> 300)]")
def test_intervalset_geniset(self):
# Test basic iset construction
eq_(makeset(" [----) "),
makeset(" [-|--) "))
eq_(makeset("[) [--) ") +
makeset(" [) [--)"),
makeset("[|) [-----)"))
eq_(makeset(" [-------)"),
makeset(" [-|-----|"))
def test_intervalset_intersect_difference(self):
# Test intersection (&)
with assert_raises(TypeError): # was AttributeError
x = makeset("[--)") & 1234
def do_test(a, b, c, d):
# a & b == c (using nilmdb.server.interval)
ab = IntervalSet()
for x in b:
for i in (a & x):
ab += i
eq_(ab,c)
# a & b == c (using nilmdb.utils.interval)
eq_(IntervalSet(nilmdb.utils.interval.intersection(a,b)), c)
# a \ b == d
eq_(IntervalSet(nilmdb.utils.interval.set_difference(a,b)), d)
# Intersection with intervals
do_test(makeset("[---|---)[)"),
makeset(" [------) "),
makeset(" [-----) "), # intersection
makeset("[-) [)")) # difference
do_test(makeset("[---------)"),
makeset(" [---) "),
makeset(" [---) "), # intersection
makeset("[) [----)")) # difference
do_test(makeset(" [---) "),
makeset("[---------)"),
makeset(" [---) "), # intersection
makeset(" ")) # difference
do_test(makeset(" [-----)"),
makeset(" [-----) "),
makeset(" [--) "), # intersection
makeset(" [--)")) # difference
do_test(makeset(" [--) [--)"),
makeset(" [------) "),
makeset(" [-) [-) "), # intersection
makeset(" [) [)")) # difference
do_test(makeset(" [---)"),
makeset(" [--) "),
makeset(" "), # intersection
makeset(" [---)")) # difference
do_test(makeset(" [-|---)"),
makeset(" [-----|-) "),
makeset(" [----) "), # intersection
makeset(" [)")) # difference
do_test(makeset(" [-|-) "),
makeset(" [-|--|--) "),
makeset(" [---) "), # intersection
makeset(" ")) # difference
do_test(makeset("[-)[-)[-)[)"),
makeset(" [) [|)[) "),
makeset(" [) [) "), # intersection
makeset("[) [-) [)[)")) # difference
# Border cases -- will give different results if intervals are
# half open or fully closed. In nilmdb, they are half open.
do_test(makeset(" [---)"),
makeset(" [----) "),
makeset(" "), # intersection
makeset(" [---)")) # difference
do_test(makeset(" [----)[--)"),
makeset("[-) [--) [)"),
makeset(" [) [-) [)"), # intersection
makeset(" [-) [-) ")) # difference
# Set difference with bounds
a = makeset(" [----)[--)")
b = makeset("[-) [--) [)")
c = makeset("[----) ")
d = makeset(" [-) ")
eq_(nilmdb.utils.interval.set_difference(
a.intersection(list(c)[0]), b.intersection(list(c)[0])), d)
# Fill out test coverage for non-subsets
def diff2(a,b, subset):
return nilmdb.utils.interval._interval_math_helper(
a, b, (lambda a, b: b and not a), subset=subset)
with assert_raises(nilmdb.utils.interval.IntervalError):
list(diff2(a,b,True))
list(diff2(a,b,False))
# Empty second set
eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a)
# Empty second set
eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a)
class TestIntervalDB:
def test_dbinterval(self):
# Test DBInterval class
i = DBInterval(100, 200, 100, 200, 10000, 20000)
eq_(i.start, 100)
eq_(i.end, 200)
eq_(i.db_start, 100)
eq_(i.db_end, 200)
eq_(i.db_startpos, 10000)
eq_(i.db_endpos, 20000)
eq_(repr(i), repr(eval(repr(i))))
# end before start
with assert_raises(IntervalError):
i = DBInterval(200, 100, 100, 200, 10000, 20000)
# db_start too late
with assert_raises(IntervalError):
i = DBInterval(100, 200, 150, 200, 10000, 20000)
# db_end too soon
with assert_raises(IntervalError):
i = DBInterval(100, 200, 100, 150, 10000, 20000)
# actual start, end can be a subset
a = DBInterval(150, 200, 100, 200, 10000, 20000)
b = DBInterval(100, 150, 100, 200, 10000, 20000)
c = DBInterval(150, 160, 100, 200, 10000, 20000)
# Make a set of DBIntervals
iseta = IntervalSet([a, b])
isetc = IntervalSet(c)
assert(iseta.intersects(a))
assert(iseta.intersects(b))
# Test subset
with assert_raises(IntervalError):
x = a.subset(150, 250)
# Subset of those IntervalSets should still contain DBIntervals
for i in IntervalSet(iseta.intersection(Interval(125,250))):
assert(isinstance(i, DBInterval))
class TestIntervalTree:
def test_interval_tree(self):
import random
random.seed(1234)
# make a set of 100 intervals
iset = IntervalSet()
j = 100
for i in random.sample(xrange(j),j):
interval = Interval(i, i+1)
iset += interval
render(iset, "Random Insertion")
# remove about half of them
for i in random.sample(xrange(j),j):
if random.randint(0,1):
iset -= Interval(i, i+1)
# try removing an interval that doesn't exist
with assert_raises(IntervalError):
iset -= Interval(1234,5678)
render(iset, "Random Insertion, deletion")
# make a set of 100 intervals, inserted in order
iset = IntervalSet()
j = 100
for i in xrange(j):
interval = Interval(i, i+1)
iset += interval
render(iset, "In-order insertion")
class TestIntervalSpeed:
@unittest.skip("this is slow")
def test_interval_speed(self):
import yappi
import time
import random
import math
print
yappi.start()
speeds = {}
limit = 22 # was 20
for j in [ 2**x for x in range(5,limit) ]:
start = time.time()
iset = IntervalSet()
for i in random.sample(xrange(j),j):
interval = Interval(i, i+1)
iset += interval
speed = (time.time() - start) * 1000000.0
printf("%d: %g μs (%g μs each, O(n log n) ratio %g)\n",
j,
speed,
speed/j,
speed / (j*math.log(j))) # should be constant
speeds[j] = speed
yappi.stop()
yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)