You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

319 lines
11 KiB

  1. import nilmdb.server
  2. from nose.tools import *
  3. from nose.tools import assert_raises
  4. import distutils.version
  5. import json
  6. import itertools
  7. import os
  8. import sys
  9. import threading
  10. import urllib.request, urllib.error, urllib.parse
  11. from urllib.request import urlopen
  12. from urllib.error import HTTPError
  13. import io
  14. import time
  15. import requests
  16. import socket
  17. import sqlite3
  18. import cherrypy
  19. from nilmdb.utils import serializer_proxy
  20. from nilmdb.server.interval import Interval
  21. testdb = "tests/testdb"
  22. #@atexit.register
  23. #def cleanup():
  24. # os.unlink(testdb)
  25. from testutil.helpers import *
  26. def setup_module():
  27. # Make sure port is free
  28. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  29. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  30. try:
  31. sock.bind(("127.0.0.1", 32180))
  32. except OSError:
  33. raise AssertionError("port 32180 must be free for tests")
  34. sock.close()
  35. class Test00Nilmdb(object): # named 00 so it runs first
  36. def test_NilmDB(self):
  37. recursive_unlink(testdb)
  38. db = nilmdb.server.NilmDB(testdb)
  39. db.close()
  40. db = nilmdb.server.NilmDB(testdb)
  41. db.close()
  42. db.close()
  43. def test_error_cases(self):
  44. # Test some misc error cases to get better code coverage
  45. with assert_raises(OSError) as e:
  46. nilmdb.server.NilmDB("/dev/null/bogus")
  47. in_("can't create tree", str(e.exception))
  48. # Version upgrades
  49. con = sqlite3.connect(os.path.join(testdb, "data.sql"))
  50. con.execute("PRAGMA user_version = 2");
  51. con.close()
  52. with assert_raises(Exception) as e:
  53. db = nilmdb.server.NilmDB(testdb)
  54. in_("can't use database version 2", str(e.exception))
  55. con = sqlite3.connect(os.path.join(testdb, "data.sql"))
  56. con.execute("PRAGMA user_version = -1234");
  57. con.close()
  58. with assert_raises(Exception) as e:
  59. db = nilmdb.server.NilmDB(testdb)
  60. in_("unknown database version -1234", str(e.exception))
  61. recursive_unlink(testdb)
  62. nilmdb.server.NilmDB.verbose = 1
  63. (old, sys.stdout) = (sys.stdout, io.StringIO())
  64. db = nilmdb.server.NilmDB(testdb)
  65. (output, sys.stdout) = (sys.stdout.getvalue(), old)
  66. nilmdb.server.NilmDB.verbose = 0
  67. db.close()
  68. in_("Database schema updated to 1", output)
  69. # Corrupted database (bad ranges)
  70. recursive_unlink(testdb)
  71. db = nilmdb.server.NilmDB(testdb)
  72. db.con.executescript("""
  73. INSERT INTO streams VALUES (1, "/test", "int32_1");
  74. INSERT INTO ranges VALUES (1, 100, 200, 100, 200);
  75. INSERT INTO ranges VALUES (1, 150, 250, 150, 250);
  76. """)
  77. db.close()
  78. db = nilmdb.server.NilmDB(testdb)
  79. with assert_raises(nilmdb.server.NilmDBError):
  80. db.stream_intervals("/test")
  81. db.close()
  82. recursive_unlink(testdb)
  83. def test_stream(self):
  84. db = nilmdb.server.NilmDB(testdb)
  85. eq_(db.stream_list(), [])
  86. # Bad path
  87. with assert_raises(ValueError):
  88. db.stream_create("foo/bar/baz", "float32_8")
  89. with assert_raises(ValueError):
  90. db.stream_create("/foo", "float32_8")
  91. # Bad layout type
  92. with assert_raises(ValueError):
  93. db.stream_create("/newton/prep", "NoSuchLayout")
  94. db.stream_create("/newton/prep", "float32_8")
  95. db.stream_create("/newton/raw", "uint16_6")
  96. db.stream_create("/newton/zzz/rawnotch", "uint16_9")
  97. # Verify we got 3 streams
  98. eq_(db.stream_list(), [ ["/newton/prep", "float32_8"],
  99. ["/newton/raw", "uint16_6"],
  100. ["/newton/zzz/rawnotch", "uint16_9"]
  101. ])
  102. # Match just one type or one path
  103. eq_(db.stream_list(layout="uint16_6"), [ ["/newton/raw", "uint16_6"] ])
  104. eq_(db.stream_list(path="/newton/raw"), [ ["/newton/raw", "uint16_6"] ])
  105. # Set / get metadata
  106. eq_(db.stream_get_metadata("/newton/prep"), {})
  107. eq_(db.stream_get_metadata("/newton/raw"), {})
  108. meta1 = { "description": "The Data",
  109. "v_scale": "1.234" }
  110. meta2 = { "description": "The Data" }
  111. meta3 = { "v_scale": "1.234" }
  112. db.stream_set_metadata("/newton/prep", meta1)
  113. db.stream_update_metadata("/newton/prep", {})
  114. db.stream_update_metadata("/newton/raw", meta2)
  115. db.stream_update_metadata("/newton/raw", meta3)
  116. eq_(db.stream_get_metadata("/newton/prep"), meta1)
  117. eq_(db.stream_get_metadata("/newton/raw"), meta1)
  118. # fill in some misc. test coverage
  119. with assert_raises(nilmdb.server.NilmDBError):
  120. db.stream_remove("/newton/prep", 0, 0)
  121. with assert_raises(nilmdb.server.NilmDBError):
  122. db.stream_remove("/newton/prep", 1, 0)
  123. db.stream_remove("/newton/prep", 0, 1)
  124. with assert_raises(nilmdb.server.NilmDBError):
  125. db.stream_extract("/newton/prep", count = True, binary = True)
  126. db.close()
  127. class TestBlockingServer(object):
  128. def setUp(self):
  129. self.db = serializer_proxy(nilmdb.server.NilmDB)(testdb)
  130. def tearDown(self):
  131. self.db.close()
  132. def test_blocking_server(self):
  133. # Server should fail if the database doesn't have a "_thread_safe"
  134. # property.
  135. with assert_raises(KeyError):
  136. nilmdb.server.Server(object())
  137. # Start web app on a custom port
  138. self.server = nilmdb.server.Server(self.db, host = "127.0.0.1",
  139. port = 32180, stoppable = True)
  140. def start_server():
  141. # Run it
  142. event = threading.Event()
  143. def run_server():
  144. self.server.start(blocking = True, event = event)
  145. thread = threading.Thread(target = run_server)
  146. thread.start()
  147. if not event.wait(timeout = 10):
  148. raise AssertionError("server didn't start in 10 seconds")
  149. return thread
  150. # Start server and request for it to exit
  151. thread = start_server()
  152. req = urlopen("http://127.0.0.1:32180/exit/", timeout = 1)
  153. thread.join()
  154. # Mock some signals that should kill the server
  155. def try_signal(sig):
  156. old = cherrypy.engine.wait
  157. def raise_sig(*args, **kwargs):
  158. raise sig()
  159. cherrypy.engine.wait = raise_sig
  160. thread = start_server()
  161. thread.join()
  162. cherrypy.engine.wait = old
  163. try_signal(SystemExit)
  164. try_signal(KeyboardInterrupt)
  165. def geturl(path):
  166. resp = urlopen("http://127.0.0.1:32180" + path, timeout = 10)
  167. body = resp.read()
  168. return body.decode(resp.headers.get_content_charset() or 'utf-8')
  169. def getjson(path):
  170. return json.loads(geturl(path))
  171. class TestServer(object):
  172. def setUp(self):
  173. # Start web app on a custom port
  174. self.db = serializer_proxy(nilmdb.server.NilmDB)(testdb)
  175. self.server = nilmdb.server.Server(self.db, host = "127.0.0.1",
  176. port = 32180, stoppable = False)
  177. self.server.start(blocking = False)
  178. def tearDown(self):
  179. # Close web app
  180. self.server.stop()
  181. self.db.close()
  182. def test_server(self):
  183. # Make sure we can't force an exit, and test other 404 errors
  184. for url in [ "/exit", "/favicon.ico" ]:
  185. with assert_raises(HTTPError) as e:
  186. geturl(url)
  187. eq_(e.exception.code, 404)
  188. # Root page
  189. in_("This is NilmDB", geturl("/"))
  190. # Check version
  191. eq_(distutils.version.LooseVersion(getjson("/version")),
  192. distutils.version.LooseVersion(nilmdb.__version__))
  193. def test_stream_list(self):
  194. # Known streams that got populated by an earlier test (test_nilmdb)
  195. streams = getjson("/stream/list")
  196. eq_(streams, [
  197. ['/newton/prep', 'float32_8'],
  198. ['/newton/raw', 'uint16_6'],
  199. ['/newton/zzz/rawnotch', 'uint16_9'],
  200. ])
  201. streams = getjson("/stream/list?layout=uint16_6")
  202. eq_(streams, [['/newton/raw', 'uint16_6']])
  203. streams = getjson("/stream/list?layout=NoSuchLayout")
  204. eq_(streams, [])
  205. def test_stream_metadata(self):
  206. with assert_raises(HTTPError) as e:
  207. getjson("/stream/get_metadata?path=foo")
  208. eq_(e.exception.code, 404)
  209. data = getjson("/stream/get_metadata?path=/newton/prep")
  210. eq_(data, {'description': 'The Data', 'v_scale': '1.234'})
  211. data = getjson("/stream/get_metadata?path=/newton/prep"
  212. "&key=v_scale")
  213. eq_(data, {'v_scale': '1.234'})
  214. data = getjson("/stream/get_metadata?path=/newton/prep"
  215. "&key=v_scale&key=description")
  216. eq_(data, {'description': 'The Data', 'v_scale': '1.234'})
  217. data = getjson("/stream/get_metadata?path=/newton/prep"
  218. "&key=v_scale&key=foo")
  219. eq_(data, {'foo': None, 'v_scale': '1.234'})
  220. data = getjson("/stream/get_metadata?path=/newton/prep"
  221. "&key=foo")
  222. eq_(data, {'foo': None})
  223. def test_cors_headers(self):
  224. # Test that CORS headers are being set correctly
  225. # Normal GET should send simple response
  226. url = "http://127.0.0.1:32180/stream/list"
  227. r = requests.get(url, headers = { "Origin": "http://google.com/" })
  228. eq_(r.status_code, 200)
  229. if "access-control-allow-origin" not in r.headers:
  230. raise AssertionError("No Access-Control-Allow-Origin (CORS) "
  231. "header in response:\n", r.headers)
  232. eq_(r.headers["access-control-allow-origin"], "http://google.com/")
  233. # OPTIONS without CORS preflight headers should result in 405
  234. r = requests.options(url, headers = {
  235. "Origin": "http://google.com/",
  236. })
  237. eq_(r.status_code, 405)
  238. # OPTIONS with preflight headers should give preflight response
  239. r = requests.options(url, headers = {
  240. "Origin": "http://google.com/",
  241. "Access-Control-Request-Method": "POST",
  242. "Access-Control-Request-Headers": "X-Custom",
  243. })
  244. eq_(r.status_code, 200)
  245. if "access-control-allow-origin" not in r.headers:
  246. raise AssertionError("No Access-Control-Allow-Origin (CORS) "
  247. "header in response:\n", r.headers)
  248. eq_(r.headers["access-control-allow-methods"], "GET, HEAD")
  249. eq_(r.headers["access-control-allow-headers"], "X-Custom")
  250. def test_post_bodies(self):
  251. # Test JSON post bodies
  252. r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
  253. headers = { "Content-Type": "application/json" },
  254. data = '{"hello": 1}')
  255. eq_(r.status_code, 404) # wrong parameters
  256. r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
  257. headers = { "Content-Type": "application/json" },
  258. data = '["hello"]')
  259. eq_(r.status_code, 415) # not a dict
  260. r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
  261. headers = { "Content-Type": "application/json" },
  262. data = '[hello]')
  263. eq_(r.status_code, 400) # badly formatted JSON