You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

326 lines
12 KiB

  1. #!/usr/bin/env python3
  2. from nilmdb.utils.printf import printf, sprintf
  3. import nilmdb.client
  4. import nilmtools.filter
  5. import nilmtools.math
  6. from nilmdb.utils.time import timestamp_to_seconds
  7. import datetime_tz
  8. from nilmdb.utils.interval import Interval
  9. import numpy as np
  10. import scipy
  11. import scipy.signal
  12. from numpy.core.umath_tests import inner1d
  13. from collections import OrderedDict
  14. import time
  15. import functools
  16. import collections
  17. class DataError(ValueError):
  18. pass
  19. def build_column_mapping(colinfo, streaminfo):
  20. """Given the 'columns' list from the JSON data, verify and
  21. pull out a dictionary mapping for the column names/numbers."""
  22. columns = OrderedDict()
  23. for c in colinfo:
  24. col_num = c['index'] + 1 # skip timestamp
  25. if (c['name'] in list(columns.keys()) or
  26. col_num in list(columns.values())):
  27. raise DataError("duplicated columns")
  28. if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
  29. raise DataError("bad column number")
  30. columns[c['name']] = col_num
  31. if not len(columns):
  32. raise DataError("no columns")
  33. return columns
  34. class Exemplar(object):
  35. def __init__(self, exinfo, min_rows=10, max_rows=100000):
  36. """Given a dictionary entry from the 'exemplars' input JSON,
  37. verify the stream, columns, etc. Then, fetch all the data
  38. into self.data."""
  39. self.name = exinfo['name']
  40. self.url = exinfo['url']
  41. self.stream = exinfo['stream']
  42. self.start = exinfo['start']
  43. self.end = exinfo['end']
  44. self.dest_column = exinfo['dest_column']
  45. # Get stream info
  46. self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
  47. self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
  48. if not self.info:
  49. raise DataError(sprintf("exemplar stream '%s' does not exist " +
  50. "on server '%s'", self.stream, self.url))
  51. # Build up name => index mapping for the columns
  52. self.columns = build_column_mapping(exinfo['columns'], self.info)
  53. # Count points
  54. self.count = self.client.stream_count(self.stream,
  55. self.start, self.end)
  56. # Verify count
  57. if self.count == 0:
  58. raise DataError("No data in this exemplar!")
  59. if self.count < min_rows:
  60. raise DataError("Too few data points: " + str(self.count))
  61. if self.count > max_rows:
  62. raise DataError("Too many data points: " + str(self.count))
  63. # Extract the data
  64. datagen = self.client.stream_extract_numpy(self.stream,
  65. self.start, self.end,
  66. self.info.layout,
  67. maxrows=self.count)
  68. self.data = list(datagen)[0]
  69. # Extract just the columns that were specified in self.columns,
  70. # skipping the timestamp.
  71. extract_cols = [value for (key, value) in list(self.columns.items())]
  72. self.data = self.data[:, extract_cols]
  73. # Fix the column indices in e.columns, since we removed/reordered
  74. # columns in self.data
  75. for n, k in enumerate(self.columns):
  76. self.columns[k] = n
  77. # Subtract the means from each column
  78. self.data = self.data - self.data.mean(axis=0)
  79. # Get scale factors for each column by computing dot product
  80. # of each column with itself.
  81. self.scale = inner1d(self.data.T, self.data.T)
  82. # Ensure a minimum (nonzero) scale and convert to list
  83. self.scale = np.maximum(self.scale, [1e-9]).tolist()
  84. def __str__(self):
  85. return sprintf("\"%s\" %s [%s] %s rows",
  86. self.name, self.stream,
  87. ",".join(list(self.columns.keys())),
  88. self.count)
  89. def timestamp_to_short_human(timestamp):
  90. dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
  91. return dt.strftime("%H:%M:%S")
  92. def trainola_matcher(data, interval, args, insert_func, final_chunk):
  93. """Perform cross-correlation match"""
  94. (src_columns, dest_count, exemplars) = args
  95. nrows = data.shape[0]
  96. # We want at least 10% more points than the widest exemplar.
  97. widest = max([x.count for x in exemplars])
  98. if (widest * 1.1) > nrows:
  99. return 0
  100. # This is how many points we'll consider valid in the
  101. # cross-correlation.
  102. valid = nrows + 1 - widest
  103. matches = collections.defaultdict(list)
  104. # Try matching against each of the exemplars
  105. for e in exemplars:
  106. corrs = []
  107. # Compute cross-correlation for each column
  108. for col_name in e.columns:
  109. a = data[:, src_columns[col_name]]
  110. b = e.data[:, e.columns[col_name]]
  111. corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
  112. # Scale by the norm of the exemplar
  113. corr = corr / e.scale[e.columns[col_name]]
  114. corrs.append(corr)
  115. # Find the peaks using the column with the largest amplitude
  116. biggest = e.scale.index(max(e.scale))
  117. peaks = nilmtools.math.peak_detect(corrs[biggest], 0.1)
  118. # To try to reduce false positives, discard peaks where
  119. # there's a higher-magnitude peak (either min or max) within
  120. # one exemplar width nearby.
  121. good_peak_locations = []
  122. for (i, (n, p, is_max)) in enumerate(peaks):
  123. if not is_max:
  124. continue
  125. ok = True
  126. # check up to 'e.count' rows before this one
  127. j = i-1
  128. while ok and j >= 0 and peaks[j][0] > (n - e.count):
  129. if abs(peaks[j][1]) > abs(p):
  130. ok = False
  131. j -= 1
  132. # check up to 'e.count' rows after this one
  133. j = i+1
  134. while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
  135. if abs(peaks[j][1]) > abs(p):
  136. ok = False
  137. j += 1
  138. if ok:
  139. good_peak_locations.append(n)
  140. # Now look at all good peaks
  141. for row in good_peak_locations:
  142. # Correlation for each column must be close enough to 1.
  143. for (corr, scale) in zip(corrs, e.scale):
  144. # The accepted distance from 1 is based on the relative
  145. # amplitude of the column. Use a linear mapping:
  146. # scale 1.0 -> distance 0.1
  147. # scale 0.0 -> distance 1.0
  148. distance = 1 - 0.9 * (scale / e.scale[biggest])
  149. if abs(corr[row] - 1) > distance:
  150. # No match
  151. break
  152. else:
  153. # Successful match
  154. matches[row].append(e)
  155. # Insert matches into destination stream.
  156. matched_rows = sorted(matches.keys())
  157. out = np.zeros((len(matched_rows), dest_count + 1))
  158. for n, row in enumerate(matched_rows):
  159. # Fill timestamp
  160. out[n][0] = data[row, 0]
  161. # Mark matched exemplars
  162. for exemplar in matches[row]:
  163. out[n, exemplar.dest_column + 1] = 1.0
  164. # Insert it
  165. insert_func(out)
  166. # Return how many rows we processed
  167. valid = max(valid, 0)
  168. printf(" [%s] matched %d exemplars in %d rows\n",
  169. timestamp_to_short_human(data[0][0]), np.sum(out[:, 1:]), valid)
  170. return valid
  171. def trainola(conf):
  172. print("Trainola", nilmtools.__version__)
  173. # Load main stream data
  174. url = conf['url']
  175. src_path = conf['stream']
  176. dest_path = conf['dest_stream']
  177. start = conf['start']
  178. end = conf['end']
  179. # Get info for the src and dest streams
  180. src_client = nilmdb.client.numpyclient.NumpyClient(url)
  181. src = nilmtools.filter.get_stream_info(src_client, src_path)
  182. if not src:
  183. raise DataError("source path '" + src_path + "' does not exist")
  184. src_columns = build_column_mapping(conf['columns'], src)
  185. dest_client = nilmdb.client.numpyclient.NumpyClient(url)
  186. dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
  187. if not dest:
  188. raise DataError("destination path '" + dest_path + "' does not exist")
  189. printf("Source:\n")
  190. printf(" %s [%s]\n", src.path, ",".join(list(src_columns.keys())))
  191. printf("Destination:\n")
  192. printf(" %s (%s columns)\n", dest.path, dest.layout_count)
  193. # Pull in the exemplar data
  194. exemplars = []
  195. if 'exemplars' not in conf:
  196. raise DataError("missing exemplars")
  197. for n, exinfo in enumerate(conf['exemplars']):
  198. printf("Loading exemplar %d:\n", n)
  199. e = Exemplar(exinfo)
  200. col = e.dest_column
  201. if col < 0 or col >= dest.layout_count:
  202. raise DataError(sprintf("bad destination column number %d\n" +
  203. "dest stream only has 0 through %d",
  204. col, dest.layout_count - 1))
  205. printf(" %s, output column %d\n", str(e), col)
  206. exemplars.append(e)
  207. if len(exemplars) == 0:
  208. raise DataError("missing exemplars")
  209. # Verify that the exemplar columns are all represented in the main data
  210. for n, ex in enumerate(exemplars):
  211. for col in ex.columns:
  212. if col not in src_columns:
  213. raise DataError(sprintf("Exemplar %d column %s is not "
  214. "available in source data", n, col))
  215. # Figure out which intervals we should process
  216. intervals = (Interval(s, e) for (s, e) in
  217. src_client.stream_intervals(src_path,
  218. diffpath=dest_path,
  219. start=start, end=end))
  220. intervals = nilmdb.utils.interval.optimize(intervals)
  221. # Do the processing
  222. rows = 100000
  223. extractor = functools.partial(src_client.stream_extract_numpy,
  224. src.path, layout=src.layout, maxrows=rows)
  225. inserter = functools.partial(dest_client.stream_insert_numpy_context,
  226. dest.path)
  227. start = time.time()
  228. processed_time = 0
  229. printf("Processing intervals:\n")
  230. for interval in intervals:
  231. printf("%s\n", interval.human_string())
  232. nilmtools.filter.process_numpy_interval(
  233. interval, extractor, inserter, rows * 3,
  234. trainola_matcher, (src_columns, dest.layout_count, exemplars))
  235. processed_time += (timestamp_to_seconds(interval.end) -
  236. timestamp_to_seconds(interval.start))
  237. elapsed = max(time.time() - start, 1e-3)
  238. printf("Done. Processed %.2f seconds per second.\n",
  239. processed_time / elapsed)
  240. def main(argv=None):
  241. import json
  242. import sys
  243. if argv is None:
  244. argv = sys.argv[1:]
  245. if len(argv) != 1 or argv[0] == '-h' or argv[0] == '--help':
  246. printf("usage: %s [-h] [-v] <json-config-dictionary>\n\n", sys.argv[0])
  247. printf(" Where <json-config-dictionary> is a JSON-encoded " +
  248. "dictionary string\n")
  249. printf(" with exemplar and stream data.\n\n")
  250. printf(" See extras/trainola-test-param*.js in the nilmtools " +
  251. "repository\n")
  252. printf(" for examples.\n")
  253. if len(argv) != 1:
  254. raise SystemExit(1)
  255. raise SystemExit(0)
  256. if argv[0] == '-v' or argv[0] == '--version':
  257. printf("%s\n", nilmtools.__version__)
  258. raise SystemExit(0)
  259. try:
  260. # Passed in a JSON string (e.g. on the command line)
  261. conf = json.loads(argv[0])
  262. except TypeError:
  263. # Passed in the config dictionary (e.g. from NilmRun)
  264. conf = argv[0]
  265. return trainola(conf)
  266. if __name__ == "__main__":
  267. main()