This found a small number of real bugs too, for example, this one that looked weird because of a 2to3 conversion, but was wrong both before and after: - except IndexError as TypeError: + except (IndexError, TypeError):
326 lines
12 KiB
Python
Executable File
326 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
from nilmdb.utils.printf import printf, sprintf
|
|
import nilmdb.client
|
|
import nilmtools.filter
|
|
import nilmtools.math
|
|
from nilmdb.utils.time import timestamp_to_seconds
|
|
import datetime_tz
|
|
from nilmdb.utils.interval import Interval
|
|
|
|
import numpy as np
|
|
import scipy
|
|
import scipy.signal
|
|
from numpy.core.umath_tests import inner1d
|
|
from collections import OrderedDict
|
|
import time
|
|
import functools
|
|
import collections
|
|
|
|
|
|
class DataError(ValueError):
|
|
pass
|
|
|
|
|
|
def build_column_mapping(colinfo, streaminfo):
|
|
"""Given the 'columns' list from the JSON data, verify and
|
|
pull out a dictionary mapping for the column names/numbers."""
|
|
columns = OrderedDict()
|
|
for c in colinfo:
|
|
col_num = c['index'] + 1 # skip timestamp
|
|
if (c['name'] in list(columns.keys()) or
|
|
col_num in list(columns.values())):
|
|
raise DataError("duplicated columns")
|
|
if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
|
|
raise DataError("bad column number")
|
|
columns[c['name']] = col_num
|
|
if not len(columns):
|
|
raise DataError("no columns")
|
|
return columns
|
|
|
|
|
|
class Exemplar(object):
|
|
def __init__(self, exinfo, min_rows=10, max_rows=100000):
|
|
"""Given a dictionary entry from the 'exemplars' input JSON,
|
|
verify the stream, columns, etc. Then, fetch all the data
|
|
into self.data."""
|
|
|
|
self.name = exinfo['name']
|
|
self.url = exinfo['url']
|
|
self.stream = exinfo['stream']
|
|
self.start = exinfo['start']
|
|
self.end = exinfo['end']
|
|
self.dest_column = exinfo['dest_column']
|
|
|
|
# Get stream info
|
|
self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
|
|
self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
|
|
if not self.info:
|
|
raise DataError(sprintf("exemplar stream '%s' does not exist " +
|
|
"on server '%s'", self.stream, self.url))
|
|
|
|
# Build up name => index mapping for the columns
|
|
self.columns = build_column_mapping(exinfo['columns'], self.info)
|
|
|
|
# Count points
|
|
self.count = self.client.stream_count(self.stream,
|
|
self.start, self.end)
|
|
|
|
# Verify count
|
|
if self.count == 0:
|
|
raise DataError("No data in this exemplar!")
|
|
if self.count < min_rows:
|
|
raise DataError("Too few data points: " + str(self.count))
|
|
if self.count > max_rows:
|
|
raise DataError("Too many data points: " + str(self.count))
|
|
|
|
# Extract the data
|
|
datagen = self.client.stream_extract_numpy(self.stream,
|
|
self.start, self.end,
|
|
self.info.layout,
|
|
maxrows=self.count)
|
|
self.data = list(datagen)[0]
|
|
|
|
# Extract just the columns that were specified in self.columns,
|
|
# skipping the timestamp.
|
|
extract_cols = [value for (key, value) in list(self.columns.items())]
|
|
self.data = self.data[:, extract_cols]
|
|
|
|
# Fix the column indices in e.columns, since we removed/reordered
|
|
# columns in self.data
|
|
for n, k in enumerate(self.columns):
|
|
self.columns[k] = n
|
|
|
|
# Subtract the means from each column
|
|
self.data = self.data - self.data.mean(axis=0)
|
|
|
|
# Get scale factors for each column by computing dot product
|
|
# of each column with itself.
|
|
self.scale = inner1d(self.data.T, self.data.T)
|
|
|
|
# Ensure a minimum (nonzero) scale and convert to list
|
|
self.scale = np.maximum(self.scale, [1e-9]).tolist()
|
|
|
|
def __str__(self):
|
|
return sprintf("\"%s\" %s [%s] %s rows",
|
|
self.name, self.stream,
|
|
",".join(list(self.columns.keys())),
|
|
self.count)
|
|
|
|
|
|
def timestamp_to_short_human(timestamp):
|
|
dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
|
|
return dt.strftime("%H:%M:%S")
|
|
|
|
|
|
def trainola_matcher(data, interval, args, insert_func, final_chunk):
|
|
"""Perform cross-correlation match"""
|
|
(src_columns, dest_count, exemplars) = args
|
|
nrows = data.shape[0]
|
|
|
|
# We want at least 10% more points than the widest exemplar.
|
|
widest = max([x.count for x in exemplars])
|
|
if (widest * 1.1) > nrows:
|
|
return 0
|
|
|
|
# This is how many points we'll consider valid in the
|
|
# cross-correlation.
|
|
valid = nrows + 1 - widest
|
|
matches = collections.defaultdict(list)
|
|
|
|
# Try matching against each of the exemplars
|
|
for e in exemplars:
|
|
corrs = []
|
|
|
|
# Compute cross-correlation for each column
|
|
for col_name in e.columns:
|
|
a = data[:, src_columns[col_name]]
|
|
b = e.data[:, e.columns[col_name]]
|
|
corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
|
|
|
|
# Scale by the norm of the exemplar
|
|
corr = corr / e.scale[e.columns[col_name]]
|
|
corrs.append(corr)
|
|
|
|
# Find the peaks using the column with the largest amplitude
|
|
biggest = e.scale.index(max(e.scale))
|
|
peaks = nilmtools.math.peak_detect(corrs[biggest], 0.1)
|
|
|
|
# To try to reduce false positives, discard peaks where
|
|
# there's a higher-magnitude peak (either min or max) within
|
|
# one exemplar width nearby.
|
|
good_peak_locations = []
|
|
for (i, (n, p, is_max)) in enumerate(peaks):
|
|
if not is_max:
|
|
continue
|
|
ok = True
|
|
# check up to 'e.count' rows before this one
|
|
j = i-1
|
|
while ok and j >= 0 and peaks[j][0] > (n - e.count):
|
|
if abs(peaks[j][1]) > abs(p):
|
|
ok = False
|
|
j -= 1
|
|
|
|
# check up to 'e.count' rows after this one
|
|
j = i+1
|
|
while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
|
|
if abs(peaks[j][1]) > abs(p):
|
|
ok = False
|
|
j += 1
|
|
|
|
if ok:
|
|
good_peak_locations.append(n)
|
|
|
|
# Now look at all good peaks
|
|
for row in good_peak_locations:
|
|
# Correlation for each column must be close enough to 1.
|
|
for (corr, scale) in zip(corrs, e.scale):
|
|
# The accepted distance from 1 is based on the relative
|
|
# amplitude of the column. Use a linear mapping:
|
|
# scale 1.0 -> distance 0.1
|
|
# scale 0.0 -> distance 1.0
|
|
distance = 1 - 0.9 * (scale / e.scale[biggest])
|
|
if abs(corr[row] - 1) > distance:
|
|
# No match
|
|
break
|
|
else:
|
|
# Successful match
|
|
matches[row].append(e)
|
|
|
|
# Insert matches into destination stream.
|
|
matched_rows = sorted(matches.keys())
|
|
out = np.zeros((len(matched_rows), dest_count + 1))
|
|
|
|
for n, row in enumerate(matched_rows):
|
|
# Fill timestamp
|
|
out[n][0] = data[row, 0]
|
|
|
|
# Mark matched exemplars
|
|
for exemplar in matches[row]:
|
|
out[n, exemplar.dest_column + 1] = 1.0
|
|
|
|
# Insert it
|
|
insert_func(out)
|
|
|
|
# Return how many rows we processed
|
|
valid = max(valid, 0)
|
|
printf(" [%s] matched %d exemplars in %d rows\n",
|
|
timestamp_to_short_human(data[0][0]), np.sum(out[:, 1:]), valid)
|
|
return valid
|
|
|
|
|
|
def trainola(conf):
|
|
print("Trainola", nilmtools.__version__)
|
|
|
|
# Load main stream data
|
|
url = conf['url']
|
|
src_path = conf['stream']
|
|
dest_path = conf['dest_stream']
|
|
start = conf['start']
|
|
end = conf['end']
|
|
|
|
# Get info for the src and dest streams
|
|
src_client = nilmdb.client.numpyclient.NumpyClient(url)
|
|
src = nilmtools.filter.get_stream_info(src_client, src_path)
|
|
if not src:
|
|
raise DataError("source path '" + src_path + "' does not exist")
|
|
src_columns = build_column_mapping(conf['columns'], src)
|
|
|
|
dest_client = nilmdb.client.numpyclient.NumpyClient(url)
|
|
dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
|
|
if not dest:
|
|
raise DataError("destination path '" + dest_path + "' does not exist")
|
|
|
|
printf("Source:\n")
|
|
printf(" %s [%s]\n", src.path, ",".join(list(src_columns.keys())))
|
|
printf("Destination:\n")
|
|
printf(" %s (%s columns)\n", dest.path, dest.layout_count)
|
|
|
|
# Pull in the exemplar data
|
|
exemplars = []
|
|
if 'exemplars' not in conf:
|
|
raise DataError("missing exemplars")
|
|
for n, exinfo in enumerate(conf['exemplars']):
|
|
printf("Loading exemplar %d:\n", n)
|
|
e = Exemplar(exinfo)
|
|
col = e.dest_column
|
|
if col < 0 or col >= dest.layout_count:
|
|
raise DataError(sprintf("bad destination column number %d\n" +
|
|
"dest stream only has 0 through %d",
|
|
col, dest.layout_count - 1))
|
|
printf(" %s, output column %d\n", str(e), col)
|
|
exemplars.append(e)
|
|
if len(exemplars) == 0:
|
|
raise DataError("missing exemplars")
|
|
|
|
# Verify that the exemplar columns are all represented in the main data
|
|
for n, ex in enumerate(exemplars):
|
|
for col in ex.columns:
|
|
if col not in src_columns:
|
|
raise DataError(sprintf("Exemplar %d column %s is not "
|
|
"available in source data", n, col))
|
|
|
|
# Figure out which intervals we should process
|
|
intervals = (Interval(s, e) for (s, e) in
|
|
src_client.stream_intervals(src_path,
|
|
diffpath=dest_path,
|
|
start=start, end=end))
|
|
intervals = nilmdb.utils.interval.optimize(intervals)
|
|
|
|
# Do the processing
|
|
rows = 100000
|
|
extractor = functools.partial(src_client.stream_extract_numpy,
|
|
src.path, layout=src.layout, maxrows=rows)
|
|
inserter = functools.partial(dest_client.stream_insert_numpy_context,
|
|
dest.path)
|
|
start = time.time()
|
|
processed_time = 0
|
|
printf("Processing intervals:\n")
|
|
for interval in intervals:
|
|
printf("%s\n", interval.human_string())
|
|
nilmtools.filter.process_numpy_interval(
|
|
interval, extractor, inserter, rows * 3,
|
|
trainola_matcher, (src_columns, dest.layout_count, exemplars))
|
|
processed_time += (timestamp_to_seconds(interval.end) -
|
|
timestamp_to_seconds(interval.start))
|
|
elapsed = max(time.time() - start, 1e-3)
|
|
|
|
printf("Done. Processed %.2f seconds per second.\n",
|
|
processed_time / elapsed)
|
|
|
|
|
|
def main(argv=None):
|
|
import json
|
|
import sys
|
|
|
|
if argv is None:
|
|
argv = sys.argv[1:]
|
|
if len(argv) != 1 or argv[0] == '-h' or argv[0] == '--help':
|
|
printf("usage: %s [-h] [-v] <json-config-dictionary>\n\n", sys.argv[0])
|
|
printf(" Where <json-config-dictionary> is a JSON-encoded " +
|
|
"dictionary string\n")
|
|
printf(" with exemplar and stream data.\n\n")
|
|
printf(" See extras/trainola-test-param*.js in the nilmtools " +
|
|
"repository\n")
|
|
printf(" for examples.\n")
|
|
if len(argv) != 1:
|
|
raise SystemExit(1)
|
|
raise SystemExit(0)
|
|
|
|
if argv[0] == '-v' or argv[0] == '--version':
|
|
printf("%s\n", nilmtools.__version__)
|
|
raise SystemExit(0)
|
|
|
|
try:
|
|
# Passed in a JSON string (e.g. on the command line)
|
|
conf = json.loads(argv[0])
|
|
except TypeError:
|
|
# Passed in the config dictionary (e.g. from NilmRun)
|
|
conf = argv[0]
|
|
|
|
return trainola(conf)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|