Browse Source

Add trainola from nilmrun

tags/nilmtools-1.3.1
Jim Paris 8 years ago
parent
commit
706c3933f9
1 changed files with 260 additions and 0 deletions
  1. +260
    -0
      nilmtools/trainola.py

+ 260
- 0
nilmtools/trainola.py View File

@@ -0,0 +1,260 @@
#!/usr/bin/python

from nilmdb.utils.printf import *
import nilmdb.client
import nilmtools.filter
from nilmdb.utils.time import (timestamp_to_human,
timestamp_to_seconds,
seconds_to_timestamp)
from nilmdb.utils.interval import Interval

import numpy as np
import scipy
import scipy.signal
from numpy.core.umath_tests import inner1d
import nilmrun
from collections import OrderedDict

class DataError(ValueError):
pass

class Data(object):
def __init__(self, name, url, stream, start, end, columns):
"""Initialize, get stream info, check columns"""
self.name = name
self.url = url
self.stream = stream
self.start = start
self.end = end

# Get stream info
self.client = nilmdb.client.numpyclient.NumpyClient(url)
self.info = nilmtools.filter.get_stream_info(self.client, stream)

# Build up name => index mapping for the columns
self.columns = OrderedDict()
for c in columns:
if (c['name'] in self.columns.keys() or
c['index'] in self.columns.values()):
raise DataError("duplicated columns")
if (c['index'] < 0 or c['index'] >= self.info.layout_count):
raise DataError("bad column number")
self.columns[c['name']] = c['index']
if not len(self.columns):
raise DataError("no columns")

# Count points
self.count = self.client.stream_count(self.stream, self.start, self.end)

def __str__(self):
return sprintf("%-20s: %s%s, %s rows",
self.name, self.stream, str(self.columns.keys()),
self.count)

def fetch(self, min_rows = 10, max_rows = 100000):
"""Fetch all the data into self.data. This is intended for
exemplars, and can only handle a relatively small number of
rows"""
# Verify count
if self.count == 0:
raise DataError("No data in this exemplar!")
if self.count < min_rows:
raise DataError("Too few data points: " + str(self.count))
if self.count > max_rows:
raise DataError("Too many data points: " + str(self.count))

# Extract the data
datagen = self.client.stream_extract_numpy(self.stream,
self.start, self.end,
self.info.layout,
maxrows = self.count)
self.data = list(datagen)[0]

# Discard timestamp
self.data = self.data[:,1:]

# Subtract the mean from each column
self.data = self.data - self.data.mean(axis=0)

# Get scale factors for each column by computing dot product
# of each column with itself.
self.scale = inner1d(self.data.T, self.data.T)

# Ensure a minimum (nonzero) scale and convert to list
self.scale = np.maximum(self.scale, [1e-9]).tolist()

def process(main, function, args = None, rows = 200000):
"""Process through the data; similar to nilmtools.Filter.process_numpy"""
if args is None:
args = []

extractor = main.client.stream_extract_numpy
old_array = np.array([])
for new_array in extractor(main.stream, main.start, main.end,
layout = main.info.layout, maxrows = rows):
# If we still had old data left, combine it
if old_array.shape[0] != 0:
array = np.vstack((old_array, new_array))
else:
array = new_array

# Process it
processed = function(array, args)

# Save the unprocessed parts
if processed >= 0:
old_array = array[processed:]
else:
raise Exception(sprintf("%s return value %s must be >= 0",
str(function), str(processed)))

# Warn if there's too much data remaining
if old_array.shape[0] > 3 * rows:
printf("warning: %d unprocessed rows in buffer\n",
old_array.shape[0])

# Handle leftover data
if old_array.shape[0] != 0:
processed = function(array, args)

def peak_detect(data, delta):
"""Simple min/max peak detection algorithm, taken from my code
in the disagg.m from the 10-8-5 paper"""
mins = [];
maxs = [];
cur_min = (None, np.inf)
cur_max = (None, -np.inf)
lookformax = False
for (n, p) in enumerate(data):
if p > cur_max[1]:
cur_max = (n, p)
if p < cur_min[1]:
cur_min = (n, p)
if lookformax:
if p < (cur_max[1] - delta):
maxs.append(cur_max)
cur_min = (n, p)
lookformax = False
else:
if p > (cur_min[1] + delta):
mins.append(cur_min)
cur_max = (n, p)
lookformax = True
return (mins, maxs)

def match(data, args):
"""Perform cross-correlation match"""
( columns, exemplars ) = args
nrows = data.shape[0]

# We want at least 10% more points than the widest exemplar.
widest = max([ x.count for x in exemplars ])
if (widest * 1.1) > nrows:
return 0

# This is how many points we'll consider valid in the
# cross-correlation.
valid = nrows + 1 - widest
matches = []

# Try matching against each of the exemplars
for e_num, e in enumerate(exemplars):
corrs = []

# Compute cross-correlation for each column
for c in e.columns:
a = data[:,columns[c] + 1]
b = e.data[:,e.columns[c]]
corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]

# Scale by the norm of the exemplar
corr = corr / e.scale[columns[c]]
corrs.append(corr)

# Find the peaks using the column with the largest amplitude
biggest = e.scale.index(max(e.scale))
peaks_minmax = peak_detect(corrs[biggest], 0.1)
peaks = [ p[0] for p in peaks_minmax[1] ]

# Now look at every peak
for p in peaks:
# Correlation for each column must be close enough to 1.
for (corr, scale) in zip(corrs, e.scale):
# The accepted distance from 1 is based on the relative
# amplitude of the column. Use a linear mapping:
# scale 1.0 -> distance 0.1
# scale 0.0 -> distance 1.0
distance = 1 - 0.9 * (scale / e.scale[biggest])
if abs(corr[p] - 1) > distance:
# No match
break
else:
# Successful match
matches.append((p, e_num))

# Print matches
for (point, e_num) in sorted(matches):
# Ignore matches that showed up at the very tail of the window,
# and shorten the window accordingly. This is an attempt to avoid
# problems at chunk boundaries.
if point > (valid - 50):
valid -= 50
break
print "matched", data[point,0], "exemplar", exemplars[e_num].name

#from matplotlib import pyplot as p
#p.plot(data[:,1:3])
#p.show()

return max(valid, 0)

def trainola(conf):
# Load main stream data
print "Loading stream data"
main = Data(None, conf['url'], conf['stream'],
conf['start'], conf['end'], conf['columns'])

# Pull in the exemplar data
exemplars = []
for n, e in enumerate(conf['exemplars']):
print sprintf("Loading exemplar %d: %s", n, e['name'])
ex = Data(e['name'], e['url'], e['stream'],
e['start'], e['end'], e['columns'])
ex.fetch()
exemplars.append(ex)

# Verify that the exemplar columns are all represented in the main data
for n, ex in enumerate(exemplars):
for col in ex.columns:
if col not in main.columns:
raise DataError(sprintf("Exemplar %d column %s is not "
"available in main data", n, col))

# Process the main data
process(main, match, (main.columns, exemplars))

return "done"

filterfunc = trainola

def main(argv = None):
import simplejson as json
import argparse
import sys

parser = argparse.ArgumentParser(
formatter_class = argparse.RawDescriptionHelpFormatter,
version = nilmrun.__version__,
description = """Run Trainola using parameters passed in as
JSON-formatted data.""")
parser.add_argument("file", metavar="FILE", nargs="?",
type=argparse.FileType('r'), default=sys.stdin)
args = parser.parse_args(argv)

conf = json.loads(args.file.read())
result = trainola(conf)
print json.dumps(result, sort_keys = True, indent = 2 * ' ')

if __name__ == "__main__":
main()


Loading…
Cancel
Save