You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

356 lines
11 KiB

  1. # -*- coding: utf-8 -*-
  2. import nilmdb
  3. from nilmdb.printf import *
  4. import datetime_tz
  5. from nose.tools import *
  6. from nose.tools import assert_raises
  7. import itertools
  8. from nilmdb.interval import Interval, DBInterval, IntervalSet, IntervalError
  9. from test_helpers import *
  10. import unittest
  11. # set to False to skip live renders
  12. do_live_renders = False
  13. def render(iset, description = "", live = True):
  14. import renderdot
  15. r = renderdot.RBTreeRenderer(iset.tree)
  16. return r.render(description, live and do_live_renders)
  17. def makeset(string):
  18. """Build an IntervalSet from a string, for testing purposes
  19. Each character is 1 second
  20. [ = interval start
  21. | = interval end + adjacent start
  22. ] = interval end
  23. . = zero-width interval (identical start and end)
  24. anything else is ignored
  25. """
  26. iset = IntervalSet()
  27. for i, c in enumerate(string):
  28. day = i + 10000
  29. if (c == "["):
  30. start = day
  31. elif (c == "|"):
  32. iset += Interval(start, day)
  33. start = day
  34. elif (c == "]"):
  35. iset += Interval(start, day)
  36. del start
  37. elif (c == "."):
  38. iset += Interval(day, day)
  39. return iset
  40. class TestInterval:
  41. def test_interval(self):
  42. # Test Interval class
  43. os.environ['TZ'] = "America/New_York"
  44. datetime_tz._localtz = None
  45. (d1, d2, d3) = [ datetime_tz.datetime_tz.smartparse(x).totimestamp()
  46. for x in [ "03/24/2012", "03/25/2012", "03/26/2012" ] ]
  47. # basic construction
  48. i = Interval(d1, d1)
  49. i = Interval(d1, d3)
  50. eq_(i.start, d1)
  51. eq_(i.end, d3)
  52. # assignment is allowed, but not verified
  53. i.start = d2
  54. #with assert_raises(IntervalError):
  55. # i.end = d1
  56. i.start = d1
  57. i.end = d2
  58. # end before start
  59. with assert_raises(IntervalError):
  60. i = Interval(d3, d1)
  61. # compare
  62. assert(Interval(d1, d2) == Interval(d1, d2))
  63. assert(Interval(d1, d2) < Interval(d1, d3))
  64. assert(Interval(d1, d3) > Interval(d1, d2))
  65. assert(Interval(d1, d2) < Interval(d2, d3))
  66. assert(Interval(d1, d3) < Interval(d2, d3))
  67. assert(Interval(d2, d2) > Interval(d1, d3))
  68. assert(Interval(d3, d3) == Interval(d3, d3))
  69. with assert_raises(TypeError): # was AttributeError, that's wrong
  70. x = (i == 123)
  71. # subset
  72. eq_(Interval(d1, d3).subset(d1, d2), Interval(d1, d2))
  73. with assert_raises(IntervalError):
  74. x = Interval(d2, d3).subset(d1, d2)
  75. # big integers and floats
  76. x = Interval(5000111222, 6000111222)
  77. eq_(str(x), "[5000111222.0 -> 6000111222.0]")
  78. x = Interval(123.45, 234.56)
  79. eq_(str(x), "[123.45 -> 234.56]")
  80. # misc
  81. i = Interval(d1, d2)
  82. eq_(repr(i), repr(eval(repr(i))))
  83. eq_(str(i), "[1332561600.0 -> 1332648000.0]")
  84. def test_interval_intersect(self):
  85. # Test Interval intersections
  86. dates = [ 100, 200, 300, 400 ]
  87. perm = list(itertools.permutations(dates, 2))
  88. prod = list(itertools.product(perm, perm))
  89. should_intersect = {
  90. False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
  91. True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
  92. 32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]
  93. }
  94. for i,((a,b),(c,d)) in enumerate(prod):
  95. try:
  96. i1 = Interval(a, b)
  97. i2 = Interval(c, d)
  98. eq_(i1.intersects(i2), i2.intersects(i1))
  99. in_(i, should_intersect[i1.intersects(i2)])
  100. except IntervalError:
  101. assert(i not in should_intersect[True] and
  102. i not in should_intersect[False])
  103. with assert_raises(AttributeError):
  104. x = i1.intersects(1234)
  105. def test_intervalset_construct(self):
  106. # Test IntervalSet construction
  107. dates = [ 100, 200, 300, 400 ]
  108. a = Interval(dates[0], dates[1])
  109. b = Interval(dates[1], dates[2])
  110. c = Interval(dates[0], dates[2])
  111. d = Interval(dates[2], dates[3])
  112. iseta = IntervalSet(a)
  113. isetb = IntervalSet([a, b])
  114. isetc = IntervalSet([a])
  115. ne_(iseta, isetb)
  116. eq_(iseta, isetc)
  117. with assert_raises(TypeError):
  118. x = iseta != 3
  119. ne_(IntervalSet(a), IntervalSet(b))
  120. # test iterator
  121. for interval in iseta:
  122. pass
  123. # overlap
  124. with assert_raises(IntervalError):
  125. x = IntervalSet([a, b, c])
  126. # bad types
  127. with assert_raises(Exception):
  128. x = IntervalSet([1, 2])
  129. iset = IntervalSet(isetb) # test iterator
  130. eq_(iset, isetb)
  131. eq_(len(iset), 2)
  132. eq_(len(IntervalSet()), 0)
  133. # Test adding
  134. iset = IntervalSet(a)
  135. iset += IntervalSet(b)
  136. eq_(iset, IntervalSet([a, b]))
  137. iset = IntervalSet(a)
  138. iset += b
  139. eq_(iset, IntervalSet([a, b]))
  140. iset = IntervalSet(a) + IntervalSet(b)
  141. eq_(iset, IntervalSet([a, b]))
  142. iset = IntervalSet(b) + a
  143. eq_(iset, IntervalSet([a, b]))
  144. # A set consisting of [0-1],[1-2] should match a set consisting of [0-2]
  145. eq_(IntervalSet([a,b]), IntervalSet([c]))
  146. # Etc
  147. ne_(IntervalSet([a,d]), IntervalSet([c]))
  148. ne_(IntervalSet([c]), IntervalSet([a,d]))
  149. ne_(IntervalSet([c,d]), IntervalSet([b,d]))
  150. # misc
  151. eq_(repr(iset), repr(eval(repr(iset))))
  152. eq_(str(iset), "[[100.0 -> 200.0], [200.0 -> 300.0]]")
  153. def test_intervalset_geniset(self):
  154. # Test basic iset construction
  155. eq_(makeset(" [----] "),
  156. makeset(" [-|--] "))
  157. eq_(makeset("[] [--] ") +
  158. makeset(" [] [--]"),
  159. makeset("[|] [-----]"))
  160. eq_(makeset(" [-------]"),
  161. makeset(" [-|-----|"))
  162. def test_intervalset_intersect(self):
  163. # Test intersection (&)
  164. with assert_raises(TypeError): # was AttributeError
  165. x = makeset("[--]") & 1234
  166. # Intersection with interval
  167. eq_(makeset("[---|---][]") &
  168. list(makeset(" [------] "))[0],
  169. makeset(" [-----] "))
  170. # Intersection with sets
  171. eq_(makeset("[---------]") &
  172. makeset(" [---] "),
  173. makeset(" [---] "))
  174. eq_(makeset(" [---] ") &
  175. makeset("[---------]"),
  176. makeset(" [---] "))
  177. eq_(makeset(" [-----]") &
  178. makeset(" [-----] "),
  179. makeset(" [--] "))
  180. eq_(makeset(" [--] [--]") &
  181. makeset(" [------] "),
  182. makeset(" [-] [-] "))
  183. eq_(makeset(" [---]") &
  184. makeset(" [--] "),
  185. makeset(" "))
  186. eq_(makeset(" [-|---]") &
  187. makeset(" [-----|-] "),
  188. makeset(" [----] "))
  189. eq_(makeset(" [-|-] ") &
  190. makeset(" [-|--|--] "),
  191. makeset(" [---] "))
  192. # Border cases -- will give different results if intervals are
  193. # half open or fully closed. Right now, they are half open,
  194. # although that's a little messy since the database intervals
  195. # often contain a data point at the endpoint.
  196. half_open = True
  197. if half_open:
  198. eq_(makeset(" [---]") &
  199. makeset(" [----] "),
  200. makeset(" "))
  201. eq_(makeset(" [----][--]") &
  202. makeset("[-] [--] []"),
  203. makeset(" [] [-] []"))
  204. else:
  205. eq_(makeset(" [---]") &
  206. makeset(" [----] "),
  207. makeset(" . "))
  208. eq_(makeset(" [----][--]") &
  209. makeset("[-] [--] []"),
  210. makeset(" [] [-]. []"))
  211. class TestIntervalDB:
  212. def test_dbinterval(self):
  213. # Test DBInterval class
  214. i = DBInterval(100, 200, 100, 200, 10000, 20000)
  215. eq_(i.start, 100)
  216. eq_(i.end, 200)
  217. eq_(i.db_start, 100)
  218. eq_(i.db_end, 200)
  219. eq_(i.db_startpos, 10000)
  220. eq_(i.db_endpos, 20000)
  221. eq_(repr(i), repr(eval(repr(i))))
  222. # end before start
  223. with assert_raises(IntervalError):
  224. i = DBInterval(200, 100, 100, 200, 10000, 20000)
  225. # db_start too late
  226. with assert_raises(IntervalError):
  227. i = DBInterval(100, 200, 150, 200, 10000, 20000)
  228. # db_end too soon
  229. with assert_raises(IntervalError):
  230. i = DBInterval(100, 200, 100, 150, 10000, 20000)
  231. # actual start, end can be a subset
  232. a = DBInterval(150, 200, 100, 200, 10000, 20000)
  233. b = DBInterval(100, 150, 100, 200, 10000, 20000)
  234. c = DBInterval(150, 150, 100, 200, 10000, 20000)
  235. # Make a set of DBIntervals
  236. iseta = IntervalSet([a, b])
  237. isetc = IntervalSet(c)
  238. assert(iseta.intersects(a))
  239. assert(iseta.intersects(b))
  240. # Test subset
  241. with assert_raises(IntervalError):
  242. x = a.subset(150, 250)
  243. # Subset of those IntervalSets should still contain DBIntervals
  244. for i in IntervalSet(iseta.intersection(Interval(125,250))):
  245. assert(isinstance(i, DBInterval))
  246. class TestIntervalTree:
  247. def test_interval_tree(self):
  248. import random
  249. random.seed(1234)
  250. # make a set of 100 intervals
  251. iset = IntervalSet()
  252. j = 100
  253. for i in random.sample(xrange(j),j):
  254. interval = Interval(i, i+1)
  255. iset += interval
  256. render(iset, "Random Insertion")
  257. # remove about half of them
  258. for i in random.sample(xrange(j),j):
  259. if random.randint(0,1):
  260. iset -= Interval(i, i+1)
  261. # try removing an interval that doesn't exist
  262. with assert_raises(IntervalError):
  263. iset -= Interval(1234,5678)
  264. render(iset, "Random Insertion, deletion")
  265. # make a set of 100 intervals, inserted in order
  266. iset = IntervalSet()
  267. j = 100
  268. for i in xrange(j):
  269. interval = Interval(i, i+1)
  270. iset += interval
  271. render(iset, "In-order insertion")
  272. class TestIntervalSpeed:
  273. #@unittest.skip("this is slow")
  274. def test_interval_speed(self):
  275. import yappi
  276. import time
  277. import aplotter
  278. import random
  279. import math
  280. print
  281. yappi.start()
  282. speeds = {}
  283. for j in [ 2**x for x in range(5,21) ]:
  284. start = time.time()
  285. iset = IntervalSet()
  286. for i in random.sample(xrange(j),j):
  287. interval = Interval(i, i+1)
  288. iset += interval
  289. speed = (time.time() - start) * 1000000.0
  290. printf("%d: %g μs (%g μs each, O(n log n) ratio %g)\n",
  291. j,
  292. speed,
  293. speed/j,
  294. speed / (j*math.log(j))) # should be constant
  295. speeds[j] = speed
  296. aplotter.plot(speeds.keys(), speeds.values(), plot_slope=True)
  297. yappi.stop()
  298. yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)