You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

295 lines
10 KiB

  1. #!/usr/bin/python
  2. from nilmdb.utils.printf import *
  3. import nilmdb.client
  4. import nilmtools.filter
  5. from nilmdb.utils.time import (timestamp_to_human,
  6. timestamp_to_seconds,
  7. seconds_to_timestamp)
  8. from nilmdb.utils import datetime_tz
  9. from nilmdb.utils.interval import Interval
  10. import numpy as np
  11. import scipy
  12. import scipy.signal
  13. from numpy.core.umath_tests import inner1d
  14. import nilmrun
  15. from collections import OrderedDict
  16. import sys
  17. import time
  18. import functools
  19. import collections
  20. class DataError(ValueError):
  21. pass
  22. def build_column_mapping(colinfo, streaminfo):
  23. """Given the 'columns' list from the JSON data, verify and
  24. pull out a dictionary mapping for the column names/numbers."""
  25. columns = OrderedDict()
  26. for c in colinfo:
  27. if (c['name'] in columns.keys() or
  28. c['index'] in columns.values()):
  29. raise DataError("duplicated columns")
  30. if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
  31. raise DataError("bad column number")
  32. columns[c['name']] = c['index']
  33. if not len(columns):
  34. raise DataError("no columns")
  35. return columns
  36. class Exemplar(object):
  37. def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
  38. """Given a dictionary entry from the 'exemplars' input JSON,
  39. verify the stream, columns, etc. Then, fetch all the data
  40. into self.data."""
  41. self.name = exinfo['name']
  42. self.url = exinfo['url']
  43. self.stream = exinfo['stream']
  44. self.start = exinfo['start']
  45. self.end = exinfo['end']
  46. self.dest_column = exinfo['dest_column']
  47. # Get stream info
  48. self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
  49. self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
  50. # Build up name => index mapping for the columns
  51. self.columns = build_column_mapping(exinfo['columns'], self.info)
  52. # Count points
  53. self.count = self.client.stream_count(self.stream, self.start, self.end)
  54. # Verify count
  55. if self.count == 0:
  56. raise DataError("No data in this exemplar!")
  57. if self.count < min_rows:
  58. raise DataError("Too few data points: " + str(self.count))
  59. if self.count > max_rows:
  60. raise DataError("Too many data points: " + str(self.count))
  61. # Extract the data
  62. datagen = self.client.stream_extract_numpy(self.stream,
  63. self.start, self.end,
  64. self.info.layout,
  65. maxrows = self.count)
  66. self.data = list(datagen)[0]
  67. # Discard timestamp
  68. self.data = self.data[:,1:]
  69. # Subtract the mean from each column
  70. self.data = self.data - self.data.mean(axis=0)
  71. # Get scale factors for each column by computing dot product
  72. # of each column with itself.
  73. self.scale = inner1d(self.data.T, self.data.T)
  74. # Ensure a minimum (nonzero) scale and convert to list
  75. self.scale = np.maximum(self.scale, [1e-9]).tolist()
  76. def __str__(self):
  77. return sprintf("\"%s\" %s [%s] %s rows",
  78. self.name, self.stream, ",".join(self.columns.keys()),
  79. self.count)
  80. def peak_detect(data, delta):
  81. """Simple min/max peak detection algorithm, taken from my code
  82. in the disagg.m from the 10-8-5 paper"""
  83. mins = [];
  84. maxs = [];
  85. cur_min = (None, np.inf)
  86. cur_max = (None, -np.inf)
  87. lookformax = False
  88. for (n, p) in enumerate(data):
  89. if p > cur_max[1]:
  90. cur_max = (n, p)
  91. if p < cur_min[1]:
  92. cur_min = (n, p)
  93. if lookformax:
  94. if p < (cur_max[1] - delta):
  95. maxs.append(cur_max)
  96. cur_min = (n, p)
  97. lookformax = False
  98. else:
  99. if p > (cur_min[1] + delta):
  100. mins.append(cur_min)
  101. cur_max = (n, p)
  102. lookformax = True
  103. return (mins, maxs)
  104. def timestamp_to_short_human(timestamp):
  105. dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
  106. return dt.strftime("%H:%M:%S")
  107. def trainola_matcher(data, interval, args, insert_func, final_chunk):
  108. """Perform cross-correlation match"""
  109. ( src_columns, dest_count, exemplars ) = args
  110. nrows = data.shape[0]
  111. # We want at least 10% more points than the widest exemplar.
  112. widest = max([ x.count for x in exemplars ])
  113. if (widest * 1.1) > nrows:
  114. return 0
  115. # This is how many points we'll consider valid in the
  116. # cross-correlation.
  117. valid = nrows + 1 - widest
  118. matches = collections.defaultdict(list)
  119. # Try matching against each of the exemplars
  120. for e in exemplars:
  121. corrs = []
  122. # Compute cross-correlation for each column
  123. for col_name in e.columns:
  124. a = data[:, src_columns[col_name] + 1]
  125. b = e.data[:, e.columns[col_name]]
  126. corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
  127. # Scale by the norm of the exemplar
  128. corr = corr / e.scale[e.columns[col_name]]
  129. corrs.append(corr)
  130. # Find the peaks using the column with the largest amplitude
  131. biggest = e.scale.index(max(e.scale))
  132. peaks_minmax = peak_detect(corrs[biggest], 0.1)
  133. peaks = [ p[0] for p in peaks_minmax[1] ]
  134. # Now look at every peak
  135. for row in peaks:
  136. # Correlation for each column must be close enough to 1.
  137. for (corr, scale) in zip(corrs, e.scale):
  138. # The accepted distance from 1 is based on the relative
  139. # amplitude of the column. Use a linear mapping:
  140. # scale 1.0 -> distance 0.1
  141. # scale 0.0 -> distance 1.0
  142. distance = 1 - 0.9 * (scale / e.scale[biggest])
  143. if abs(corr[row] - 1) > distance:
  144. # No match
  145. break
  146. else:
  147. # Successful match
  148. matches[row].append(e)
  149. # Insert matches into destination stream.
  150. matched_rows = sorted(matches.keys())
  151. out = np.zeros((len(matched_rows), dest_count + 1))
  152. for n, row in enumerate(matched_rows):
  153. # Fill timestamp
  154. out[n][0] = data[row, 0]
  155. # Mark matched exemplars
  156. for exemplar in matches[row]:
  157. out[n, exemplar.dest_column + 1] = 1.0
  158. # Insert it
  159. insert_func(out)
  160. # Return how many rows we processed
  161. valid = max(valid, 0)
  162. printf(" [%s] matched %d exemplars in %d rows\n",
  163. timestamp_to_short_human(data[0][0]), np.sum(out[:,1:]), valid)
  164. return valid
  165. def trainola(conf):
  166. print "Trainola", nilmtools.__version__
  167. # Load main stream data
  168. url = conf['url']
  169. src_path = conf['stream']
  170. dest_path = conf['dest_stream']
  171. start = conf['start']
  172. end = conf['end']
  173. # Get info for the src and dest streams
  174. src_client = nilmdb.client.numpyclient.NumpyClient(url)
  175. src = nilmtools.filter.get_stream_info(src_client, src_path)
  176. if not src:
  177. raise DataError("source path '" + src_path + "' does not exist")
  178. src_columns = build_column_mapping(conf['columns'], src)
  179. dest_client = nilmdb.client.numpyclient.NumpyClient(url)
  180. dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
  181. if not dest:
  182. raise DataError("destination path '" + dest_path + "' does not exist")
  183. printf("Source:\n")
  184. printf(" %s [%s]\n", src.path, ",".join(src_columns.keys()))
  185. printf("Destination:\n")
  186. printf(" %s (%s columns)\n", dest.path, dest.layout_count)
  187. # Pull in the exemplar data
  188. exemplars = []
  189. for n, exinfo in enumerate(conf['exemplars']):
  190. printf("Loading exemplar %d:\n", n)
  191. e = Exemplar(exinfo)
  192. col = e.dest_column
  193. if col < 0 or col >= dest.layout_count:
  194. raise DataError(sprintf("bad destination column number %d\n" +
  195. "dest stream only has 0 through %d",
  196. col, dest.layout_count - 1))
  197. printf(" %s, output column %d\n", str(e), col)
  198. exemplars.append(e)
  199. if len(exemplars) == 0:
  200. raise DataError("missing exemplars")
  201. # Verify that the exemplar columns are all represented in the main data
  202. for n, ex in enumerate(exemplars):
  203. for col in ex.columns:
  204. if col not in src_columns:
  205. raise DataError(sprintf("Exemplar %d column %s is not "
  206. "available in source data", n, col))
  207. # Figure out which intervals we should process
  208. intervals = ( Interval(s, e) for (s, e) in
  209. src_client.stream_intervals(src_path,
  210. diffpath = dest_path,
  211. start = start, end = end) )
  212. intervals = nilmdb.utils.interval.optimize(intervals)
  213. # Do the processing
  214. rows = 100000
  215. extractor = functools.partial(src_client.stream_extract_numpy,
  216. src.path, layout = src.layout, maxrows = rows)
  217. inserter = functools.partial(dest_client.stream_insert_numpy_context,
  218. dest.path)
  219. start = time.time()
  220. processed_time = 0
  221. printf("Processing intervals:\n")
  222. for interval in intervals:
  223. printf("%s\n", interval.human_string())
  224. nilmtools.filter.process_numpy_interval(
  225. interval, extractor, inserter, rows * 3,
  226. trainola_matcher, (src_columns, dest.layout_count, exemplars))
  227. processed_time += (timestamp_to_seconds(interval.end) -
  228. timestamp_to_seconds(interval.start))
  229. elapsed = max(time.time() - start, 1e-3)
  230. printf("Done. Processed %.2f seconds per second.\n",
  231. processed_time / elapsed)
  232. def main(argv = None):
  233. import simplejson as json
  234. import sys
  235. if argv is None:
  236. argv = sys.argv[1:]
  237. if len(argv) != 1:
  238. raise DataError("need one argument, either a dictionary or JSON string")
  239. try:
  240. # Passed in a JSON string (e.g. on the command line)
  241. conf = json.loads(argv[0])
  242. except TypeError as e:
  243. # Passed in the config dictionary (e.g. from NilmRun)
  244. conf = argv[0]
  245. return trainola(conf)
  246. if __name__ == "__main__":
  247. main()