From c9c2e0d5a8fbf911059e9b75bc43dd9651f6cf30 Mon Sep 17 00:00:00 2001 From: Jim Paris Date: Tue, 9 Jul 2013 18:09:05 -0400 Subject: [PATCH] Improve split between process_numpy and process_numpy_interval --- nilmtools/filter.py | 109 ++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/nilmtools/filter.py b/nilmtools/filter.py index 34b1bba..1782375 100644 --- a/nilmtools/filter.py +++ b/nilmtools/filter.py @@ -19,6 +19,7 @@ import re import argparse import numpy as np import cStringIO +import functools class ArgumentError(Exception): pass @@ -255,61 +256,68 @@ class Filter(object): self._client_dest.stream_update_metadata(self.dest.path, data) # Filter processing for a single interval of data. - def process_numpy_interval(self, interval, extractor, insert_ctx, - function, args = None, rows = 100000): + def process_numpy_interval(self, interval, extractor, inserter, + function, args = None, warn_rows = None): """For the given 'interval' of data, extract data, process it through 'function', and insert the result. 'extractor' should be a function like NumpyClient.stream_extract_numpy - 'insert_ctx' should be a class like StreamInserterNumpy, with member - functions 'insert', 'send', and 'update_end'. + but with the the interval 'start' and 'end' as the only parameters, + e.g.: + extractor = functools.partial(NumpyClient.stream_extract_numpy, + src_path, layout = l, maxrows = m) - See process_numpy for details on 'function', 'args', and 'rows'. + 'inserter' should be a function like NumpyClient.stream_insert_context + but with the interval 'start' and 'end' as the only parameters, e.g.: + inserter = functools.partial(NumpyClient.stream_insert_context, + dest_path) + + If 'warn_rows' is not None, print a warning to stdout when the + number of unprocessed rows exceeds this amount. + + See process_numpy for details on 'function' and 'args'. """ if args is None: args = [] - insert_function = insert_ctx.insert - old_array = np.array([]) - for new_array in extractor(self.src.path, - interval.start, interval.end, - layout = self.src.layout, - maxrows = rows): - # If we still had old data left, combine it + with inserter(interval.start, interval.end) as insert_ctx: + insert_func = insert_ctx.insert + old_array = np.array([]) + for new_array in extractor(interval.start, interval.end): + # If we still had old data left, combine it + if old_array.shape[0] != 0: + array = np.vstack((old_array, new_array)) + else: + array = new_array + + # Pass the data to the user provided function + processed = function(array, interval, args, insert_func, False) + + # Send any pending data that the user function inserted + insert_ctx.send() + + # Save the unprocessed parts + if processed >= 0: + old_array = array[processed:] + else: + raise Exception( + sprintf("%s return value %s must be >= 0", + str(function), str(processed))) + + # Warn if there's too much data remaining + if warn_rows is not None and old_array.shape[0] > warn_rows: + printf("warning: %d unprocessed rows in buffer\n", + old_array.shape[0]) + + # Last call for this contiguous interval if old_array.shape[0] != 0: - array = np.vstack((old_array, new_array)) - else: - array = new_array - - # Pass it to the process function - processed = function(array, interval, args, - insert_function, False) - - # Send any pending data - insert_ctx.send() - - # Save the unprocessed parts - if processed >= 0: - old_array = array[processed:] - else: - raise Exception( - sprintf("%s return value %s must be >= 0", - str(function), str(processed))) - - # Warn if there's too much data remaining - if old_array.shape[0] > 3 * rows: - printf("warning: %d unprocessed rows in buffer\n", - old_array.shape[0]) - - # Last call for this contiguous interval - if old_array.shape[0] != 0: - processed = function(old_array, interval, args, - insert_function, True) - if processed != old_array.shape[0]: - # Truncate the interval we're inserting at the first - # unprocessed data point. This ensures that - # we'll not miss any data when we run again later. - insert_ctx.update_end(old_array[processed][0]) + processed = function(old_array, interval, args, + insert_func, True) + if processed != old_array.shape[0]: + # Truncate the interval we're inserting at the first + # unprocessed data point. This ensures that + # we'll not miss any data when we run again later. + insert_ctx.update_end(old_array[processed][0]) # The main filter processing method. def process_numpy(self, function, args = None, rows = 100000): @@ -349,12 +357,15 @@ class Filter(object): extractor = NumpyClient(self.src.url).stream_extract_numpy inserter = NumpyClient(self.dest.url).stream_insert_numpy_context + extractor_func = functools.partial(extractor, self.src.path, + layout = self.src.layout, + maxrows = rows) + inserter_func = functools.partial(inserter, self.dest.path) + for interval in self.intervals(): print "Processing", self.interval_string(interval) - with inserter(self.dest.path, - interval.start, interval.end) as insert_ctx: - self.process_numpy_interval(interval, extractor, insert_ctx, - function, args, rows) + self.process_numpy_interval(interval, extractor_func, inserter_func, + function, args, warn_rows = rows * 3) def main(argv = None): # This is just a dummy function; actual filters can use the other