You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

394 lines
16 KiB

  1. #!/usr/bin/env python3
  2. import nilmdb.client
  3. from nilmdb.client import Client
  4. from nilmdb.client.numpyclient import NumpyClient
  5. from nilmdb.utils.printf import printf, sprintf
  6. from nilmdb.utils.interval import Interval
  7. import nilmtools
  8. import os
  9. import argparse
  10. import numpy as np
  11. import functools
  12. class ArgumentError(Exception):
  13. pass
  14. class MissingDestination(Exception):
  15. def __init__(self, args, src, dest):
  16. self.parsed_args = args
  17. self.src = src
  18. self.dest = dest
  19. Exception.__init__(self, f"destination path {dest.path} not found")
  20. class StreamInfo(object):
  21. def __init__(self, url, info):
  22. self.url = url
  23. self.info = info
  24. try:
  25. self.path = info[0]
  26. self.layout = info[1]
  27. self.layout_type = self.layout.split('_')[0]
  28. self.layout_count = int(self.layout.split('_')[1])
  29. self.total_count = self.layout_count + 1
  30. self.timestamp_min = info[2]
  31. self.timestamp_max = info[3]
  32. self.rows = info[4]
  33. self.seconds = nilmdb.utils.time.timestamp_to_seconds(info[5])
  34. except (IndexError, TypeError):
  35. pass
  36. def string(self, interhost):
  37. """Return stream info as a string. If interhost is true,
  38. include the host URL."""
  39. if interhost:
  40. return sprintf("[%s] ", self.url) + str(self)
  41. return str(self)
  42. def __str__(self):
  43. """Return stream info as a string."""
  44. return sprintf("%s (%s), %.2fM rows, %.2f hours",
  45. self.path, self.layout, self.rows / 1e6,
  46. self.seconds / 3600.0)
  47. def get_stream_info(client, path):
  48. """Return a StreamInfo object about the given path, or None if it
  49. doesn't exist"""
  50. streams = client.stream_list(path, extended=True)
  51. if len(streams) != 1:
  52. return None
  53. return StreamInfo(client.geturl(), streams[0])
  54. # Filter processing for a single interval of data.
  55. def process_numpy_interval(interval, extractor, inserter, warn_rows,
  56. function, args=None):
  57. """For the given 'interval' of data, extract data, process it
  58. through 'function', and insert the result.
  59. 'extractor' should be a function like NumpyClient.stream_extract_numpy
  60. but with the the interval 'start' and 'end' as the only parameters,
  61. e.g.:
  62. extractor = functools.partial(NumpyClient.stream_extract_numpy,
  63. src_path, layout = l, maxrows = m)
  64. 'inserter' should be a function like NumpyClient.stream_insert_context
  65. but with the interval 'start' and 'end' as the only parameters, e.g.:
  66. inserter = functools.partial(NumpyClient.stream_insert_context,
  67. dest_path)
  68. If 'warn_rows' is not None, print a warning to stdout when the
  69. number of unprocessed rows exceeds this amount.
  70. See process_numpy for details on 'function' and 'args'.
  71. """
  72. if args is None:
  73. args = []
  74. with inserter(interval.start, interval.end) as insert_ctx:
  75. insert_func = insert_ctx.insert
  76. old_array = np.array([])
  77. for new_array in extractor(interval.start, interval.end):
  78. # If we still had old data left, combine it
  79. if old_array.shape[0] != 0:
  80. array = np.vstack((old_array, new_array))
  81. else:
  82. array = new_array
  83. # Pass the data to the user provided function
  84. processed = function(array, interval, args, insert_func, False)
  85. # Send any pending data that the user function inserted
  86. insert_ctx.send()
  87. # Save the unprocessed parts
  88. if processed >= 0:
  89. old_array = array[processed:]
  90. else:
  91. raise Exception(
  92. sprintf("%s return value %s must be >= 0",
  93. str(function), str(processed)))
  94. # Warn if there's too much data remaining
  95. if warn_rows is not None and old_array.shape[0] > warn_rows:
  96. printf("warning: %d unprocessed rows in buffer\n",
  97. old_array.shape[0])
  98. # Last call for this contiguous interval
  99. if old_array.shape[0] != 0:
  100. processed = function(old_array, interval, args,
  101. insert_func, True)
  102. if processed != old_array.shape[0]:
  103. # Truncate the interval we're inserting at the first
  104. # unprocessed data point. This ensures that
  105. # we'll not miss any data when we run again later.
  106. insert_ctx.update_end(old_array[processed][0])
  107. def example_callback_function(data, interval, args, insert_func, final):
  108. """Example of the signature for the function that gets passed
  109. to process_numpy_interval.
  110. 'data': array of data to process -- may be empty
  111. 'interval': overall interval we're processing (but not necessarily
  112. the interval of this particular chunk of data)
  113. 'args': opaque arguments passed to process_numpy
  114. 'insert_func': function to call in order to insert array of data.
  115. Should be passed a 2-dimensional array of data to insert.
  116. Data timestamps must be within the provided interval.
  117. 'final': True if this is the last bit of data for this
  118. contiguous interval, False otherwise.
  119. Return value of 'function' is the number of data rows processed.
  120. Unprocessed data will be provided again in a subsequent call
  121. (unless 'final' is True).
  122. If unprocessed data remains after 'final' is True, the interval
  123. being inserted will be ended at the timestamp of the first
  124. unprocessed data point.
  125. """
  126. raise NotImplementedError("example_callback_function does nothing")
  127. class Filter(object):
  128. def __init__(self, parser_description=None):
  129. self._parser = None
  130. self._client_src = None
  131. self._client_dest = None
  132. self._using_client = False
  133. self.src = None
  134. self.dest = None
  135. self.start = None
  136. self.end = None
  137. self._interhost = False
  138. self._force_metadata = False
  139. self.def_url = os.environ.get("NILMDB_URL", "http://localhost/nilmdb/")
  140. if parser_description is not None:
  141. self.setup_parser(parser_description)
  142. self.parse_args()
  143. @property
  144. def client_src(self):
  145. if self._using_client:
  146. raise Exception("Filter src client is in use; make another")
  147. return self._client_src
  148. @property
  149. def client_dest(self):
  150. if self._using_client:
  151. raise Exception("Filter dest client is in use; make another")
  152. return self._client_dest
  153. def setup_parser(self, description="Filter data", skip_paths=False):
  154. parser = argparse.ArgumentParser(
  155. formatter_class=argparse.RawDescriptionHelpFormatter,
  156. description=description)
  157. group = parser.add_argument_group("General filter arguments")
  158. group.add_argument("-u", "--url", action="store",
  159. default=self.def_url,
  160. help="Server URL (default: %(default)s)")
  161. group.add_argument("-U", "--dest-url", action="store",
  162. help="Destination server URL "
  163. "(default: same as source)")
  164. group.add_argument("-D", "--dry-run", action="store_true",
  165. default=False,
  166. help="Just print intervals that would be "
  167. "processed")
  168. group.add_argument("-q", "--quiet", action="store_true",
  169. default=False,
  170. help="Don't print source and dest stream info")
  171. group.add_argument("-F", "--force-metadata", action="store_true",
  172. default=False,
  173. help="Force metadata changes if the dest "
  174. "doesn't match")
  175. group.add_argument("-s", "--start",
  176. metavar="TIME", type=self.arg_time,
  177. help="Starting timestamp for intervals "
  178. "(free-form, inclusive)")
  179. group.add_argument("-e", "--end",
  180. metavar="TIME", type=self.arg_time,
  181. help="Ending timestamp for intervals "
  182. "(free-form, noninclusive)")
  183. group.add_argument("-v", "--version", action="version",
  184. version=nilmtools.__version__)
  185. if not skip_paths:
  186. # Individual filter scripts might want to add these arguments
  187. # themselves, to include multiple sources in a different order
  188. # (for example). "srcpath" and "destpath" arguments must exist,
  189. # though.
  190. group.add_argument("srcpath", action="store",
  191. help="Path of source stream, eg. /foo/bar")
  192. group.add_argument("destpath", action="store",
  193. help="Path of destination stream, eg. /foo/bar")
  194. self._parser = parser
  195. return parser
  196. def set_args(self, url, dest_url, srcpath, destpath, start, end,
  197. parsed_args=None, quiet=True):
  198. """Set arguments directly from parameters"""
  199. if dest_url is None:
  200. dest_url = url
  201. if url != dest_url:
  202. self._interhost = True
  203. self._client_src = Client(url)
  204. self._client_dest = Client(dest_url)
  205. if (not self._interhost) and (srcpath == destpath):
  206. raise ArgumentError(
  207. "source and destination path must be different")
  208. # Open the streams
  209. self.src = get_stream_info(self._client_src, srcpath)
  210. if not self.src:
  211. raise ArgumentError("source path " + srcpath + " not found")
  212. self.dest = get_stream_info(self._client_dest, destpath)
  213. if not self.dest:
  214. raise MissingDestination(parsed_args, self.src,
  215. StreamInfo(dest_url, [destpath]))
  216. self.start = start
  217. self.end = end
  218. # Print info
  219. if not quiet:
  220. print("Source:", self.src.string(self._interhost))
  221. print(" Dest:", self.dest.string(self._interhost))
  222. def parse_args(self, argv=None):
  223. """Parse arguments from a command line"""
  224. args = self._parser.parse_args(argv)
  225. self.set_args(args.url, args.dest_url, args.srcpath, args.destpath,
  226. args.start, args.end, quiet=args.quiet, parsed_args=args)
  227. self._force_metadata = args.force_metadata
  228. if args.dry_run:
  229. for interval in self.intervals():
  230. print(interval.human_string())
  231. raise SystemExit(0)
  232. return args
  233. def intervals(self):
  234. """Generate all the intervals that this filter should process"""
  235. self._using_client = True
  236. if self._interhost:
  237. # Do the difference ourselves
  238. s_intervals = (Interval(start, end)
  239. for (start, end) in
  240. self._client_src.stream_intervals(
  241. self.src.path,
  242. start=self.start, end=self.end))
  243. d_intervals = (Interval(start, end)
  244. for (start, end) in
  245. self._client_dest.stream_intervals(
  246. self.dest.path,
  247. start=self.start, end=self.end))
  248. intervals = nilmdb.utils.interval.set_difference(s_intervals,
  249. d_intervals)
  250. else:
  251. # Let the server do the difference for us
  252. intervals = (Interval(start, end)
  253. for (start, end) in
  254. self._client_src.stream_intervals(
  255. self.src.path, diffpath=self.dest.path,
  256. start=self.start, end=self.end))
  257. # Optimize intervals: join intervals that are adjacent
  258. for interval in nilmdb.utils.interval.optimize(intervals):
  259. yield interval
  260. self._using_client = False
  261. # Misc helpers
  262. @staticmethod
  263. def arg_time(toparse):
  264. """Parse a time string argument"""
  265. try:
  266. return nilmdb.utils.time.parse_time(toparse)
  267. except ValueError as e:
  268. raise argparse.ArgumentTypeError(sprintf("%s \"%s\"",
  269. str(e), toparse))
  270. def check_dest_metadata(self, data):
  271. """See if the metadata jives, and complain if it doesn't. For
  272. each key in data, if the stream contains the key, it must match
  273. values. If the stream does not contain the key, it is created."""
  274. metadata = self._client_dest.stream_get_metadata(self.dest.path)
  275. if not self._force_metadata:
  276. for key in data:
  277. wanted = data[key]
  278. if not isinstance(wanted, str):
  279. wanted = str(wanted)
  280. val = metadata.get(key, wanted)
  281. if val != wanted and self.dest.rows > 0:
  282. m = "Metadata in destination stream:\n"
  283. m += " %s = %s\n" % (key, val)
  284. m += "doesn't match desired data:\n"
  285. m += " %s = %s\n" % (key, wanted)
  286. m += "Refusing to change it. To prevent this error, "
  287. m += "change or delete the metadata with nilmtool,\n"
  288. m += "remove existing data from the stream, or "
  289. m += "retry with --force-metadata."
  290. raise Exception(m)
  291. # All good -- write the metadata in case it's not already there
  292. self._client_dest.stream_update_metadata(self.dest.path, data)
  293. # The main filter processing method.
  294. def process_numpy(self, function, args=None, rows=100000,
  295. intervals=None):
  296. """Calls process_numpy_interval for each interval that currently
  297. exists in self.src, but doesn't exist in self.dest. It will
  298. process the data in chunks as follows:
  299. For each chunk of data, call 'function' with a Numpy array
  300. corresponding to the data. The data is converted to a Numpy
  301. array in chunks of 'rows' rows at a time.
  302. If 'intervals' is not None, process those intervals instead of
  303. the default list.
  304. 'function' should be defined with the same interface as
  305. nilmtools.filter.example_callback_function. See the
  306. documentation of that for details. 'args' are passed to
  307. 'function'.
  308. """
  309. extractor = NumpyClient(self.src.url).stream_extract_numpy
  310. inserter = NumpyClient(self.dest.url).stream_insert_numpy_context
  311. extractor_func = functools.partial(extractor, self.src.path,
  312. layout=self.src.layout,
  313. maxrows=rows)
  314. inserter_func = functools.partial(inserter, self.dest.path)
  315. for interval in (intervals or self.intervals()):
  316. print("Processing", interval.human_string())
  317. process_numpy_interval(interval, extractor_func, inserter_func,
  318. rows * 3, function, args)
  319. def main(argv=None):
  320. # This is just a dummy function; actual filters can use the other
  321. # functions to prepare stuff, and then do something with the data.
  322. f = Filter()
  323. parser = f.setup_parser() # noqa: F841
  324. args = f.parse_args(argv) # noqa: F841
  325. for i in f.intervals():
  326. print("Generic filter: need to handle", i.human_string())
  327. if __name__ == "__main__":
  328. main()