You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

334 lines
12 KiB

  1. #!/usr/bin/python
  2. from nilmdb.utils.printf import *
  3. import nilmdb.client
  4. import nilmtools.filter
  5. from nilmdb.utils.time import (timestamp_to_human,
  6. timestamp_to_seconds,
  7. seconds_to_timestamp)
  8. from nilmdb.utils import datetime_tz
  9. from nilmdb.utils.interval import Interval
  10. import numpy as np
  11. import scipy
  12. import scipy.signal
  13. from numpy.core.umath_tests import inner1d
  14. import nilmrun
  15. from collections import OrderedDict
  16. import sys
  17. import time
  18. import functools
  19. import collections
  20. class DataError(ValueError):
  21. pass
  22. def build_column_mapping(colinfo, streaminfo):
  23. """Given the 'columns' list from the JSON data, verify and
  24. pull out a dictionary mapping for the column names/numbers."""
  25. columns = OrderedDict()
  26. for c in colinfo:
  27. col_num = c['index'] + 1 # skip timestamp
  28. if (c['name'] in columns.keys() or col_num in columns.values()):
  29. raise DataError("duplicated columns")
  30. if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
  31. raise DataError("bad column number")
  32. columns[c['name']] = col_num
  33. if not len(columns):
  34. raise DataError("no columns")
  35. return columns
  36. class Exemplar(object):
  37. def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
  38. """Given a dictionary entry from the 'exemplars' input JSON,
  39. verify the stream, columns, etc. Then, fetch all the data
  40. into self.data."""
  41. self.name = exinfo['name']
  42. self.url = exinfo['url']
  43. self.stream = exinfo['stream']
  44. self.start = exinfo['start']
  45. self.end = exinfo['end']
  46. self.dest_column = exinfo['dest_column']
  47. # Get stream info
  48. self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
  49. self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
  50. if not self.info:
  51. raise DataError(sprintf("exemplar stream '%s' does not exist " +
  52. "on server '%s'", self.stream, self.url))
  53. # Build up name => index mapping for the columns
  54. self.columns = build_column_mapping(exinfo['columns'], self.info)
  55. # Count points
  56. self.count = self.client.stream_count(self.stream, self.start, self.end)
  57. # Verify count
  58. if self.count == 0:
  59. raise DataError("No data in this exemplar!")
  60. if self.count < min_rows:
  61. raise DataError("Too few data points: " + str(self.count))
  62. if self.count > max_rows:
  63. raise DataError("Too many data points: " + str(self.count))
  64. # Extract the data
  65. datagen = self.client.stream_extract_numpy(self.stream,
  66. self.start, self.end,
  67. self.info.layout,
  68. maxrows = self.count)
  69. self.data = list(datagen)[0]
  70. # Extract just the columns that were specified in self.columns,
  71. # skipping the timestamp.
  72. extract_columns = [ value for (key, value) in self.columns.items() ]
  73. self.data = self.data[:,extract_columns]
  74. # Fix the column indices in e.columns, since we removed/reordered
  75. # columns in self.data
  76. for n, k in enumerate(self.columns):
  77. self.columns[k] = n
  78. # Subtract the means from each column
  79. self.data = self.data - self.data.mean(axis=0)
  80. # Get scale factors for each column by computing dot product
  81. # of each column with itself.
  82. self.scale = inner1d(self.data.T, self.data.T)
  83. # Ensure a minimum (nonzero) scale and convert to list
  84. self.scale = np.maximum(self.scale, [1e-9]).tolist()
  85. def __str__(self):
  86. return sprintf("\"%s\" %s [%s] %s rows",
  87. self.name, self.stream, ",".join(self.columns.keys()),
  88. self.count)
  89. def peak_detect(data, delta):
  90. """Simple min/max peak detection algorithm, taken from my code
  91. in the disagg.m from the 10-8-5 paper.
  92. Returns an array of peaks: each peak is a tuple
  93. (n, p, is_max)
  94. where n is the row number in 'data', and p is 'data[n]',
  95. and is_max is True if this is a maximum, False if it's a minimum,
  96. """
  97. peaks = [];
  98. cur_min = (None, np.inf)
  99. cur_max = (None, -np.inf)
  100. lookformax = False
  101. for (n, p) in enumerate(data):
  102. if p > cur_max[1]:
  103. cur_max = (n, p)
  104. if p < cur_min[1]:
  105. cur_min = (n, p)
  106. if lookformax:
  107. if p < (cur_max[1] - delta):
  108. peaks.append((cur_max[0], cur_max[1], True))
  109. cur_min = (n, p)
  110. lookformax = False
  111. else:
  112. if p > (cur_min[1] + delta):
  113. peaks.append((cur_min[0], cur_min[1], False))
  114. cur_max = (n, p)
  115. lookformax = True
  116. return peaks
  117. def timestamp_to_short_human(timestamp):
  118. dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
  119. return dt.strftime("%H:%M:%S")
  120. def trainola_matcher(data, interval, args, insert_func, final_chunk):
  121. """Perform cross-correlation match"""
  122. ( src_columns, dest_count, exemplars ) = args
  123. nrows = data.shape[0]
  124. # We want at least 10% more points than the widest exemplar.
  125. widest = max([ x.count for x in exemplars ])
  126. if (widest * 1.1) > nrows:
  127. return 0
  128. # This is how many points we'll consider valid in the
  129. # cross-correlation.
  130. valid = nrows + 1 - widest
  131. matches = collections.defaultdict(list)
  132. # Try matching against each of the exemplars
  133. for e in exemplars:
  134. corrs = []
  135. # Compute cross-correlation for each column
  136. for col_name in e.columns:
  137. a = data[:, src_columns[col_name]]
  138. b = e.data[:, e.columns[col_name]]
  139. corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
  140. # Scale by the norm of the exemplar
  141. corr = corr / e.scale[e.columns[col_name]]
  142. corrs.append(corr)
  143. # Find the peaks using the column with the largest amplitude
  144. biggest = e.scale.index(max(e.scale))
  145. peaks = peak_detect(corrs[biggest], 0.1)
  146. # To try to reduce false positives, discard peaks where
  147. # there's a higher-magnitude peak (either min or max) within
  148. # one exemplar width nearby.
  149. good_peak_locations = []
  150. for (i, (n, p, is_max)) in enumerate(peaks):
  151. if not is_max:
  152. continue
  153. ok = True
  154. # check up to 'e.count' rows before this one
  155. j = i-1
  156. while ok and j >= 0 and peaks[j][0] > (n - e.count):
  157. if abs(peaks[j][1]) > abs(p):
  158. ok = False
  159. j -= 1
  160. # check up to 'e.count' rows after this one
  161. j = i+1
  162. while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
  163. if abs(peaks[j][1]) > abs(p):
  164. ok = False
  165. j += 1
  166. if ok:
  167. good_peak_locations.append(n)
  168. # Now look at all good peaks
  169. for row in good_peak_locations:
  170. # Correlation for each column must be close enough to 1.
  171. for (corr, scale) in zip(corrs, e.scale):
  172. # The accepted distance from 1 is based on the relative
  173. # amplitude of the column. Use a linear mapping:
  174. # scale 1.0 -> distance 0.1
  175. # scale 0.0 -> distance 1.0
  176. distance = 1 - 0.9 * (scale / e.scale[biggest])
  177. if abs(corr[row] - 1) > distance:
  178. # No match
  179. break
  180. else:
  181. # Successful match
  182. matches[row].append(e)
  183. # Insert matches into destination stream.
  184. matched_rows = sorted(matches.keys())
  185. out = np.zeros((len(matched_rows), dest_count + 1))
  186. for n, row in enumerate(matched_rows):
  187. # Fill timestamp
  188. out[n][0] = data[row, 0]
  189. # Mark matched exemplars
  190. for exemplar in matches[row]:
  191. out[n, exemplar.dest_column + 1] = 1.0
  192. # Insert it
  193. insert_func(out)
  194. # Return how many rows we processed
  195. valid = max(valid, 0)
  196. printf(" [%s] matched %d exemplars in %d rows\n",
  197. timestamp_to_short_human(data[0][0]), np.sum(out[:,1:]), valid)
  198. return valid
  199. def trainola(conf):
  200. print "Trainola", nilmtools.__version__
  201. # Load main stream data
  202. url = conf['url']
  203. src_path = conf['stream']
  204. dest_path = conf['dest_stream']
  205. start = conf['start']
  206. end = conf['end']
  207. # Get info for the src and dest streams
  208. src_client = nilmdb.client.numpyclient.NumpyClient(url)
  209. src = nilmtools.filter.get_stream_info(src_client, src_path)
  210. if not src:
  211. raise DataError("source path '" + src_path + "' does not exist")
  212. src_columns = build_column_mapping(conf['columns'], src)
  213. dest_client = nilmdb.client.numpyclient.NumpyClient(url)
  214. dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
  215. if not dest:
  216. raise DataError("destination path '" + dest_path + "' does not exist")
  217. printf("Source:\n")
  218. printf(" %s [%s]\n", src.path, ",".join(src_columns.keys()))
  219. printf("Destination:\n")
  220. printf(" %s (%s columns)\n", dest.path, dest.layout_count)
  221. # Pull in the exemplar data
  222. exemplars = []
  223. for n, exinfo in enumerate(conf['exemplars']):
  224. printf("Loading exemplar %d:\n", n)
  225. e = Exemplar(exinfo)
  226. col = e.dest_column
  227. if col < 0 or col >= dest.layout_count:
  228. raise DataError(sprintf("bad destination column number %d\n" +
  229. "dest stream only has 0 through %d",
  230. col, dest.layout_count - 1))
  231. printf(" %s, output column %d\n", str(e), col)
  232. exemplars.append(e)
  233. if len(exemplars) == 0:
  234. raise DataError("missing exemplars")
  235. # Verify that the exemplar columns are all represented in the main data
  236. for n, ex in enumerate(exemplars):
  237. for col in ex.columns:
  238. if col not in src_columns:
  239. raise DataError(sprintf("Exemplar %d column %s is not "
  240. "available in source data", n, col))
  241. # Figure out which intervals we should process
  242. intervals = ( Interval(s, e) for (s, e) in
  243. src_client.stream_intervals(src_path,
  244. diffpath = dest_path,
  245. start = start, end = end) )
  246. intervals = nilmdb.utils.interval.optimize(intervals)
  247. # Do the processing
  248. rows = 100000
  249. extractor = functools.partial(src_client.stream_extract_numpy,
  250. src.path, layout = src.layout, maxrows = rows)
  251. inserter = functools.partial(dest_client.stream_insert_numpy_context,
  252. dest.path)
  253. start = time.time()
  254. processed_time = 0
  255. printf("Processing intervals:\n")
  256. for interval in intervals:
  257. printf("%s\n", interval.human_string())
  258. nilmtools.filter.process_numpy_interval(
  259. interval, extractor, inserter, rows * 3,
  260. trainola_matcher, (src_columns, dest.layout_count, exemplars))
  261. processed_time += (timestamp_to_seconds(interval.end) -
  262. timestamp_to_seconds(interval.start))
  263. elapsed = max(time.time() - start, 1e-3)
  264. printf("Done. Processed %.2f seconds per second.\n",
  265. processed_time / elapsed)
  266. def main(argv = None):
  267. import simplejson as json
  268. import sys
  269. if argv is None:
  270. argv = sys.argv[1:]
  271. if len(argv) != 1:
  272. raise DataError("need one argument, either a dictionary or JSON string")
  273. try:
  274. # Passed in a JSON string (e.g. on the command line)
  275. conf = json.loads(argv[0])
  276. except TypeError as e:
  277. # Passed in the config dictionary (e.g. from NilmRun)
  278. conf = argv[0]
  279. return trainola(conf)
  280. if __name__ == "__main__":
  281. main()