|
|
@@ -14,47 +14,49 @@ import scipy.signal |
|
|
|
from numpy.core.umath_tests import inner1d |
|
|
|
import nilmrun |
|
|
|
from collections import OrderedDict |
|
|
|
import sys |
|
|
|
|
|
|
|
class DataError(ValueError): |
|
|
|
pass |
|
|
|
|
|
|
|
class Data(object): |
|
|
|
def __init__(self, name, url, stream, start, end, columns): |
|
|
|
"""Initialize, get stream info, check columns""" |
|
|
|
self.name = name |
|
|
|
self.url = url |
|
|
|
self.stream = stream |
|
|
|
self.start = start |
|
|
|
self.end = end |
|
|
|
def build_column_mapping(colinfo, streaminfo): |
|
|
|
"""Given the 'columns' list from the JSON data, verify and |
|
|
|
pull out a dictionary mapping for the column names/numbers.""" |
|
|
|
columns = OrderedDict() |
|
|
|
for c in colinfo: |
|
|
|
if (c['name'] in columns.keys() or |
|
|
|
c['index'] in columns.values()): |
|
|
|
raise DataError("duplicated columns") |
|
|
|
if (c['index'] < 0 or c['index'] >= streaminfo.layout_count): |
|
|
|
raise DataError("bad column number") |
|
|
|
columns[c['name']] = c['index'] |
|
|
|
if not len(columns): |
|
|
|
raise DataError("no columns") |
|
|
|
return columns |
|
|
|
|
|
|
|
class Exemplar(object): |
|
|
|
def __init__(self, exinfo, min_rows = 10, max_rows = 100000): |
|
|
|
"""Given a dictionary entry from the 'exemplars' input JSON, |
|
|
|
verify the stream, columns, etc. Then, fetch all the data |
|
|
|
into self.data.""" |
|
|
|
|
|
|
|
self.name = exinfo['name'] |
|
|
|
self.url = exinfo['url'] |
|
|
|
self.stream = exinfo['stream'] |
|
|
|
self.start = exinfo['start'] |
|
|
|
self.end = exinfo['end'] |
|
|
|
self.dest_column = exinfo['dest_column'] |
|
|
|
|
|
|
|
# Get stream info |
|
|
|
self.client = nilmdb.client.numpyclient.NumpyClient(url) |
|
|
|
self.info = nilmtools.filter.get_stream_info(self.client, stream) |
|
|
|
self.client = nilmdb.client.numpyclient.NumpyClient(self.url) |
|
|
|
self.info = nilmtools.filter.get_stream_info(self.client, self.stream) |
|
|
|
|
|
|
|
# Build up name => index mapping for the columns |
|
|
|
self.columns = OrderedDict() |
|
|
|
for c in columns: |
|
|
|
if (c['name'] in self.columns.keys() or |
|
|
|
c['index'] in self.columns.values()): |
|
|
|
raise DataError("duplicated columns") |
|
|
|
if (c['index'] < 0 or c['index'] >= self.info.layout_count): |
|
|
|
raise DataError("bad column number") |
|
|
|
self.columns[c['name']] = c['index'] |
|
|
|
if not len(self.columns): |
|
|
|
raise DataError("no columns") |
|
|
|
self.columns = build_column_mapping(exinfo['columns'], self.info) |
|
|
|
|
|
|
|
# Count points |
|
|
|
self.count = self.client.stream_count(self.stream, self.start, self.end) |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return sprintf("%-20s: %s%s, %s rows", |
|
|
|
self.name, self.stream, str(self.columns.keys()), |
|
|
|
self.count) |
|
|
|
|
|
|
|
def fetch(self, min_rows = 10, max_rows = 100000): |
|
|
|
"""Fetch all the data into self.data. This is intended for |
|
|
|
exemplars, and can only handle a relatively small number of |
|
|
|
rows""" |
|
|
|
# Verify count |
|
|
|
if self.count == 0: |
|
|
|
raise DataError("No data in this exemplar!") |
|
|
@@ -83,39 +85,10 @@ class Data(object): |
|
|
|
# Ensure a minimum (nonzero) scale and convert to list |
|
|
|
self.scale = np.maximum(self.scale, [1e-9]).tolist() |
|
|
|
|
|
|
|
def process(main, function, args = None, rows = 200000): |
|
|
|
"""Process through the data; similar to nilmtools.Filter.process_numpy""" |
|
|
|
if args is None: |
|
|
|
args = [] |
|
|
|
|
|
|
|
extractor = main.client.stream_extract_numpy |
|
|
|
old_array = np.array([]) |
|
|
|
for new_array in extractor(main.stream, main.start, main.end, |
|
|
|
layout = main.info.layout, maxrows = rows): |
|
|
|
# If we still had old data left, combine it |
|
|
|
if old_array.shape[0] != 0: |
|
|
|
array = np.vstack((old_array, new_array)) |
|
|
|
else: |
|
|
|
array = new_array |
|
|
|
|
|
|
|
# Process it |
|
|
|
processed = function(array, args) |
|
|
|
|
|
|
|
# Save the unprocessed parts |
|
|
|
if processed >= 0: |
|
|
|
old_array = array[processed:] |
|
|
|
else: |
|
|
|
raise Exception(sprintf("%s return value %s must be >= 0", |
|
|
|
str(function), str(processed))) |
|
|
|
|
|
|
|
# Warn if there's too much data remaining |
|
|
|
if old_array.shape[0] > 3 * rows: |
|
|
|
printf("warning: %d unprocessed rows in buffer\n", |
|
|
|
old_array.shape[0]) |
|
|
|
|
|
|
|
# Handle leftover data |
|
|
|
if old_array.shape[0] != 0: |
|
|
|
processed = function(array, args) |
|
|
|
def __str__(self): |
|
|
|
return sprintf("\"%s\" %s [%s] %s rows", |
|
|
|
self.name, self.stream, ",".join(self.columns.keys()), |
|
|
|
self.count) |
|
|
|
|
|
|
|
def peak_detect(data, delta): |
|
|
|
"""Simple min/max peak detection algorithm, taken from my code |
|
|
@@ -142,7 +115,7 @@ def peak_detect(data, delta): |
|
|
|
lookformax = True |
|
|
|
return (mins, maxs) |
|
|
|
|
|
|
|
def match(data, args): |
|
|
|
def trainola_matcher(data, args): |
|
|
|
"""Perform cross-correlation match""" |
|
|
|
( columns, exemplars ) = args |
|
|
|
nrows = data.shape[0] |
|
|
@@ -209,51 +182,135 @@ def match(data, args): |
|
|
|
return max(valid, 0) |
|
|
|
|
|
|
|
def trainola(conf): |
|
|
|
print "Trainola", nilmtools.__version__ |
|
|
|
|
|
|
|
# Load main stream data |
|
|
|
print "Loading stream data" |
|
|
|
main = Data(None, conf['url'], conf['stream'], |
|
|
|
conf['start'], conf['end'], conf['columns']) |
|
|
|
url = conf['url'] |
|
|
|
src_path = conf['stream'] |
|
|
|
dest_path = conf['dest_stream'] |
|
|
|
start = conf['start'] |
|
|
|
end = conf['end'] |
|
|
|
|
|
|
|
# Get info for the src and dest streams |
|
|
|
src_client = nilmdb.client.numpyclient.NumpyClient(url) |
|
|
|
src = nilmtools.filter.get_stream_info(src_client, src_path) |
|
|
|
if not src: |
|
|
|
raise DataError("source path '" + src_path + "' does not exist") |
|
|
|
src_columns = build_column_mapping(conf['columns'], src) |
|
|
|
|
|
|
|
dest_client = nilmdb.client.numpyclient.NumpyClient(url) |
|
|
|
dest = nilmtools.filter.get_stream_info(dest_client, dest_path) |
|
|
|
if not dest: |
|
|
|
raise DataError("destination path '" + dest_path + "' does not exist") |
|
|
|
|
|
|
|
printf("Source:\n") |
|
|
|
printf(" %s [%s]\n", src.path, ",".join(src_columns.keys())) |
|
|
|
printf("Destination:\n") |
|
|
|
printf(" %s (%s columns)\n", dest.path, dest.layout_count) |
|
|
|
|
|
|
|
# Pull in the exemplar data |
|
|
|
exemplars = [] |
|
|
|
for n, e in enumerate(conf['exemplars']): |
|
|
|
print sprintf("Loading exemplar %d: %s", n, e['name']) |
|
|
|
ex = Data(e['name'], e['url'], e['stream'], |
|
|
|
e['start'], e['end'], e['columns']) |
|
|
|
ex.fetch() |
|
|
|
exemplars.append(ex) |
|
|
|
for n, exinfo in enumerate(conf['exemplars']): |
|
|
|
printf("Loading exemplar %d:\n", n) |
|
|
|
e = Exemplar(exinfo) |
|
|
|
col = e.dest_column |
|
|
|
if col < 0 or col >= dest.layout_count: |
|
|
|
raise DataError(sprintf("bad destination column number %d\n" + |
|
|
|
"dest stream only has 0 through %d", |
|
|
|
col, dest.layout_count - 1)) |
|
|
|
printf(" %s, output column %d\n", str(e), col) |
|
|
|
exemplars.append(e) |
|
|
|
if len(exemplars) == 0: |
|
|
|
raise DataError("missing exemplars") |
|
|
|
|
|
|
|
# Verify that the exemplar columns are all represented in the main data |
|
|
|
for n, ex in enumerate(exemplars): |
|
|
|
for col in ex.columns: |
|
|
|
if col not in main.columns: |
|
|
|
if col not in src_columns: |
|
|
|
raise DataError(sprintf("Exemplar %d column %s is not " |
|
|
|
"available in main data", n, col)) |
|
|
|
|
|
|
|
# Process the main data |
|
|
|
process(main, match, (main.columns, exemplars)) |
|
|
|
"available in source data", n, col)) |
|
|
|
|
|
|
|
# Process the data in a piecewise manner |
|
|
|
|
|
|
|
# # See which intervals we should processs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# intervals = ( Interval(start, end) |
|
|
|
# for (start, end) in |
|
|
|
# self._client_src.stream_intervals( |
|
|
|
# self.src.path, diffpath = self.dest.path, |
|
|
|
# start = self.start, end = self.end) ) |
|
|
|
# # Optimize intervals: join intervals that are adjacent |
|
|
|
# for interval in self.optimize_intervals(intervals): |
|
|
|
# yield interval |
|
|
|
|
|
|
|
|
|
|
|
# def process(main, function, args = None, rows = 200000): |
|
|
|
# """Process through the data; similar to nilmtools.Filter.process_numpy""" |
|
|
|
# if args is None: |
|
|
|
# args = [] |
|
|
|
|
|
|
|
# extractor = main.client.stream_extract_numpy |
|
|
|
# old_array = np.array([]) |
|
|
|
# for new_array in extractor(main.stream, main.start, main.end, |
|
|
|
# layout = main.info.layout, maxrows = rows): |
|
|
|
# # If we still had old data left, combine it |
|
|
|
# if old_array.shape[0] != 0: |
|
|
|
# array = np.vstack((old_array, new_array)) |
|
|
|
# else: |
|
|
|
# array = new_array |
|
|
|
|
|
|
|
# # Process it |
|
|
|
# processed = function(array, args) |
|
|
|
|
|
|
|
# # Save the unprocessed parts |
|
|
|
# if processed >= 0: |
|
|
|
# old_array = array[processed:] |
|
|
|
# else: |
|
|
|
# raise Exception(sprintf("%s return value %s must be >= 0", |
|
|
|
# str(function), str(processed))) |
|
|
|
|
|
|
|
# # Warn if there's too much data remaining |
|
|
|
# if old_array.shape[0] > 3 * rows: |
|
|
|
# printf("warning: %d unprocessed rows in buffer\n", |
|
|
|
# old_array.shape[0]) |
|
|
|
|
|
|
|
# # Handle leftover data |
|
|
|
# if old_array.shape[0] != 0: |
|
|
|
# processed = function(array, args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# process(src, src_client, dest, dest_client, dest, main, match, (main.columns, exemplars)) |
|
|
|
|
|
|
|
return "done" |
|
|
|
|
|
|
|
filterfunc = trainola |
|
|
|
|
|
|
|
def main(argv = None): |
|
|
|
import simplejson as json |
|
|
|
import argparse |
|
|
|
import sys |
|
|
|
|
|
|
|
if argv is None: |
|
|
|
argv = sys.argv[1:] |
|
|
|
|
|
|
|
# If the first parameter is a dictionary (passed in by direct call), |
|
|
|
# don't both parsing it as a JSON string. |
|
|
|
if len(argv) == 1 and isinstance(argv[0], dict): |
|
|
|
return trainola(argv[0]) |
|
|
|
|
|
|
|
# Parse command line arguments as text |
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
formatter_class = argparse.RawDescriptionHelpFormatter, |
|
|
|
version = nilmrun.__version__, |
|
|
|
version = nilmtools.__version__, |
|
|
|
description = """Run Trainola using parameters passed in as |
|
|
|
JSON-formatted data.""") |
|
|
|
parser.add_argument("file", metavar="FILE", nargs="?", |
|
|
|
type=argparse.FileType('r'), default=sys.stdin) |
|
|
|
parser.add_argument("data", metavar="DATA", |
|
|
|
help="Arguments, formatted as a JSON string") |
|
|
|
args = parser.parse_args(argv) |
|
|
|
|
|
|
|
conf = json.loads(args.file.read()) |
|
|
|
result = trainola(conf) |
|
|
|
print json.dumps(result, sort_keys = True, indent = 2 * ' ') |
|
|
|
conf = json.loads(args.data) |
|
|
|
return trainola(conf) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|
|