Add trainola from nilmrun
This commit is contained in:
parent
cfd1719152
commit
706c3933f9
260
nilmtools/trainola.py
Executable file
260
nilmtools/trainola.py
Executable file
|
@ -0,0 +1,260 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from nilmdb.utils.printf import *
|
||||
import nilmdb.client
|
||||
import nilmtools.filter
|
||||
from nilmdb.utils.time import (timestamp_to_human,
|
||||
timestamp_to_seconds,
|
||||
seconds_to_timestamp)
|
||||
from nilmdb.utils.interval import Interval
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.signal
|
||||
from numpy.core.umath_tests import inner1d
|
||||
import nilmrun
|
||||
from collections import OrderedDict
|
||||
|
||||
class DataError(ValueError):
|
||||
pass
|
||||
|
||||
class Data(object):
|
||||
def __init__(self, name, url, stream, start, end, columns):
|
||||
"""Initialize, get stream info, check columns"""
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.stream = stream
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
# Get stream info
|
||||
self.client = nilmdb.client.numpyclient.NumpyClient(url)
|
||||
self.info = nilmtools.filter.get_stream_info(self.client, stream)
|
||||
|
||||
# Build up name => index mapping for the columns
|
||||
self.columns = OrderedDict()
|
||||
for c in columns:
|
||||
if (c['name'] in self.columns.keys() or
|
||||
c['index'] in self.columns.values()):
|
||||
raise DataError("duplicated columns")
|
||||
if (c['index'] < 0 or c['index'] >= self.info.layout_count):
|
||||
raise DataError("bad column number")
|
||||
self.columns[c['name']] = c['index']
|
||||
if not len(self.columns):
|
||||
raise DataError("no columns")
|
||||
|
||||
# Count points
|
||||
self.count = self.client.stream_count(self.stream, self.start, self.end)
|
||||
|
||||
def __str__(self):
|
||||
return sprintf("%-20s: %s%s, %s rows",
|
||||
self.name, self.stream, str(self.columns.keys()),
|
||||
self.count)
|
||||
|
||||
def fetch(self, min_rows = 10, max_rows = 100000):
|
||||
"""Fetch all the data into self.data. This is intended for
|
||||
exemplars, and can only handle a relatively small number of
|
||||
rows"""
|
||||
# Verify count
|
||||
if self.count == 0:
|
||||
raise DataError("No data in this exemplar!")
|
||||
if self.count < min_rows:
|
||||
raise DataError("Too few data points: " + str(self.count))
|
||||
if self.count > max_rows:
|
||||
raise DataError("Too many data points: " + str(self.count))
|
||||
|
||||
# Extract the data
|
||||
datagen = self.client.stream_extract_numpy(self.stream,
|
||||
self.start, self.end,
|
||||
self.info.layout,
|
||||
maxrows = self.count)
|
||||
self.data = list(datagen)[0]
|
||||
|
||||
# Discard timestamp
|
||||
self.data = self.data[:,1:]
|
||||
|
||||
# Subtract the mean from each column
|
||||
self.data = self.data - self.data.mean(axis=0)
|
||||
|
||||
# Get scale factors for each column by computing dot product
|
||||
# of each column with itself.
|
||||
self.scale = inner1d(self.data.T, self.data.T)
|
||||
|
||||
# Ensure a minimum (nonzero) scale and convert to list
|
||||
self.scale = np.maximum(self.scale, [1e-9]).tolist()
|
||||
|
||||
def process(main, function, args = None, rows = 200000):
|
||||
"""Process through the data; similar to nilmtools.Filter.process_numpy"""
|
||||
if args is None:
|
||||
args = []
|
||||
|
||||
extractor = main.client.stream_extract_numpy
|
||||
old_array = np.array([])
|
||||
for new_array in extractor(main.stream, main.start, main.end,
|
||||
layout = main.info.layout, maxrows = rows):
|
||||
# If we still had old data left, combine it
|
||||
if old_array.shape[0] != 0:
|
||||
array = np.vstack((old_array, new_array))
|
||||
else:
|
||||
array = new_array
|
||||
|
||||
# Process it
|
||||
processed = function(array, args)
|
||||
|
||||
# Save the unprocessed parts
|
||||
if processed >= 0:
|
||||
old_array = array[processed:]
|
||||
else:
|
||||
raise Exception(sprintf("%s return value %s must be >= 0",
|
||||
str(function), str(processed)))
|
||||
|
||||
# Warn if there's too much data remaining
|
||||
if old_array.shape[0] > 3 * rows:
|
||||
printf("warning: %d unprocessed rows in buffer\n",
|
||||
old_array.shape[0])
|
||||
|
||||
# Handle leftover data
|
||||
if old_array.shape[0] != 0:
|
||||
processed = function(array, args)
|
||||
|
||||
def peak_detect(data, delta):
|
||||
"""Simple min/max peak detection algorithm, taken from my code
|
||||
in the disagg.m from the 10-8-5 paper"""
|
||||
mins = [];
|
||||
maxs = [];
|
||||
cur_min = (None, np.inf)
|
||||
cur_max = (None, -np.inf)
|
||||
lookformax = False
|
||||
for (n, p) in enumerate(data):
|
||||
if p > cur_max[1]:
|
||||
cur_max = (n, p)
|
||||
if p < cur_min[1]:
|
||||
cur_min = (n, p)
|
||||
if lookformax:
|
||||
if p < (cur_max[1] - delta):
|
||||
maxs.append(cur_max)
|
||||
cur_min = (n, p)
|
||||
lookformax = False
|
||||
else:
|
||||
if p > (cur_min[1] + delta):
|
||||
mins.append(cur_min)
|
||||
cur_max = (n, p)
|
||||
lookformax = True
|
||||
return (mins, maxs)
|
||||
|
||||
def match(data, args):
|
||||
"""Perform cross-correlation match"""
|
||||
( columns, exemplars ) = args
|
||||
nrows = data.shape[0]
|
||||
|
||||
# We want at least 10% more points than the widest exemplar.
|
||||
widest = max([ x.count for x in exemplars ])
|
||||
if (widest * 1.1) > nrows:
|
||||
return 0
|
||||
|
||||
# This is how many points we'll consider valid in the
|
||||
# cross-correlation.
|
||||
valid = nrows + 1 - widest
|
||||
matches = []
|
||||
|
||||
# Try matching against each of the exemplars
|
||||
for e_num, e in enumerate(exemplars):
|
||||
corrs = []
|
||||
|
||||
# Compute cross-correlation for each column
|
||||
for c in e.columns:
|
||||
a = data[:,columns[c] + 1]
|
||||
b = e.data[:,e.columns[c]]
|
||||
corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
|
||||
|
||||
# Scale by the norm of the exemplar
|
||||
corr = corr / e.scale[columns[c]]
|
||||
corrs.append(corr)
|
||||
|
||||
# Find the peaks using the column with the largest amplitude
|
||||
biggest = e.scale.index(max(e.scale))
|
||||
peaks_minmax = peak_detect(corrs[biggest], 0.1)
|
||||
peaks = [ p[0] for p in peaks_minmax[1] ]
|
||||
|
||||
# Now look at every peak
|
||||
for p in peaks:
|
||||
# Correlation for each column must be close enough to 1.
|
||||
for (corr, scale) in zip(corrs, e.scale):
|
||||
# The accepted distance from 1 is based on the relative
|
||||
# amplitude of the column. Use a linear mapping:
|
||||
# scale 1.0 -> distance 0.1
|
||||
# scale 0.0 -> distance 1.0
|
||||
distance = 1 - 0.9 * (scale / e.scale[biggest])
|
||||
if abs(corr[p] - 1) > distance:
|
||||
# No match
|
||||
break
|
||||
else:
|
||||
# Successful match
|
||||
matches.append((p, e_num))
|
||||
|
||||
# Print matches
|
||||
for (point, e_num) in sorted(matches):
|
||||
# Ignore matches that showed up at the very tail of the window,
|
||||
# and shorten the window accordingly. This is an attempt to avoid
|
||||
# problems at chunk boundaries.
|
||||
if point > (valid - 50):
|
||||
valid -= 50
|
||||
break
|
||||
print "matched", data[point,0], "exemplar", exemplars[e_num].name
|
||||
|
||||
#from matplotlib import pyplot as p
|
||||
#p.plot(data[:,1:3])
|
||||
#p.show()
|
||||
|
||||
return max(valid, 0)
|
||||
|
||||
def trainola(conf):
|
||||
# Load main stream data
|
||||
print "Loading stream data"
|
||||
main = Data(None, conf['url'], conf['stream'],
|
||||
conf['start'], conf['end'], conf['columns'])
|
||||
|
||||
# Pull in the exemplar data
|
||||
exemplars = []
|
||||
for n, e in enumerate(conf['exemplars']):
|
||||
print sprintf("Loading exemplar %d: %s", n, e['name'])
|
||||
ex = Data(e['name'], e['url'], e['stream'],
|
||||
e['start'], e['end'], e['columns'])
|
||||
ex.fetch()
|
||||
exemplars.append(ex)
|
||||
|
||||
# Verify that the exemplar columns are all represented in the main data
|
||||
for n, ex in enumerate(exemplars):
|
||||
for col in ex.columns:
|
||||
if col not in main.columns:
|
||||
raise DataError(sprintf("Exemplar %d column %s is not "
|
||||
"available in main data", n, col))
|
||||
|
||||
# Process the main data
|
||||
process(main, match, (main.columns, exemplars))
|
||||
|
||||
return "done"
|
||||
|
||||
filterfunc = trainola
|
||||
|
||||
def main(argv = None):
|
||||
import simplejson as json
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class = argparse.RawDescriptionHelpFormatter,
|
||||
version = nilmrun.__version__,
|
||||
description = """Run Trainola using parameters passed in as
|
||||
JSON-formatted data.""")
|
||||
parser.add_argument("file", metavar="FILE", nargs="?",
|
||||
type=argparse.FileType('r'), default=sys.stdin)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
conf = json.loads(args.file.read())
|
||||
result = trainola(conf)
|
||||
print json.dumps(result, sort_keys = True, indent = 2 * ' ')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user