nilmtools/nilmtools/trainola.py
Jim Paris cfc66b6847 Fix flake8 errors throughout code
This found a small number of real bugs too, for example,
this one that looked weird because of a 2to3 conversion,
but was wrong both before and after:
-        except IndexError as TypeError:
+        except (IndexError, TypeError):
2020-08-06 17:58:41 -04:00

326 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
from nilmdb.utils.printf import printf, sprintf
import nilmdb.client
import nilmtools.filter
import nilmtools.math
from nilmdb.utils.time import timestamp_to_seconds
import datetime_tz
from nilmdb.utils.interval import Interval
import numpy as np
import scipy
import scipy.signal
from numpy.core.umath_tests import inner1d
from collections import OrderedDict
import time
import functools
import collections
class DataError(ValueError):
pass
def build_column_mapping(colinfo, streaminfo):
"""Given the 'columns' list from the JSON data, verify and
pull out a dictionary mapping for the column names/numbers."""
columns = OrderedDict()
for c in colinfo:
col_num = c['index'] + 1 # skip timestamp
if (c['name'] in list(columns.keys()) or
col_num in list(columns.values())):
raise DataError("duplicated columns")
if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
raise DataError("bad column number")
columns[c['name']] = col_num
if not len(columns):
raise DataError("no columns")
return columns
class Exemplar(object):
def __init__(self, exinfo, min_rows=10, max_rows=100000):
"""Given a dictionary entry from the 'exemplars' input JSON,
verify the stream, columns, etc. Then, fetch all the data
into self.data."""
self.name = exinfo['name']
self.url = exinfo['url']
self.stream = exinfo['stream']
self.start = exinfo['start']
self.end = exinfo['end']
self.dest_column = exinfo['dest_column']
# Get stream info
self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
if not self.info:
raise DataError(sprintf("exemplar stream '%s' does not exist " +
"on server '%s'", self.stream, self.url))
# Build up name => index mapping for the columns
self.columns = build_column_mapping(exinfo['columns'], self.info)
# Count points
self.count = self.client.stream_count(self.stream,
self.start, self.end)
# Verify count
if self.count == 0:
raise DataError("No data in this exemplar!")
if self.count < min_rows:
raise DataError("Too few data points: " + str(self.count))
if self.count > max_rows:
raise DataError("Too many data points: " + str(self.count))
# Extract the data
datagen = self.client.stream_extract_numpy(self.stream,
self.start, self.end,
self.info.layout,
maxrows=self.count)
self.data = list(datagen)[0]
# Extract just the columns that were specified in self.columns,
# skipping the timestamp.
extract_cols = [value for (key, value) in list(self.columns.items())]
self.data = self.data[:, extract_cols]
# Fix the column indices in e.columns, since we removed/reordered
# columns in self.data
for n, k in enumerate(self.columns):
self.columns[k] = n
# Subtract the means from each column
self.data = self.data - self.data.mean(axis=0)
# Get scale factors for each column by computing dot product
# of each column with itself.
self.scale = inner1d(self.data.T, self.data.T)
# Ensure a minimum (nonzero) scale and convert to list
self.scale = np.maximum(self.scale, [1e-9]).tolist()
def __str__(self):
return sprintf("\"%s\" %s [%s] %s rows",
self.name, self.stream,
",".join(list(self.columns.keys())),
self.count)
def timestamp_to_short_human(timestamp):
dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
return dt.strftime("%H:%M:%S")
def trainola_matcher(data, interval, args, insert_func, final_chunk):
"""Perform cross-correlation match"""
(src_columns, dest_count, exemplars) = args
nrows = data.shape[0]
# We want at least 10% more points than the widest exemplar.
widest = max([x.count for x in exemplars])
if (widest * 1.1) > nrows:
return 0
# This is how many points we'll consider valid in the
# cross-correlation.
valid = nrows + 1 - widest
matches = collections.defaultdict(list)
# Try matching against each of the exemplars
for e in exemplars:
corrs = []
# Compute cross-correlation for each column
for col_name in e.columns:
a = data[:, src_columns[col_name]]
b = e.data[:, e.columns[col_name]]
corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
# Scale by the norm of the exemplar
corr = corr / e.scale[e.columns[col_name]]
corrs.append(corr)
# Find the peaks using the column with the largest amplitude
biggest = e.scale.index(max(e.scale))
peaks = nilmtools.math.peak_detect(corrs[biggest], 0.1)
# To try to reduce false positives, discard peaks where
# there's a higher-magnitude peak (either min or max) within
# one exemplar width nearby.
good_peak_locations = []
for (i, (n, p, is_max)) in enumerate(peaks):
if not is_max:
continue
ok = True
# check up to 'e.count' rows before this one
j = i-1
while ok and j >= 0 and peaks[j][0] > (n - e.count):
if abs(peaks[j][1]) > abs(p):
ok = False
j -= 1
# check up to 'e.count' rows after this one
j = i+1
while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
if abs(peaks[j][1]) > abs(p):
ok = False
j += 1
if ok:
good_peak_locations.append(n)
# Now look at all good peaks
for row in good_peak_locations:
# Correlation for each column must be close enough to 1.
for (corr, scale) in zip(corrs, e.scale):
# The accepted distance from 1 is based on the relative
# amplitude of the column. Use a linear mapping:
# scale 1.0 -> distance 0.1
# scale 0.0 -> distance 1.0
distance = 1 - 0.9 * (scale / e.scale[biggest])
if abs(corr[row] - 1) > distance:
# No match
break
else:
# Successful match
matches[row].append(e)
# Insert matches into destination stream.
matched_rows = sorted(matches.keys())
out = np.zeros((len(matched_rows), dest_count + 1))
for n, row in enumerate(matched_rows):
# Fill timestamp
out[n][0] = data[row, 0]
# Mark matched exemplars
for exemplar in matches[row]:
out[n, exemplar.dest_column + 1] = 1.0
# Insert it
insert_func(out)
# Return how many rows we processed
valid = max(valid, 0)
printf(" [%s] matched %d exemplars in %d rows\n",
timestamp_to_short_human(data[0][0]), np.sum(out[:, 1:]), valid)
return valid
def trainola(conf):
print("Trainola", nilmtools.__version__)
# Load main stream data
url = conf['url']
src_path = conf['stream']
dest_path = conf['dest_stream']
start = conf['start']
end = conf['end']
# Get info for the src and dest streams
src_client = nilmdb.client.numpyclient.NumpyClient(url)
src = nilmtools.filter.get_stream_info(src_client, src_path)
if not src:
raise DataError("source path '" + src_path + "' does not exist")
src_columns = build_column_mapping(conf['columns'], src)
dest_client = nilmdb.client.numpyclient.NumpyClient(url)
dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
if not dest:
raise DataError("destination path '" + dest_path + "' does not exist")
printf("Source:\n")
printf(" %s [%s]\n", src.path, ",".join(list(src_columns.keys())))
printf("Destination:\n")
printf(" %s (%s columns)\n", dest.path, dest.layout_count)
# Pull in the exemplar data
exemplars = []
if 'exemplars' not in conf:
raise DataError("missing exemplars")
for n, exinfo in enumerate(conf['exemplars']):
printf("Loading exemplar %d:\n", n)
e = Exemplar(exinfo)
col = e.dest_column
if col < 0 or col >= dest.layout_count:
raise DataError(sprintf("bad destination column number %d\n" +
"dest stream only has 0 through %d",
col, dest.layout_count - 1))
printf(" %s, output column %d\n", str(e), col)
exemplars.append(e)
if len(exemplars) == 0:
raise DataError("missing exemplars")
# Verify that the exemplar columns are all represented in the main data
for n, ex in enumerate(exemplars):
for col in ex.columns:
if col not in src_columns:
raise DataError(sprintf("Exemplar %d column %s is not "
"available in source data", n, col))
# Figure out which intervals we should process
intervals = (Interval(s, e) for (s, e) in
src_client.stream_intervals(src_path,
diffpath=dest_path,
start=start, end=end))
intervals = nilmdb.utils.interval.optimize(intervals)
# Do the processing
rows = 100000
extractor = functools.partial(src_client.stream_extract_numpy,
src.path, layout=src.layout, maxrows=rows)
inserter = functools.partial(dest_client.stream_insert_numpy_context,
dest.path)
start = time.time()
processed_time = 0
printf("Processing intervals:\n")
for interval in intervals:
printf("%s\n", interval.human_string())
nilmtools.filter.process_numpy_interval(
interval, extractor, inserter, rows * 3,
trainola_matcher, (src_columns, dest.layout_count, exemplars))
processed_time += (timestamp_to_seconds(interval.end) -
timestamp_to_seconds(interval.start))
elapsed = max(time.time() - start, 1e-3)
printf("Done. Processed %.2f seconds per second.\n",
processed_time / elapsed)
def main(argv=None):
import json
import sys
if argv is None:
argv = sys.argv[1:]
if len(argv) != 1 or argv[0] == '-h' or argv[0] == '--help':
printf("usage: %s [-h] [-v] <json-config-dictionary>\n\n", sys.argv[0])
printf(" Where <json-config-dictionary> is a JSON-encoded " +
"dictionary string\n")
printf(" with exemplar and stream data.\n\n")
printf(" See extras/trainola-test-param*.js in the nilmtools " +
"repository\n")
printf(" for examples.\n")
if len(argv) != 1:
raise SystemExit(1)
raise SystemExit(0)
if argv[0] == '-v' or argv[0] == '--version':
printf("%s\n", nilmtools.__version__)
raise SystemExit(0)
try:
# Passed in a JSON string (e.g. on the command line)
conf = json.loads(argv[0])
except TypeError:
# Passed in the config dictionary (e.g. from NilmRun)
conf = argv[0]
return trainola(conf)
if __name__ == "__main__":
main()