You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

488 lines
18 KiB

  1. # -*- coding: utf-8 -*-
  2. """NilmDB
  3. Object that represents a NILM database file.
  4. Manages both the SQL database and the PyTables storage backend.
  5. """
  6. # Need absolute_import so that "import nilmdb" won't pull in nilmdb.py,
  7. # but will pull the nilmdb module instead.
  8. from __future__ import absolute_import
  9. import nilmdb
  10. from nilmdb.printf import *
  11. import sqlite3
  12. import tables
  13. import time
  14. import sys
  15. import os
  16. import errno
  17. import bisect
  18. import pyximport
  19. pyximport.install()
  20. from nilmdb.interval import Interval, DBInterval, IntervalSet, IntervalError
  21. # Note about performance and transactions:
  22. #
  23. # Committing a transaction in the default sync mode (PRAGMA synchronous=FULL)
  24. # takes about 125msec. sqlite3 will commit transactions at 3 times:
  25. # 1: explicit con.commit()
  26. # 2: between a series of DML commands and non-DML commands, e.g.
  27. # after a series of INSERT, SELECT, but before a CREATE TABLE or PRAGMA.
  28. # 3: at the end of an explicit transaction, e.g. "with self.con as con:"
  29. #
  30. # To speed up testing, or if this transaction speed becomes an issue,
  31. # the sync=False option to NilmDB.__init__ will set PRAGMA synchronous=OFF.
  32. # Don't touch old entries -- just add new ones.
  33. _sql_schema_updates = {
  34. 0: """
  35. -- All streams
  36. CREATE TABLE streams(
  37. id INTEGER PRIMARY KEY, -- stream ID
  38. path TEXT UNIQUE NOT NULL, -- path, e.g. '/newton/prep'
  39. layout TEXT NOT NULL -- layout name, e.g. float32_8
  40. );
  41. -- Individual timestamped ranges in those streams.
  42. -- For a given start_time and end_time, this tells us that the
  43. -- data is stored between start_pos and end_pos.
  44. -- Times are stored as μs since Unix epoch
  45. -- Positions are opaque: PyTables rows, file offsets, etc.
  46. --
  47. -- Note: end_pos points to the row _after_ end_time, so end_pos-1
  48. -- is the last valid row.
  49. CREATE TABLE ranges(
  50. stream_id INTEGER NOT NULL,
  51. start_time INTEGER NOT NULL,
  52. end_time INTEGER NOT NULL,
  53. start_pos INTEGER NOT NULL,
  54. end_pos INTEGER NOT NULL
  55. );
  56. CREATE INDEX _ranges_index ON ranges (stream_id, start_time, end_time);
  57. """,
  58. 1: """
  59. -- Generic dictionary-type metadata that can be associated with a stream
  60. CREATE TABLE metadata(
  61. stream_id INTEGER NOT NULL,
  62. key TEXT NOT NULL,
  63. value TEXT
  64. );
  65. """,
  66. }
  67. class NilmDBError(Exception):
  68. """Base exception for NilmDB errors"""
  69. def __init__(self, message = "Unspecified error"):
  70. Exception.__init__(self, self.__class__.__name__ + ": " + message)
  71. class StreamError(NilmDBError):
  72. pass
  73. class OverlapError(NilmDBError):
  74. pass
  75. # Helper that lets us pass a Pytables table into bisect
  76. class BisectableTable(object):
  77. def __init__(self, table):
  78. self.table = table
  79. def __getitem__(self, index):
  80. return self.table[index][0]
  81. class NilmDB(object):
  82. verbose = 0
  83. def __init__(self, basepath, sync=True, max_results=None):
  84. # set up path
  85. self.basepath = os.path.abspath(basepath.rstrip('/'))
  86. # Create the database path if it doesn't exist
  87. try:
  88. os.makedirs(self.basepath)
  89. except OSError as e:
  90. if e.errno != errno.EEXIST:
  91. raise IOError("can't create tree " + self.basepath)
  92. # Our HD5 file goes inside it
  93. h5filename = os.path.abspath(self.basepath + "/data.h5")
  94. self.h5file = tables.openFile(h5filename, "a", "NILM Database")
  95. # SQLite database too
  96. sqlfilename = os.path.abspath(self.basepath + "/data.sql")
  97. # We use check_same_thread = False, assuming that the rest
  98. # of the code (e.g. Server) will be smart and not access this
  99. # database from multiple threads simultaneously. That requirement
  100. # may be relaxed later.
  101. self.con = sqlite3.connect(sqlfilename, check_same_thread = False)
  102. self._sql_schema_update()
  103. # See big comment at top about the performance implications of this
  104. if sync:
  105. self.con.execute("PRAGMA synchronous=FULL")
  106. else:
  107. self.con.execute("PRAGMA synchronous=OFF")
  108. # Approximate largest number of elements that we want to send
  109. # in a single reply (for stream_intervals, stream_extract)
  110. if max_results:
  111. self.max_results = max_results
  112. else:
  113. self.max_results = 16384
  114. self.opened = True
  115. # Cached intervals
  116. self._cached_iset = {}
  117. def __del__(self):
  118. if "opened" in self.__dict__: # pragma: no cover
  119. fprintf(sys.stderr,
  120. "error: NilmDB.close() wasn't called, path %s",
  121. self.basepath)
  122. def get_basepath(self):
  123. return self.basepath
  124. def close(self):
  125. if self.con:
  126. self.con.commit()
  127. self.con.close()
  128. self.h5file.close()
  129. del self.opened
  130. def _sql_schema_update(self):
  131. cur = self.con.cursor()
  132. version = cur.execute("PRAGMA user_version").fetchone()[0]
  133. oldversion = version
  134. while version in _sql_schema_updates:
  135. cur.executescript(_sql_schema_updates[version])
  136. version = version + 1
  137. if self.verbose: # pragma: no cover
  138. printf("Schema updated to %d\n", version)
  139. if version != oldversion:
  140. with self.con:
  141. cur.execute("PRAGMA user_version = {v:d}".format(v=version))
  142. def _get_intervals(self, stream_id):
  143. """
  144. Return a mutable IntervalSet corresponding to the given stream ID.
  145. """
  146. # Load from database if not cached
  147. if stream_id not in self._cached_iset:
  148. iset = IntervalSet()
  149. result = self.con.execute("SELECT start_time, end_time, "
  150. "start_pos, end_pos "
  151. "FROM ranges "
  152. "WHERE stream_id=?", (stream_id,))
  153. try:
  154. for (start_time, end_time, start_pos, end_pos) in result:
  155. iset += DBInterval(start_time, end_time,
  156. start_time, end_time,
  157. start_pos, end_pos)
  158. except IntervalError as e: # pragma: no cover
  159. raise NilmDBError("unexpected overlap in ranges table!")
  160. self._cached_iset[stream_id] = iset
  161. # Return cached value
  162. return self._cached_iset[stream_id]
  163. # TODO: Split add_interval into two pieces, one to add
  164. # and one to flush to disk?
  165. # Need to think about this. Basic problem is that we can't
  166. # mess with intervals once they're in the IntervalSet,
  167. # without mucking with bxinterval internals.
  168. # Maybe add a separate optimization step?
  169. # Join intervals that have a fairly small gap between them
  170. def _add_interval(self, stream_id, interval, start_pos, end_pos):
  171. """
  172. Add interval to the internal interval cache, and to the database.
  173. Note: arguments must be ints (not numpy.int64, etc)
  174. """
  175. # Ensure this stream's intervals are cached, and add the new
  176. # interval to that cache.
  177. iset = self._get_intervals(stream_id)
  178. try:
  179. iset += DBInterval(interval.start, interval.end,
  180. interval.start, interval.end,
  181. start_pos, end_pos)
  182. except IntervalError as e: # pragma: no cover
  183. raise NilmDBError("new interval overlaps existing data")
  184. # Insert into the database
  185. self.con.execute("INSERT INTO ranges "
  186. "(stream_id,start_time,end_time,start_pos,end_pos) "
  187. "VALUES (?,?,?,?,?)",
  188. (stream_id, interval.start, interval.end,
  189. int(start_pos), int(end_pos)))
  190. self.con.commit()
  191. def stream_list(self, path = None, layout = None):
  192. """Return list of [path, layout] lists of all streams
  193. in the database.
  194. If path is specified, include only streams with a path that
  195. matches the given string.
  196. If layout is specified, include only streams with a layout
  197. that matches the given string.
  198. """
  199. where = "WHERE 1=1"
  200. params = ()
  201. if layout:
  202. where += " AND layout=?"
  203. params += (layout,)
  204. if path:
  205. where += " AND path=?"
  206. params += (path,)
  207. result = self.con.execute("SELECT path, layout "
  208. "FROM streams " + where, params).fetchall()
  209. return sorted(list(x) for x in result)
  210. def stream_intervals(self, path, start = None, end = None):
  211. """
  212. Returns (intervals, restart) tuple.
  213. intervals is a list of [start,end] timestamps of all intervals
  214. that exist for path, between start and end.
  215. restart, if nonzero, means that there were too many results to
  216. return in a single request. The data is complete from the
  217. starting timestamp to the point at which it was truncated,
  218. and a new request with a start time of 'restart' will fetch
  219. the next block of data.
  220. """
  221. stream_id = self._stream_id(path)
  222. intervals = self._get_intervals(stream_id)
  223. requested = Interval(start or 0, end or 1e12)
  224. result = []
  225. for n, i in enumerate(intervals.intersection(requested)):
  226. if n >= self.max_results:
  227. restart = i.start
  228. break
  229. result.append([i.start, i.end])
  230. else:
  231. restart = 0
  232. return (result, restart)
  233. def stream_create(self, path, layout_name):
  234. """Create a new table in the database.
  235. path: path to the data (e.g. '/newton/prep').
  236. Paths must contain at least two elements, e.g.:
  237. /newton/prep
  238. /newton/raw
  239. /newton/upstairs/prep
  240. /newton/upstairs/raw
  241. layout_name: string for nilmdb.layout.get_named(), e.g. 'float32_8'
  242. """
  243. if path[0] != '/':
  244. raise ValueError("paths must start with /")
  245. [ group, node ] = path.rsplit("/", 1)
  246. if group == '':
  247. raise ValueError("invalid path")
  248. # Get description
  249. try:
  250. desc = nilmdb.layout.get_named(layout_name).description()
  251. except KeyError:
  252. raise ValueError("no such layout")
  253. # Estimated table size (for PyTables optimization purposes): assume
  254. # 3 months worth of data at 8 KHz. It's OK if this is wrong.
  255. exp_rows = 8000 * 60*60*24*30*3
  256. # Create the table
  257. table = self.h5file.createTable(group, node,
  258. description = desc,
  259. expectedrows = exp_rows,
  260. createparents = True)
  261. # Insert into SQL database once the PyTables is happy
  262. with self.con as con:
  263. con.execute("INSERT INTO streams (path, layout) VALUES (?,?)",
  264. (path, layout_name))
  265. def _stream_id(self, path):
  266. """Return unique stream ID"""
  267. result = self.con.execute("SELECT id FROM streams WHERE path=?",
  268. (path,)).fetchone()
  269. if result is None:
  270. raise StreamError("No stream at path " + path)
  271. return result[0]
  272. def stream_set_metadata(self, path, data):
  273. """Set stream metadata from a dictionary, e.g.
  274. { description = 'Downstairs lighting',
  275. v_scaling = 123.45 }
  276. This replaces all existing metadata.
  277. """
  278. stream_id = self._stream_id(path)
  279. with self.con as con:
  280. con.execute("DELETE FROM metadata "
  281. "WHERE stream_id=?", (stream_id,))
  282. for key in data:
  283. if data[key] != '':
  284. con.execute("INSERT INTO metadata VALUES (?, ?, ?)",
  285. (stream_id, key, data[key]))
  286. def stream_get_metadata(self, path):
  287. """Return stream metadata as a dictionary."""
  288. stream_id = self._stream_id(path)
  289. result = self.con.execute("SELECT metadata.key, metadata.value "
  290. "FROM metadata "
  291. "WHERE metadata.stream_id=?", (stream_id,))
  292. data = {}
  293. for (key, value) in result:
  294. data[key] = value
  295. return data
  296. def stream_update_metadata(self, path, newdata):
  297. """Update stream metadata from a dictionary"""
  298. data = self.stream_get_metadata(path)
  299. data.update(newdata)
  300. self.stream_set_metadata(path, data)
  301. def stream_insert(self, path, parser, old_timestamp = None):
  302. """Insert new data into the database.
  303. path: Path at which to add the data
  304. parser: nilmdb.layout.Parser instance full of data to insert
  305. """
  306. if (not parser.min_timestamp or not parser.max_timestamp or
  307. not len(parser.data)):
  308. raise StreamError("no data provided")
  309. # If we were provided with an old timestamp, the expectation
  310. # is that the client has a contiguous block of time it is sending,
  311. # but it's doing it over multiple calls to stream_insert.
  312. # old_timestamp is the max_timestamp of the previous insert.
  313. # To make things continuous, use that as our starting timestamp
  314. # instead of what the parser found.
  315. if old_timestamp:
  316. min_timestamp = old_timestamp
  317. else:
  318. min_timestamp = parser.min_timestamp
  319. # First check for basic overlap using timestamp info given.
  320. stream_id = self._stream_id(path)
  321. iset = self._get_intervals(stream_id)
  322. interval = Interval(min_timestamp, parser.max_timestamp)
  323. if iset.intersects(interval):
  324. raise OverlapError("new data overlaps existing data: "
  325. + str(iset & interval))
  326. # Insert the data into pytables
  327. table = self.h5file.getNode(path)
  328. row_start = table.nrows
  329. table.append(parser.data)
  330. row_end = table.nrows
  331. table.flush()
  332. # Insert the record into the sql database.
  333. # Casts are to convert from numpy.int64.
  334. self._add_interval(stream_id, interval, int(row_start), int(row_end))
  335. # And that's all
  336. return "ok"
  337. def _find_start(self, table, interval):
  338. """
  339. Given a DBInterval, find the row in the database that
  340. corresponds to the start time. Return the first database
  341. position with a timestamp (first element) greater than or
  342. equal to 'start'.
  343. """
  344. # Optimization for the common case where an interval wasn't truncated
  345. if interval.start == interval.db_start:
  346. return interval.db_startpos
  347. return bisect.bisect_left(BisectableTable(table),
  348. interval.start,
  349. interval.db_startpos,
  350. interval.db_endpos)
  351. def _find_end(self, table, interval):
  352. """
  353. Given a DBInterval, find the row in the database that follows
  354. the end time. Return the first database position after the
  355. row with timestamp (first element) greater than or equal
  356. to 'end'.
  357. """
  358. # Optimization for the common case where an interval wasn't truncated
  359. if interval.end == interval.db_end:
  360. return interval.db_endpos
  361. # Note that we still use bisect_left here, because we don't
  362. # want to include the given timestamp in the results. This is
  363. # so a queries like 1:00 -> 2:00 and 2:00 -> 3:00 return
  364. # non-overlapping data.
  365. return bisect.bisect_left(BisectableTable(table),
  366. interval.end,
  367. interval.db_startpos,
  368. interval.db_endpos)
  369. def stream_extract(self, path, start = None, end = None, count = False):
  370. """
  371. Returns (data, restart) tuple.
  372. data is a list of raw data from the database, suitable for
  373. passing to e.g. nilmdb.layout.Formatter to translate into
  374. textual form.
  375. restart, if nonzero, means that there were too many results to
  376. return in a single request. The data is complete from the
  377. starting timestamp to the point at which it was truncated,
  378. and a new request with a start time of 'restart' will fetch
  379. the next block of data.
  380. count, if true, means to not return raw data, but just the count
  381. of rows that would have been returned. This is much faster
  382. than actually fetching the data. It is not limited by
  383. max_results.
  384. """
  385. table = self.h5file.getNode(path)
  386. stream_id = self._stream_id(path)
  387. intervals = self._get_intervals(stream_id)
  388. requested = Interval(start or 0, end or 1e12)
  389. result = []
  390. matched = 0
  391. remaining = self.max_results
  392. restart = 0
  393. for interval in intervals.intersection(requested):
  394. # Reading single rows from the table is too slow, so
  395. # we use two bisections to find both the starting and
  396. # ending row for this particular interval, then
  397. # read the entire range as one slice.
  398. row_start = self._find_start(table, interval)
  399. row_end = self._find_end(table, interval)
  400. if count:
  401. matched += row_end - row_start
  402. continue
  403. # Shorten it if we'll hit the maximum number of results
  404. row_max = row_start + remaining
  405. if row_max < row_end:
  406. row_end = row_max
  407. restart = table[row_max][0]
  408. # Gather these results up
  409. result.extend(table[row_start:row_end])
  410. # Count them
  411. remaining -= row_end - row_start
  412. if restart:
  413. break
  414. if count:
  415. return matched
  416. return (result, restart)