You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

318 lines
12 KiB

  1. #!/usr/bin/python
  2. from nilmdb.utils.printf import *
  3. import nilmdb.client
  4. import nilmtools.filter
  5. import nilmtools.math
  6. from nilmdb.utils.time import (timestamp_to_human,
  7. timestamp_to_seconds,
  8. seconds_to_timestamp)
  9. import datetime_tz
  10. from nilmdb.utils.interval import Interval
  11. import numpy as np
  12. import scipy
  13. import scipy.signal
  14. from numpy.core.umath_tests import inner1d
  15. import nilmrun
  16. from collections import OrderedDict
  17. import sys
  18. import time
  19. import functools
  20. import collections
  21. class DataError(ValueError):
  22. pass
  23. def build_column_mapping(colinfo, streaminfo):
  24. """Given the 'columns' list from the JSON data, verify and
  25. pull out a dictionary mapping for the column names/numbers."""
  26. columns = OrderedDict()
  27. for c in colinfo:
  28. col_num = c['index'] + 1 # skip timestamp
  29. if (c['name'] in list(columns.keys()) or col_num in list(columns.values())):
  30. raise DataError("duplicated columns")
  31. if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
  32. raise DataError("bad column number")
  33. columns[c['name']] = col_num
  34. if not len(columns):
  35. raise DataError("no columns")
  36. return columns
  37. class Exemplar(object):
  38. def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
  39. """Given a dictionary entry from the 'exemplars' input JSON,
  40. verify the stream, columns, etc. Then, fetch all the data
  41. into self.data."""
  42. self.name = exinfo['name']
  43. self.url = exinfo['url']
  44. self.stream = exinfo['stream']
  45. self.start = exinfo['start']
  46. self.end = exinfo['end']
  47. self.dest_column = exinfo['dest_column']
  48. # Get stream info
  49. self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
  50. self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
  51. if not self.info:
  52. raise DataError(sprintf("exemplar stream '%s' does not exist " +
  53. "on server '%s'", self.stream, self.url))
  54. # Build up name => index mapping for the columns
  55. self.columns = build_column_mapping(exinfo['columns'], self.info)
  56. # Count points
  57. self.count = self.client.stream_count(self.stream, self.start, self.end)
  58. # Verify count
  59. if self.count == 0:
  60. raise DataError("No data in this exemplar!")
  61. if self.count < min_rows:
  62. raise DataError("Too few data points: " + str(self.count))
  63. if self.count > max_rows:
  64. raise DataError("Too many data points: " + str(self.count))
  65. # Extract the data
  66. datagen = self.client.stream_extract_numpy(self.stream,
  67. self.start, self.end,
  68. self.info.layout,
  69. maxrows = self.count)
  70. self.data = list(datagen)[0]
  71. # Extract just the columns that were specified in self.columns,
  72. # skipping the timestamp.
  73. extract_columns = [ value for (key, value) in list(self.columns.items()) ]
  74. self.data = self.data[:,extract_columns]
  75. # Fix the column indices in e.columns, since we removed/reordered
  76. # columns in self.data
  77. for n, k in enumerate(self.columns):
  78. self.columns[k] = n
  79. # Subtract the means from each column
  80. self.data = self.data - self.data.mean(axis=0)
  81. # Get scale factors for each column by computing dot product
  82. # of each column with itself.
  83. self.scale = inner1d(self.data.T, self.data.T)
  84. # Ensure a minimum (nonzero) scale and convert to list
  85. self.scale = np.maximum(self.scale, [1e-9]).tolist()
  86. def __str__(self):
  87. return sprintf("\"%s\" %s [%s] %s rows",
  88. self.name, self.stream, ",".join(list(self.columns.keys())),
  89. self.count)
  90. def timestamp_to_short_human(timestamp):
  91. dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
  92. return dt.strftime("%H:%M:%S")
  93. def trainola_matcher(data, interval, args, insert_func, final_chunk):
  94. """Perform cross-correlation match"""
  95. ( src_columns, dest_count, exemplars ) = args
  96. nrows = data.shape[0]
  97. # We want at least 10% more points than the widest exemplar.
  98. widest = max([ x.count for x in exemplars ])
  99. if (widest * 1.1) > nrows:
  100. return 0
  101. # This is how many points we'll consider valid in the
  102. # cross-correlation.
  103. valid = nrows + 1 - widest
  104. matches = collections.defaultdict(list)
  105. # Try matching against each of the exemplars
  106. for e in exemplars:
  107. corrs = []
  108. # Compute cross-correlation for each column
  109. for col_name in e.columns:
  110. a = data[:, src_columns[col_name]]
  111. b = e.data[:, e.columns[col_name]]
  112. corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
  113. # Scale by the norm of the exemplar
  114. corr = corr / e.scale[e.columns[col_name]]
  115. corrs.append(corr)
  116. # Find the peaks using the column with the largest amplitude
  117. biggest = e.scale.index(max(e.scale))
  118. peaks = nilmtools.math.peak_detect(corrs[biggest], 0.1)
  119. # To try to reduce false positives, discard peaks where
  120. # there's a higher-magnitude peak (either min or max) within
  121. # one exemplar width nearby.
  122. good_peak_locations = []
  123. for (i, (n, p, is_max)) in enumerate(peaks):
  124. if not is_max:
  125. continue
  126. ok = True
  127. # check up to 'e.count' rows before this one
  128. j = i-1
  129. while ok and j >= 0 and peaks[j][0] > (n - e.count):
  130. if abs(peaks[j][1]) > abs(p):
  131. ok = False
  132. j -= 1
  133. # check up to 'e.count' rows after this one
  134. j = i+1
  135. while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
  136. if abs(peaks[j][1]) > abs(p):
  137. ok = False
  138. j += 1
  139. if ok:
  140. good_peak_locations.append(n)
  141. # Now look at all good peaks
  142. for row in good_peak_locations:
  143. # Correlation for each column must be close enough to 1.
  144. for (corr, scale) in zip(corrs, e.scale):
  145. # The accepted distance from 1 is based on the relative
  146. # amplitude of the column. Use a linear mapping:
  147. # scale 1.0 -> distance 0.1
  148. # scale 0.0 -> distance 1.0
  149. distance = 1 - 0.9 * (scale / e.scale[biggest])
  150. if abs(corr[row] - 1) > distance:
  151. # No match
  152. break
  153. else:
  154. # Successful match
  155. matches[row].append(e)
  156. # Insert matches into destination stream.
  157. matched_rows = sorted(matches.keys())
  158. out = np.zeros((len(matched_rows), dest_count + 1))
  159. for n, row in enumerate(matched_rows):
  160. # Fill timestamp
  161. out[n][0] = data[row, 0]
  162. # Mark matched exemplars
  163. for exemplar in matches[row]:
  164. out[n, exemplar.dest_column + 1] = 1.0
  165. # Insert it
  166. insert_func(out)
  167. # Return how many rows we processed
  168. valid = max(valid, 0)
  169. printf(" [%s] matched %d exemplars in %d rows\n",
  170. timestamp_to_short_human(data[0][0]), np.sum(out[:,1:]), valid)
  171. return valid
  172. def trainola(conf):
  173. print("Trainola", nilmtools.__version__)
  174. # Load main stream data
  175. url = conf['url']
  176. src_path = conf['stream']
  177. dest_path = conf['dest_stream']
  178. start = conf['start']
  179. end = conf['end']
  180. # Get info for the src and dest streams
  181. src_client = nilmdb.client.numpyclient.NumpyClient(url)
  182. src = nilmtools.filter.get_stream_info(src_client, src_path)
  183. if not src:
  184. raise DataError("source path '" + src_path + "' does not exist")
  185. src_columns = build_column_mapping(conf['columns'], src)
  186. dest_client = nilmdb.client.numpyclient.NumpyClient(url)
  187. dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
  188. if not dest:
  189. raise DataError("destination path '" + dest_path + "' does not exist")
  190. printf("Source:\n")
  191. printf(" %s [%s]\n", src.path, ",".join(list(src_columns.keys())))
  192. printf("Destination:\n")
  193. printf(" %s (%s columns)\n", dest.path, dest.layout_count)
  194. # Pull in the exemplar data
  195. exemplars = []
  196. for n, exinfo in enumerate(conf['exemplars']):
  197. printf("Loading exemplar %d:\n", n)
  198. e = Exemplar(exinfo)
  199. col = e.dest_column
  200. if col < 0 or col >= dest.layout_count:
  201. raise DataError(sprintf("bad destination column number %d\n" +
  202. "dest stream only has 0 through %d",
  203. col, dest.layout_count - 1))
  204. printf(" %s, output column %d\n", str(e), col)
  205. exemplars.append(e)
  206. if len(exemplars) == 0:
  207. raise DataError("missing exemplars")
  208. # Verify that the exemplar columns are all represented in the main data
  209. for n, ex in enumerate(exemplars):
  210. for col in ex.columns:
  211. if col not in src_columns:
  212. raise DataError(sprintf("Exemplar %d column %s is not "
  213. "available in source data", n, col))
  214. # Figure out which intervals we should process
  215. intervals = ( Interval(s, e) for (s, e) in
  216. src_client.stream_intervals(src_path,
  217. diffpath = dest_path,
  218. start = start, end = end) )
  219. intervals = nilmdb.utils.interval.optimize(intervals)
  220. # Do the processing
  221. rows = 100000
  222. extractor = functools.partial(src_client.stream_extract_numpy,
  223. src.path, layout = src.layout, maxrows = rows)
  224. inserter = functools.partial(dest_client.stream_insert_numpy_context,
  225. dest.path)
  226. start = time.time()
  227. processed_time = 0
  228. printf("Processing intervals:\n")
  229. for interval in intervals:
  230. printf("%s\n", interval.human_string())
  231. nilmtools.filter.process_numpy_interval(
  232. interval, extractor, inserter, rows * 3,
  233. trainola_matcher, (src_columns, dest.layout_count, exemplars))
  234. processed_time += (timestamp_to_seconds(interval.end) -
  235. timestamp_to_seconds(interval.start))
  236. elapsed = max(time.time() - start, 1e-3)
  237. printf("Done. Processed %.2f seconds per second.\n",
  238. processed_time / elapsed)
  239. def main(argv = None):
  240. import simplejson as json
  241. import sys
  242. if argv is None:
  243. argv = sys.argv[1:]
  244. if len(argv) != 1 or argv[0] == '-h' or argv[0] == '--help':
  245. printf("usage: %s [-h] [-v] <json-config-dictionary>\n\n", sys.argv[0])
  246. printf(" Where <json-config-dictionary> is a JSON-encoded " +
  247. "dictionary string\n")
  248. printf(" with exemplar and stream data.\n\n")
  249. printf(" See extras/trainola-test-param*.js in the nilmtools " +
  250. "repository\n")
  251. printf(" for examples.\n")
  252. if len(argv) != 1:
  253. raise SystemExit(1)
  254. raise SystemExit(0)
  255. if argv[0] == '-v' or argv[0] == '--version':
  256. printf("%s\n", nilmtools.__version__)
  257. raise SystemExit(0)
  258. try:
  259. # Passed in a JSON string (e.g. on the command line)
  260. conf = json.loads(argv[0])
  261. except TypeError as e:
  262. # Passed in the config dictionary (e.g. from NilmRun)
  263. conf = argv[0]
  264. return trainola(conf)
  265. if __name__ == "__main__":
  266. main()