You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

374 lines
12 KiB

  1. # -*- coding: utf-8 -*-
  2. import nilmdb
  3. from nilmdb.utils.printf import *
  4. from nilmdb.utils import datetime_tz
  5. from nose.tools import *
  6. from nose.tools import assert_raises
  7. import itertools
  8. from nilmdb.server.interval import (Interval, DBInterval,
  9. IntervalSet, IntervalError)
  10. from testutil.helpers import *
  11. import unittest
  12. # set to False to skip live renders
  13. do_live_renders = False
  14. def render(iset, description = "", live = True):
  15. import testutil.renderdot as renderdot
  16. r = renderdot.RBTreeRenderer(iset.tree)
  17. return r.render(description, live and do_live_renders)
  18. def makeset(string):
  19. """Build an IntervalSet from a string, for testing purposes
  20. Each character is 1 second
  21. [ = interval start
  22. | = interval end + next start
  23. ] = interval end
  24. . = zero-width interval (identical start and end)
  25. anything else is ignored
  26. """
  27. iset = IntervalSet()
  28. for i, c in enumerate(string):
  29. day = i + 10000
  30. if (c == "["):
  31. start = day
  32. elif (c == "|"):
  33. iset += Interval(start, day)
  34. start = day
  35. elif (c == ")"):
  36. iset += Interval(start, day)
  37. del start
  38. elif (c == "."):
  39. iset += Interval(day, day)
  40. return iset
  41. class TestInterval:
  42. def test_interval(self):
  43. # Test Interval class
  44. os.environ['TZ'] = "America/New_York"
  45. datetime_tz._localtz = None
  46. (d1, d2, d3) = [ datetime_tz.datetime_tz.smartparse(x).totimestamp()
  47. for x in [ "03/24/2012", "03/25/2012", "03/26/2012" ] ]
  48. # basic construction
  49. i = Interval(d1, d2)
  50. i = Interval(d1, d3)
  51. eq_(i.start, d1)
  52. eq_(i.end, d3)
  53. # assignment is allowed, but not verified
  54. i.start = d2
  55. #with assert_raises(IntervalError):
  56. # i.end = d1
  57. i.start = d1
  58. i.end = d2
  59. # end before start
  60. with assert_raises(IntervalError):
  61. i = Interval(d3, d1)
  62. # compare
  63. assert(Interval(d1, d2) == Interval(d1, d2))
  64. assert(Interval(d1, d2) < Interval(d1, d3))
  65. assert(Interval(d1, d3) > Interval(d1, d2))
  66. assert(Interval(d1, d2) < Interval(d2, d3))
  67. assert(Interval(d1, d3) < Interval(d2, d3))
  68. assert(Interval(d2, d2+0.01) > Interval(d1, d3))
  69. assert(Interval(d3, d3+0.01) == Interval(d3, d3+0.01))
  70. #with assert_raises(TypeError): # was AttributeError, that's wrong
  71. # x = (i == 123)
  72. # subset
  73. eq_(Interval(d1, d3).subset(d1, d2), Interval(d1, d2))
  74. with assert_raises(IntervalError):
  75. x = Interval(d2, d3).subset(d1, d2)
  76. # big integers and floats
  77. x = Interval(5000111222, 6000111222)
  78. eq_(str(x), "[5000111222.0 -> 6000111222.0)")
  79. x = Interval(123.45, 234.56)
  80. eq_(str(x), "[123.45 -> 234.56)")
  81. # misc
  82. i = Interval(d1, d2)
  83. eq_(repr(i), repr(eval(repr(i))))
  84. eq_(str(i), "[1332561600.0 -> 1332648000.0)")
  85. def test_interval_intersect(self):
  86. # Test Interval intersections
  87. dates = [ 100, 200, 300, 400 ]
  88. perm = list(itertools.permutations(dates, 2))
  89. prod = list(itertools.product(perm, perm))
  90. should_intersect = {
  91. False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
  92. True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
  93. 32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]
  94. }
  95. for i,((a,b),(c,d)) in enumerate(prod):
  96. try:
  97. i1 = Interval(a, b)
  98. i2 = Interval(c, d)
  99. eq_(i1.intersects(i2), i2.intersects(i1))
  100. in_(i, should_intersect[i1.intersects(i2)])
  101. except IntervalError:
  102. assert(i not in should_intersect[True] and
  103. i not in should_intersect[False])
  104. with assert_raises(TypeError):
  105. x = i1.intersects(1234)
  106. def test_intervalset_construct(self):
  107. # Test IntervalSet construction
  108. dates = [ 100, 200, 300, 400 ]
  109. a = Interval(dates[0], dates[1])
  110. b = Interval(dates[1], dates[2])
  111. c = Interval(dates[0], dates[2])
  112. d = Interval(dates[2], dates[3])
  113. iseta = IntervalSet(a)
  114. isetb = IntervalSet([a, b])
  115. isetc = IntervalSet([a])
  116. ne_(iseta, isetb)
  117. eq_(iseta, isetc)
  118. with assert_raises(TypeError):
  119. x = iseta != 3
  120. ne_(IntervalSet(a), IntervalSet(b))
  121. # Note that assignment makes a new reference (not a copy)
  122. isetd = IntervalSet(isetb)
  123. isete = isetd
  124. eq_(isetd, isetb)
  125. eq_(isetd, isete)
  126. isetd -= a
  127. ne_(isetd, isetb)
  128. eq_(isetd, isete)
  129. # test iterator
  130. for interval in iseta:
  131. pass
  132. # overlap
  133. with assert_raises(IntervalError):
  134. x = IntervalSet([a, b, c])
  135. # bad types
  136. with assert_raises(Exception):
  137. x = IntervalSet([1, 2])
  138. iset = IntervalSet(isetb) # test iterator
  139. eq_(iset, isetb)
  140. eq_(len(iset), 2)
  141. eq_(len(IntervalSet()), 0)
  142. # Test adding
  143. iset = IntervalSet(a)
  144. iset += IntervalSet(b)
  145. eq_(iset, IntervalSet([a, b]))
  146. iset = IntervalSet(a)
  147. iset += b
  148. eq_(iset, IntervalSet([a, b]))
  149. iset = IntervalSet(a)
  150. iset.iadd_nocheck(b)
  151. eq_(iset, IntervalSet([a, b]))
  152. iset = IntervalSet(a) + IntervalSet(b)
  153. eq_(iset, IntervalSet([a, b]))
  154. iset = IntervalSet(b) + a
  155. eq_(iset, IntervalSet([a, b]))
  156. # A set consisting of [0-1],[1-2] should match a set consisting of [0-2]
  157. eq_(IntervalSet([a,b]), IntervalSet([c]))
  158. # Etc
  159. ne_(IntervalSet([a,d]), IntervalSet([c]))
  160. ne_(IntervalSet([c]), IntervalSet([a,d]))
  161. ne_(IntervalSet([c,d]), IntervalSet([b,d]))
  162. # misc
  163. eq_(repr(iset), repr(eval(repr(iset))))
  164. eq_(str(iset), "[[100.0 -> 200.0), [200.0 -> 300.0)]")
  165. def test_intervalset_geniset(self):
  166. # Test basic iset construction
  167. eq_(makeset(" [----) "),
  168. makeset(" [-|--) "))
  169. eq_(makeset("[) [--) ") +
  170. makeset(" [) [--)"),
  171. makeset("[|) [-----)"))
  172. eq_(makeset(" [-------)"),
  173. makeset(" [-|-----|"))
  174. def test_intervalset_intersect(self):
  175. # Test intersection (&)
  176. with assert_raises(TypeError): # was AttributeError
  177. x = makeset("[--)") & 1234
  178. # Intersection with interval
  179. eq_(makeset("[---|---)[)") &
  180. list(makeset(" [------) "))[0],
  181. makeset(" [-----) "))
  182. # Intersection with sets
  183. eq_(makeset("[---------)") &
  184. makeset(" [---) "),
  185. makeset(" [---) "))
  186. eq_(makeset(" [---) ") &
  187. makeset("[---------)"),
  188. makeset(" [---) "))
  189. eq_(makeset(" [-----)") &
  190. makeset(" [-----) "),
  191. makeset(" [--) "))
  192. eq_(makeset(" [--) [--)") &
  193. makeset(" [------) "),
  194. makeset(" [-) [-) "))
  195. eq_(makeset(" [---)") &
  196. makeset(" [--) "),
  197. makeset(" "))
  198. eq_(makeset(" [-|---)") &
  199. makeset(" [-----|-) "),
  200. makeset(" [----) "))
  201. eq_(makeset(" [-|-) ") &
  202. makeset(" [-|--|--) "),
  203. makeset(" [---) "))
  204. # Border cases -- will give different results if intervals are
  205. # half open or fully closed. Right now, they are half open,
  206. # although that's a little messy since the database intervals
  207. # often contain a data point at the endpoint.
  208. half_open = True
  209. if half_open:
  210. eq_(makeset(" [---)") &
  211. makeset(" [----) "),
  212. makeset(" "))
  213. eq_(makeset(" [----)[--)") &
  214. makeset("[-) [--) [)"),
  215. makeset(" [) [-) [)"))
  216. else:
  217. eq_(makeset(" [---)") &
  218. makeset(" [----) "),
  219. makeset(" . "))
  220. eq_(makeset(" [----)[--)") &
  221. makeset("[-) [--) [)"),
  222. makeset(" [) [-). [)"))
  223. class TestIntervalDB:
  224. def test_dbinterval(self):
  225. # Test DBInterval class
  226. i = DBInterval(100, 200, 100, 200, 10000, 20000)
  227. eq_(i.start, 100)
  228. eq_(i.end, 200)
  229. eq_(i.db_start, 100)
  230. eq_(i.db_end, 200)
  231. eq_(i.db_startpos, 10000)
  232. eq_(i.db_endpos, 20000)
  233. eq_(repr(i), repr(eval(repr(i))))
  234. # end before start
  235. with assert_raises(IntervalError):
  236. i = DBInterval(200, 100, 100, 200, 10000, 20000)
  237. # db_start too late
  238. with assert_raises(IntervalError):
  239. i = DBInterval(100, 200, 150, 200, 10000, 20000)
  240. # db_end too soon
  241. with assert_raises(IntervalError):
  242. i = DBInterval(100, 200, 100, 150, 10000, 20000)
  243. # actual start, end can be a subset
  244. a = DBInterval(150, 200, 100, 200, 10000, 20000)
  245. b = DBInterval(100, 150, 100, 200, 10000, 20000)
  246. c = DBInterval(150, 160, 100, 200, 10000, 20000)
  247. # Make a set of DBIntervals
  248. iseta = IntervalSet([a, b])
  249. isetc = IntervalSet(c)
  250. assert(iseta.intersects(a))
  251. assert(iseta.intersects(b))
  252. # Test subset
  253. with assert_raises(IntervalError):
  254. x = a.subset(150, 250)
  255. # Subset of those IntervalSets should still contain DBIntervals
  256. for i in IntervalSet(iseta.intersection(Interval(125,250))):
  257. assert(isinstance(i, DBInterval))
  258. class TestIntervalTree:
  259. def test_interval_tree(self):
  260. import random
  261. random.seed(1234)
  262. # make a set of 100 intervals
  263. iset = IntervalSet()
  264. j = 100
  265. for i in random.sample(xrange(j),j):
  266. interval = Interval(i, i+1)
  267. iset += interval
  268. render(iset, "Random Insertion")
  269. # remove about half of them
  270. for i in random.sample(xrange(j),j):
  271. if random.randint(0,1):
  272. iset -= Interval(i, i+1)
  273. # try removing an interval that doesn't exist
  274. with assert_raises(IntervalError):
  275. iset -= Interval(1234,5678)
  276. render(iset, "Random Insertion, deletion")
  277. # make a set of 100 intervals, inserted in order
  278. iset = IntervalSet()
  279. j = 100
  280. for i in xrange(j):
  281. interval = Interval(i, i+1)
  282. iset += interval
  283. render(iset, "In-order insertion")
  284. class TestIntervalSpeed:
  285. @unittest.skip("this is slow")
  286. def test_interval_speed(self):
  287. import yappi
  288. import time
  289. import testutil.aplotter as aplotter
  290. import random
  291. import math
  292. print
  293. yappi.start()
  294. speeds = {}
  295. limit = 10 # was 20
  296. for j in [ 2**x for x in range(5,limit) ]:
  297. start = time.time()
  298. iset = IntervalSet()
  299. for i in random.sample(xrange(j),j):
  300. interval = Interval(i, i+1)
  301. iset += interval
  302. speed = (time.time() - start) * 1000000.0
  303. printf("%d: %g μs (%g μs each, O(n log n) ratio %g)\n",
  304. j,
  305. speed,
  306. speed/j,
  307. speed / (j*math.log(j))) # should be constant
  308. speeds[j] = speed
  309. aplotter.plot(speeds.keys(), speeds.values(), plot_slope=True)
  310. yappi.stop()
  311. yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)