You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

259 lines
8.7 KiB

  1. #!/usr/bin/python
  2. from nilmdb.utils.printf import *
  3. import nilmdb.client
  4. import nilmtools.filter
  5. from nilmdb.utils.time import (timestamp_to_human,
  6. timestamp_to_seconds,
  7. seconds_to_timestamp)
  8. from nilmdb.utils.interval import Interval
  9. import numpy as np
  10. import scipy
  11. import scipy.signal
  12. from numpy.core.umath_tests import inner1d
  13. import nilmrun
  14. from collections import OrderedDict
  15. class DataError(ValueError):
  16. pass
  17. class Data(object):
  18. def __init__(self, name, url, stream, start, end, columns):
  19. """Initialize, get stream info, check columns"""
  20. self.name = name
  21. self.url = url
  22. self.stream = stream
  23. self.start = start
  24. self.end = end
  25. # Get stream info
  26. self.client = nilmdb.client.numpyclient.NumpyClient(url)
  27. self.info = nilmtools.filter.get_stream_info(self.client, stream)
  28. # Build up name => index mapping for the columns
  29. self.columns = OrderedDict()
  30. for c in columns:
  31. if (c['name'] in self.columns.keys() or
  32. c['index'] in self.columns.values()):
  33. raise DataError("duplicated columns")
  34. if (c['index'] < 0 or c['index'] >= self.info.layout_count):
  35. raise DataError("bad column number")
  36. self.columns[c['name']] = c['index']
  37. if not len(self.columns):
  38. raise DataError("no columns")
  39. # Count points
  40. self.count = self.client.stream_count(self.stream, self.start, self.end)
  41. def __str__(self):
  42. return sprintf("%-20s: %s%s, %s rows",
  43. self.name, self.stream, str(self.columns.keys()),
  44. self.count)
  45. def fetch(self, min_rows = 10, max_rows = 100000):
  46. """Fetch all the data into self.data. This is intended for
  47. exemplars, and can only handle a relatively small number of
  48. rows"""
  49. # Verify count
  50. if self.count == 0:
  51. raise DataError("No data in this exemplar!")
  52. if self.count < min_rows:
  53. raise DataError("Too few data points: " + str(self.count))
  54. if self.count > max_rows:
  55. raise DataError("Too many data points: " + str(self.count))
  56. # Extract the data
  57. datagen = self.client.stream_extract_numpy(self.stream,
  58. self.start, self.end,
  59. self.info.layout,
  60. maxrows = self.count)
  61. self.data = list(datagen)[0]
  62. # Discard timestamp
  63. self.data = self.data[:,1:]
  64. # Subtract the mean from each column
  65. self.data = self.data - self.data.mean(axis=0)
  66. # Get scale factors for each column by computing dot product
  67. # of each column with itself.
  68. self.scale = inner1d(self.data.T, self.data.T)
  69. # Ensure a minimum (nonzero) scale and convert to list
  70. self.scale = np.maximum(self.scale, [1e-9]).tolist()
  71. def process(main, function, args = None, rows = 200000):
  72. """Process through the data; similar to nilmtools.Filter.process_numpy"""
  73. if args is None:
  74. args = []
  75. extractor = main.client.stream_extract_numpy
  76. old_array = np.array([])
  77. for new_array in extractor(main.stream, main.start, main.end,
  78. layout = main.info.layout, maxrows = rows):
  79. # If we still had old data left, combine it
  80. if old_array.shape[0] != 0:
  81. array = np.vstack((old_array, new_array))
  82. else:
  83. array = new_array
  84. # Process it
  85. processed = function(array, args)
  86. # Save the unprocessed parts
  87. if processed >= 0:
  88. old_array = array[processed:]
  89. else:
  90. raise Exception(sprintf("%s return value %s must be >= 0",
  91. str(function), str(processed)))
  92. # Warn if there's too much data remaining
  93. if old_array.shape[0] > 3 * rows:
  94. printf("warning: %d unprocessed rows in buffer\n",
  95. old_array.shape[0])
  96. # Handle leftover data
  97. if old_array.shape[0] != 0:
  98. processed = function(array, args)
  99. def peak_detect(data, delta):
  100. """Simple min/max peak detection algorithm, taken from my code
  101. in the disagg.m from the 10-8-5 paper"""
  102. mins = [];
  103. maxs = [];
  104. cur_min = (None, np.inf)
  105. cur_max = (None, -np.inf)
  106. lookformax = False
  107. for (n, p) in enumerate(data):
  108. if p > cur_max[1]:
  109. cur_max = (n, p)
  110. if p < cur_min[1]:
  111. cur_min = (n, p)
  112. if lookformax:
  113. if p < (cur_max[1] - delta):
  114. maxs.append(cur_max)
  115. cur_min = (n, p)
  116. lookformax = False
  117. else:
  118. if p > (cur_min[1] + delta):
  119. mins.append(cur_min)
  120. cur_max = (n, p)
  121. lookformax = True
  122. return (mins, maxs)
  123. def match(data, args):
  124. """Perform cross-correlation match"""
  125. ( columns, exemplars ) = args
  126. nrows = data.shape[0]
  127. # We want at least 10% more points than the widest exemplar.
  128. widest = max([ x.count for x in exemplars ])
  129. if (widest * 1.1) > nrows:
  130. return 0
  131. # This is how many points we'll consider valid in the
  132. # cross-correlation.
  133. valid = nrows + 1 - widest
  134. matches = []
  135. # Try matching against each of the exemplars
  136. for e_num, e in enumerate(exemplars):
  137. corrs = []
  138. # Compute cross-correlation for each column
  139. for c in e.columns:
  140. a = data[:,columns[c] + 1]
  141. b = e.data[:,e.columns[c]]
  142. corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
  143. # Scale by the norm of the exemplar
  144. corr = corr / e.scale[columns[c]]
  145. corrs.append(corr)
  146. # Find the peaks using the column with the largest amplitude
  147. biggest = e.scale.index(max(e.scale))
  148. peaks_minmax = peak_detect(corrs[biggest], 0.1)
  149. peaks = [ p[0] for p in peaks_minmax[1] ]
  150. # Now look at every peak
  151. for p in peaks:
  152. # Correlation for each column must be close enough to 1.
  153. for (corr, scale) in zip(corrs, e.scale):
  154. # The accepted distance from 1 is based on the relative
  155. # amplitude of the column. Use a linear mapping:
  156. # scale 1.0 -> distance 0.1
  157. # scale 0.0 -> distance 1.0
  158. distance = 1 - 0.9 * (scale / e.scale[biggest])
  159. if abs(corr[p] - 1) > distance:
  160. # No match
  161. break
  162. else:
  163. # Successful match
  164. matches.append((p, e_num))
  165. # Print matches
  166. for (point, e_num) in sorted(matches):
  167. # Ignore matches that showed up at the very tail of the window,
  168. # and shorten the window accordingly. This is an attempt to avoid
  169. # problems at chunk boundaries.
  170. if point > (valid - 50):
  171. valid -= 50
  172. break
  173. print "matched", data[point,0], "exemplar", exemplars[e_num].name
  174. #from matplotlib import pyplot as p
  175. #p.plot(data[:,1:3])
  176. #p.show()
  177. return max(valid, 0)
  178. def trainola(conf):
  179. # Load main stream data
  180. print "Loading stream data"
  181. main = Data(None, conf['url'], conf['stream'],
  182. conf['start'], conf['end'], conf['columns'])
  183. # Pull in the exemplar data
  184. exemplars = []
  185. for n, e in enumerate(conf['exemplars']):
  186. print sprintf("Loading exemplar %d: %s", n, e['name'])
  187. ex = Data(e['name'], e['url'], e['stream'],
  188. e['start'], e['end'], e['columns'])
  189. ex.fetch()
  190. exemplars.append(ex)
  191. # Verify that the exemplar columns are all represented in the main data
  192. for n, ex in enumerate(exemplars):
  193. for col in ex.columns:
  194. if col not in main.columns:
  195. raise DataError(sprintf("Exemplar %d column %s is not "
  196. "available in main data", n, col))
  197. # Process the main data
  198. process(main, match, (main.columns, exemplars))
  199. return "done"
  200. def main(argv = None):
  201. import simplejson as json
  202. import argparse
  203. import sys
  204. parser = argparse.ArgumentParser(
  205. formatter_class = argparse.RawDescriptionHelpFormatter,
  206. version = nilmrun.__version__,
  207. description = """Run Trainola using parameters passed in as
  208. JSON-formatted data.""")
  209. parser.add_argument("file", metavar="FILE", nargs="?",
  210. type=argparse.FileType('r'), default=sys.stdin)
  211. args = parser.parse_args(argv)
  212. conf = json.loads(args.file.read())
  213. result = trainola(conf)
  214. print json.dumps(result, sort_keys = True, indent = 2 * ' ')
  215. if __name__ == "__main__":
  216. main()