You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

386 lines
15 KiB

  1. #!/usr/bin/python
  2. from __future__ import absolute_import
  3. import nilmdb.client
  4. from nilmdb.client import Client
  5. from nilmdb.utils.printf import *
  6. from nilmdb.utils.time import (parse_time, timestamp_to_human,
  7. timestamp_to_seconds)
  8. from nilmdb.utils.interval import Interval
  9. import nilmtools
  10. import itertools
  11. import time
  12. import sys
  13. import re
  14. import argparse
  15. import numpy as np
  16. import cStringIO
  17. class MissingDestination(Exception):
  18. def __init__(self, src, dest):
  19. self.src = src
  20. self.dest = dest
  21. Exception.__init__(self, "destination path " + dest.path + " not found")
  22. class StreamInfo(object):
  23. def __init__(self, url, info, interhost = False):
  24. self.url = url
  25. self.info = info
  26. self.interhost = interhost
  27. try:
  28. self.path = info[0]
  29. self.layout = info[1]
  30. self.layout_type = self.layout.split('_')[0]
  31. self.layout_count = int(self.layout.split('_')[1])
  32. self.total_count = self.layout_count + 1
  33. self.timestamp_min = info[2]
  34. self.timestamp_max = info[3]
  35. self.rows = info[4]
  36. self.seconds = nilmdb.utils.time.timestamp_to_seconds(info[5])
  37. except IndexError, TypeError:
  38. pass
  39. def __str__(self):
  40. """Print stream info as a string"""
  41. res = ""
  42. if self.interhost:
  43. res = sprintf("[%s] ", self.url)
  44. res += sprintf("%s (%s), %.2fM rows, %.2f hours",
  45. self.path, self.layout, self.rows / 1e6,
  46. self.seconds / 3600.0)
  47. return res
  48. class Filter(object):
  49. def __init__(self):
  50. self._parser = None
  51. self._client_src = None
  52. self._client_dest = None
  53. self._using_client = False
  54. self.src = None
  55. self.dest = None
  56. self.start = None
  57. self.end = None
  58. self.interhost = False
  59. @property
  60. def client_src(self):
  61. if self._using_client:
  62. raise Exception("Filter client is in use; make another")
  63. return self._client_src
  64. @property
  65. def client_dest(self):
  66. if self._using_client:
  67. raise Exception("Filter client is in use; make another")
  68. return self._client_dest
  69. def setup_parser(self, description = "Filter data"):
  70. parser = argparse.ArgumentParser(
  71. formatter_class = argparse.RawDescriptionHelpFormatter,
  72. version = nilmtools.__version__,
  73. description = description)
  74. group = parser.add_argument_group("General filter arguments")
  75. group.add_argument("-u", "--url", action="store",
  76. default="http://localhost/nilmdb/",
  77. help="Server URL (default: %(default)s)")
  78. group.add_argument("-U", "--dest-url", action="store",
  79. help="Destination server URL "
  80. "(default: same as source)")
  81. group.add_argument("-D", "--dry-run", action="store_true",
  82. default = False,
  83. help="Just print intervals that would be "
  84. "processed")
  85. group.add_argument("-s", "--start",
  86. metavar="TIME", type=self.arg_time,
  87. help="Starting timestamp for intervals "
  88. "(free-form, inclusive)")
  89. group.add_argument("-e", "--end",
  90. metavar="TIME", type=self.arg_time,
  91. help="Ending timestamp for intervals "
  92. "(free-form, noninclusive)")
  93. group.add_argument("srcpath", action="store",
  94. help="Path of source stream, e.g. /foo/bar")
  95. group.add_argument("destpath", action="store",
  96. help="Path of destination stream, e.g. /foo/bar")
  97. self._parser = parser
  98. return parser
  99. def interval_string(self, interval):
  100. return sprintf("[ %s -> %s ]",
  101. timestamp_to_human(interval.start),
  102. timestamp_to_human(interval.end))
  103. def parse_args(self, argv = None):
  104. args = self._parser.parse_args(argv)
  105. if args.dest_url is None:
  106. args.dest_url = args.url
  107. if args.url != args.dest_url:
  108. self.interhost = True
  109. self._client_src = Client(args.url)
  110. self._client_dest = Client(args.dest_url)
  111. if (not self.interhost) and (args.srcpath == args.destpath):
  112. raise Exception("source and destination path must be different")
  113. # Open and print info about the streams
  114. src = self._client_src.stream_list(args.srcpath, extended = True)
  115. if len(src) != 1:
  116. raise Exception("source path " + args.srcpath + " not found")
  117. self.src = StreamInfo(args.url, src[0], self.interhost)
  118. dest = self._client_dest.stream_list(args.destpath, extended = True)
  119. if len(dest) != 1:
  120. raise MissingDestination(self.src,
  121. StreamInfo(args.dest_url, [args.destpath],
  122. self.interhost))
  123. self.dest = StreamInfo(args.dest_url, dest[0], self.interhost)
  124. print "Source:", self.src
  125. print " Dest:", self.dest
  126. if args.dry_run:
  127. for interval in self.intervals():
  128. print self.interval_string(interval)
  129. raise SystemExit(0)
  130. self.start = args.start
  131. self.end = args.end
  132. return args
  133. def _optimize_int(self, it):
  134. """Join and yield adjacent intervals from the iterator 'it'"""
  135. saved_int = None
  136. for interval in it:
  137. if saved_int is not None:
  138. if saved_int.end == interval.start:
  139. interval.start = saved_int.start
  140. else:
  141. yield saved_int
  142. saved_int = interval
  143. if saved_int is not None:
  144. yield saved_int
  145. def intervals(self):
  146. """Generate all the intervals that this filter should process"""
  147. self._using_client = True
  148. if self.interhost:
  149. # Do the difference ourselves
  150. s_intervals = ( Interval(start, end)
  151. for (start, end) in
  152. self._client_src.stream_intervals(
  153. self.src.path,
  154. start = self.start, end = self.end) )
  155. d_intervals = ( Interval(start, end)
  156. for (start, end) in
  157. self._client_dest.stream_intervals(
  158. self.dest.path,
  159. start = self.start, end = self.end) )
  160. intervals = nilmdb.utils.interval.set_difference(s_intervals,
  161. d_intervals)
  162. else:
  163. # Let the server do the difference for us
  164. intervals = ( Interval(start, end)
  165. for (start, end) in
  166. self._client_src.stream_intervals(
  167. self.src.path, diffpath = self.dest.path,
  168. start = self.start, end = self.end) )
  169. # Optimize intervals: join intervals that are adjacent
  170. for interval in self._optimize_int(intervals):
  171. yield interval
  172. self._using_client = False
  173. # Misc helpers
  174. def arg_time(self, toparse):
  175. """Parse a time string argument"""
  176. try:
  177. return nilmdb.utils.time.parse_time(toparse)
  178. except ValueError as e:
  179. raise argparse.ArgumentTypeError(sprintf("%s \"%s\"",
  180. str(e), toparse))
  181. def check_dest_metadata(self, data):
  182. """See if the metadata jives, and complain if it doesn't. If
  183. there's no conflict, update the metadata to match 'data'."""
  184. metadata = self._client_dest.stream_get_metadata(self.dest.path)
  185. for key in data:
  186. wanted = str(data[key])
  187. val = metadata.get(key, wanted)
  188. if val != wanted and self.dest.rows > 0:
  189. m = "Metadata in destination stream:\n"
  190. m += " %s = %s\n" % (key, val)
  191. m += "doesn't match desired data:\n"
  192. m += " %s = %s\n" % (key, wanted)
  193. m += "Refusing to change it. You can change the stream's "
  194. m += "metadata manually, or\n"
  195. m += "remove existing data from the stream, to prevent "
  196. m += "this error.\n"
  197. raise Exception(m)
  198. # All good -- write the metadata in case it's not already there
  199. self._client_dest.stream_update_metadata(self.dest.path, data)
  200. # Main processing helper
  201. def process_python(self, function, rows, args = None, partial = False):
  202. """Process data in chunks of 'rows' data at a time.
  203. This provides data as nested Python lists and expects the same
  204. back.
  205. function: function to process the data
  206. rows: maximum number of rows to pass to 'function' at once
  207. args: tuple containing extra arguments to pass to 'function'
  208. partial: if true, less than 'rows' may be passed to 'function'.
  209. if false, partial data at the end of an interval will
  210. be dropped.
  211. 'function' should be defined like:
  212. function(data, *args)
  213. It will be passed a list containing up to 'rows' rows of
  214. data from the source stream, and any arguments passed in
  215. 'args'. It should transform the data as desired, and return a
  216. new list of rdata, which will be inserted into the destination
  217. stream.
  218. """
  219. if args is None:
  220. args = []
  221. extractor = Client(self.src.url).stream_extract
  222. inserter = Client(self.dest.url).stream_insert_context
  223. # Parse input data. We use homogenous types for now, which
  224. # means the timestamp type will be either float or int.
  225. if "int" in self.src.layout_type:
  226. parser = lambda line: [ int(x) for x in line.split() ]
  227. else:
  228. parser = lambda line: [ float(x) for x in line.split() ]
  229. # Format output data.
  230. formatter = lambda row: " ".join([repr(x) for x in row]) + "\n"
  231. for interval in self.intervals():
  232. print "Processing", self.interval_string(interval)
  233. with inserter(self.dest.path,
  234. interval.start, interval.end) as insert_ctx:
  235. src_array = []
  236. for line in extractor(self.src.path,
  237. interval.start, interval.end):
  238. # Read in data
  239. src_array.append([ float(x) for x in line.split() ])
  240. if len(src_array) == rows:
  241. # Pass through filter function
  242. dest_array = function(src_array, *args)
  243. # Write result to destination
  244. out = [ formatter(row) for row in dest_array ]
  245. insert_ctx.insert("".join(out))
  246. # Clear source array
  247. src_array = []
  248. # Take care of partial chunk
  249. if len(src_array) and partial:
  250. dest_array = function(src_array, *args)
  251. out = [ formatter(row) for row in dest_array ]
  252. insert_ctx.insert("".join(out))
  253. # Like process_python, but provides Numpy arrays and allows for
  254. # partial processing.
  255. def process_numpy(self, function, args = None, rows = 100000):
  256. """For all intervals that exist in self.src but don't exist in
  257. self.dest, call 'function' with a Numpy array corresponding to
  258. the data. The data is converted to a Numpy array in chunks of
  259. 'rows' rows at a time.
  260. 'function' should be defined as:
  261. def function(data, interval, args, insert_func, final)
  262. 'data': array of data to process -- may be empty
  263. 'interval': overall interval we're processing (but not necessarily
  264. the interval of this particular chunk of data)
  265. 'args': opaque arguments passed to process_numpy
  266. 'insert_func': function to call in order to insert array of data.
  267. Should be passed a 2-dimensional array of data to insert.
  268. Data timestamps must be within the provided interval.
  269. 'final': True if this is the last bit of data for this
  270. contiguous interval, False otherwise.
  271. Return value of 'function' is the number of data rows processed.
  272. Unprocessed data will be provided again in a subsequent call
  273. (unless 'final' is True).
  274. """
  275. if args is None:
  276. args = []
  277. extractor = Client(self.src.url).stream_extract
  278. inserter = Client(self.dest.url).stream_insert_context
  279. # Format output data.
  280. formatter = lambda row: " ".join([repr(x) for x in row]) + "\n"
  281. def batch(iterable, size):
  282. c = itertools.count()
  283. for k, g in itertools.groupby(iterable, lambda x: c.next() // size):
  284. yield g
  285. for interval in self.intervals():
  286. print "Processing", self.interval_string(interval)
  287. with inserter(self.dest.path,
  288. interval.start, interval.end) as insert_ctx:
  289. def insert_function(array):
  290. s = cStringIO.StringIO()
  291. if len(np.shape(array)) != 2:
  292. raise Exception("array must be 2-dimensional")
  293. np.savetxt(s, array)
  294. insert_ctx.insert(s.getvalue())
  295. extract = extractor(self.src.path, interval.start, interval.end)
  296. old_array = np.array([])
  297. for batched in batch(extract, rows):
  298. # Read in this batch of data
  299. new_array = np.loadtxt(batched)
  300. # If we still had old data left, combine it
  301. if old_array.shape[0] != 0:
  302. array = np.vstack((old_array, new_array))
  303. else:
  304. array = new_array
  305. # Pass it to the process function
  306. processed = function(array, interval, args,
  307. insert_function, False)
  308. # Send any pending data
  309. insert_ctx.send()
  310. # Save the unprocessed parts
  311. if processed > 0:
  312. old_array = array[processed:]
  313. else:
  314. old_array = array
  315. # Last call for this contiguous interval
  316. if old_array.shape[0] != 0:
  317. function(old_array, interval, args, insert_function, True)
  318. def main(argv = None):
  319. # This is just a dummy function; actual filters can use the other
  320. # functions to prepare stuff, and then do something with the data.
  321. f = Filter()
  322. parser = f.setup_parser()
  323. args = f.parse_args(argv)
  324. for i in f.intervals():
  325. print "Generic filter: need to handle", f.interval_string(i)
  326. if __name__ == "__main__":
  327. main()