Browse Source

Improve split between process_numpy and process_numpy_interval

tags/nilmtools-1.3.1
Jim Paris 11 years ago
parent
commit
c9c2e0d5a8
1 changed files with 60 additions and 49 deletions
  1. +60
    -49
      nilmtools/filter.py

+ 60
- 49
nilmtools/filter.py View File

@@ -19,6 +19,7 @@ import re
import argparse
import numpy as np
import cStringIO
import functools

class ArgumentError(Exception):
pass
@@ -255,61 +256,68 @@ class Filter(object):
self._client_dest.stream_update_metadata(self.dest.path, data)

# Filter processing for a single interval of data.
def process_numpy_interval(self, interval, extractor, insert_ctx,
function, args = None, rows = 100000):
def process_numpy_interval(self, interval, extractor, inserter,
function, args = None, warn_rows = None):
"""For the given 'interval' of data, extract data, process it
through 'function', and insert the result.

'extractor' should be a function like NumpyClient.stream_extract_numpy
'insert_ctx' should be a class like StreamInserterNumpy, with member
functions 'insert', 'send', and 'update_end'.
but with the the interval 'start' and 'end' as the only parameters,
e.g.:
extractor = functools.partial(NumpyClient.stream_extract_numpy,
src_path, layout = l, maxrows = m)

See process_numpy for details on 'function', 'args', and 'rows'.
'inserter' should be a function like NumpyClient.stream_insert_context
but with the interval 'start' and 'end' as the only parameters, e.g.:
inserter = functools.partial(NumpyClient.stream_insert_context,
dest_path)

If 'warn_rows' is not None, print a warning to stdout when the
number of unprocessed rows exceeds this amount.

See process_numpy for details on 'function' and 'args'.
"""
if args is None:
args = []

insert_function = insert_ctx.insert
old_array = np.array([])
for new_array in extractor(self.src.path,
interval.start, interval.end,
layout = self.src.layout,
maxrows = rows):
# If we still had old data left, combine it
with inserter(interval.start, interval.end) as insert_ctx:
insert_func = insert_ctx.insert
old_array = np.array([])
for new_array in extractor(interval.start, interval.end):
# If we still had old data left, combine it
if old_array.shape[0] != 0:
array = np.vstack((old_array, new_array))
else:
array = new_array

# Pass the data to the user provided function
processed = function(array, interval, args, insert_func, False)

# Send any pending data that the user function inserted
insert_ctx.send()

# Save the unprocessed parts
if processed >= 0:
old_array = array[processed:]
else:
raise Exception(
sprintf("%s return value %s must be >= 0",
str(function), str(processed)))

# Warn if there's too much data remaining
if warn_rows is not None and old_array.shape[0] > warn_rows:
printf("warning: %d unprocessed rows in buffer\n",
old_array.shape[0])

# Last call for this contiguous interval
if old_array.shape[0] != 0:
array = np.vstack((old_array, new_array))
else:
array = new_array

# Pass it to the process function
processed = function(array, interval, args,
insert_function, False)

# Send any pending data
insert_ctx.send()

# Save the unprocessed parts
if processed >= 0:
old_array = array[processed:]
else:
raise Exception(
sprintf("%s return value %s must be >= 0",
str(function), str(processed)))

# Warn if there's too much data remaining
if old_array.shape[0] > 3 * rows:
printf("warning: %d unprocessed rows in buffer\n",
old_array.shape[0])

# Last call for this contiguous interval
if old_array.shape[0] != 0:
processed = function(old_array, interval, args,
insert_function, True)
if processed != old_array.shape[0]:
# Truncate the interval we're inserting at the first
# unprocessed data point. This ensures that
# we'll not miss any data when we run again later.
insert_ctx.update_end(old_array[processed][0])
processed = function(old_array, interval, args,
insert_func, True)
if processed != old_array.shape[0]:
# Truncate the interval we're inserting at the first
# unprocessed data point. This ensures that
# we'll not miss any data when we run again later.
insert_ctx.update_end(old_array[processed][0])

# The main filter processing method.
def process_numpy(self, function, args = None, rows = 100000):
@@ -349,12 +357,15 @@ class Filter(object):
extractor = NumpyClient(self.src.url).stream_extract_numpy
inserter = NumpyClient(self.dest.url).stream_insert_numpy_context

extractor_func = functools.partial(extractor, self.src.path,
layout = self.src.layout,
maxrows = rows)
inserter_func = functools.partial(inserter, self.dest.path)

for interval in self.intervals():
print "Processing", self.interval_string(interval)
with inserter(self.dest.path,
interval.start, interval.end) as insert_ctx:
self.process_numpy_interval(interval, extractor, insert_ctx,
function, args, rows)
self.process_numpy_interval(interval, extractor_func, inserter_func,
function, args, warn_rows = rows * 3)

def main(argv = None):
# This is just a dummy function; actual filters can use the other


Loading…
Cancel
Save