|
- #!/usr/bin/python
-
- from nilmdb.utils.printf import *
- import nilmdb.client
- import nilmtools.filter
- from nilmdb.utils.time import (timestamp_to_human,
- timestamp_to_seconds,
- seconds_to_timestamp)
- from nilmdb.utils.interval import Interval
-
- import numpy as np
- import scipy
- import scipy.signal
- from numpy.core.umath_tests import inner1d
- import nilmrun
- from collections import OrderedDict
-
- class DataError(ValueError):
- pass
-
- class Data(object):
- def __init__(self, name, url, stream, start, end, columns):
- """Initialize, get stream info, check columns"""
- self.name = name
- self.url = url
- self.stream = stream
- self.start = start
- self.end = end
-
- # Get stream info
- self.client = nilmdb.client.numpyclient.NumpyClient(url)
- self.info = nilmtools.filter.get_stream_info(self.client, stream)
-
- # Build up name => index mapping for the columns
- self.columns = OrderedDict()
- for c in columns:
- if (c['name'] in self.columns.keys() or
- c['index'] in self.columns.values()):
- raise DataError("duplicated columns")
- if (c['index'] < 0 or c['index'] >= self.info.layout_count):
- raise DataError("bad column number")
- self.columns[c['name']] = c['index']
- if not len(self.columns):
- raise DataError("no columns")
-
- # Count points
- self.count = self.client.stream_count(self.stream, self.start, self.end)
-
- def __str__(self):
- return sprintf("%-20s: %s%s, %s rows",
- self.name, self.stream, str(self.columns.keys()),
- self.count)
-
- def fetch(self, min_rows = 10, max_rows = 100000):
- """Fetch all the data into self.data. This is intended for
- exemplars, and can only handle a relatively small number of
- rows"""
- # Verify count
- if self.count == 0:
- raise DataError("No data in this exemplar!")
- if self.count < min_rows:
- raise DataError("Too few data points: " + str(self.count))
- if self.count > max_rows:
- raise DataError("Too many data points: " + str(self.count))
-
- # Extract the data
- datagen = self.client.stream_extract_numpy(self.stream,
- self.start, self.end,
- self.info.layout,
- maxrows = self.count)
- self.data = list(datagen)[0]
-
- # Discard timestamp
- self.data = self.data[:,1:]
-
- # Subtract the mean from each column
- self.data = self.data - self.data.mean(axis=0)
-
- # Get scale factors for each column by computing dot product
- # of each column with itself.
- self.scale = inner1d(self.data.T, self.data.T)
-
- # Ensure a minimum (nonzero) scale and convert to list
- self.scale = np.maximum(self.scale, [1e-9]).tolist()
-
- def process(main, function, args = None, rows = 200000):
- """Process through the data; similar to nilmtools.Filter.process_numpy"""
- if args is None:
- args = []
-
- extractor = main.client.stream_extract_numpy
- old_array = np.array([])
- for new_array in extractor(main.stream, main.start, main.end,
- layout = main.info.layout, maxrows = rows):
- # If we still had old data left, combine it
- if old_array.shape[0] != 0:
- array = np.vstack((old_array, new_array))
- else:
- array = new_array
-
- # Process it
- processed = function(array, args)
-
- # Save the unprocessed parts
- if processed >= 0:
- old_array = array[processed:]
- else:
- raise Exception(sprintf("%s return value %s must be >= 0",
- str(function), str(processed)))
-
- # Warn if there's too much data remaining
- if old_array.shape[0] > 3 * rows:
- printf("warning: %d unprocessed rows in buffer\n",
- old_array.shape[0])
-
- # Handle leftover data
- if old_array.shape[0] != 0:
- processed = function(array, args)
-
- def peak_detect(data, delta):
- """Simple min/max peak detection algorithm, taken from my code
- in the disagg.m from the 10-8-5 paper"""
- mins = [];
- maxs = [];
- cur_min = (None, np.inf)
- cur_max = (None, -np.inf)
- lookformax = False
- for (n, p) in enumerate(data):
- if p > cur_max[1]:
- cur_max = (n, p)
- if p < cur_min[1]:
- cur_min = (n, p)
- if lookformax:
- if p < (cur_max[1] - delta):
- maxs.append(cur_max)
- cur_min = (n, p)
- lookformax = False
- else:
- if p > (cur_min[1] + delta):
- mins.append(cur_min)
- cur_max = (n, p)
- lookformax = True
- return (mins, maxs)
-
- def match(data, args):
- """Perform cross-correlation match"""
- ( columns, exemplars ) = args
- nrows = data.shape[0]
-
- # We want at least 10% more points than the widest exemplar.
- widest = max([ x.count for x in exemplars ])
- if (widest * 1.1) > nrows:
- return 0
-
- # This is how many points we'll consider valid in the
- # cross-correlation.
- valid = nrows + 1 - widest
- matches = []
-
- # Try matching against each of the exemplars
- for e_num, e in enumerate(exemplars):
- corrs = []
-
- # Compute cross-correlation for each column
- for c in e.columns:
- a = data[:,columns[c] + 1]
- b = e.data[:,e.columns[c]]
- corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
-
- # Scale by the norm of the exemplar
- corr = corr / e.scale[columns[c]]
- corrs.append(corr)
-
- # Find the peaks using the column with the largest amplitude
- biggest = e.scale.index(max(e.scale))
- peaks_minmax = peak_detect(corrs[biggest], 0.1)
- peaks = [ p[0] for p in peaks_minmax[1] ]
-
- # Now look at every peak
- for p in peaks:
- # Correlation for each column must be close enough to 1.
- for (corr, scale) in zip(corrs, e.scale):
- # The accepted distance from 1 is based on the relative
- # amplitude of the column. Use a linear mapping:
- # scale 1.0 -> distance 0.1
- # scale 0.0 -> distance 1.0
- distance = 1 - 0.9 * (scale / e.scale[biggest])
- if abs(corr[p] - 1) > distance:
- # No match
- break
- else:
- # Successful match
- matches.append((p, e_num))
-
- # Print matches
- for (point, e_num) in sorted(matches):
- # Ignore matches that showed up at the very tail of the window,
- # and shorten the window accordingly. This is an attempt to avoid
- # problems at chunk boundaries.
- if point > (valid - 50):
- valid -= 50
- break
- print "matched", data[point,0], "exemplar", exemplars[e_num].name
-
- #from matplotlib import pyplot as p
- #p.plot(data[:,1:3])
- #p.show()
-
- return max(valid, 0)
-
- def trainola(conf):
- # Load main stream data
- print "Loading stream data"
- main = Data(None, conf['url'], conf['stream'],
- conf['start'], conf['end'], conf['columns'])
-
- # Pull in the exemplar data
- exemplars = []
- for n, e in enumerate(conf['exemplars']):
- print sprintf("Loading exemplar %d: %s", n, e['name'])
- ex = Data(e['name'], e['url'], e['stream'],
- e['start'], e['end'], e['columns'])
- ex.fetch()
- exemplars.append(ex)
-
- # Verify that the exemplar columns are all represented in the main data
- for n, ex in enumerate(exemplars):
- for col in ex.columns:
- if col not in main.columns:
- raise DataError(sprintf("Exemplar %d column %s is not "
- "available in main data", n, col))
-
- # Process the main data
- process(main, match, (main.columns, exemplars))
-
- return "done"
-
- def main(argv = None):
- import simplejson as json
- import argparse
- import sys
-
- parser = argparse.ArgumentParser(
- formatter_class = argparse.RawDescriptionHelpFormatter,
- version = nilmrun.__version__,
- description = """Run Trainola using parameters passed in as
- JSON-formatted data.""")
- parser.add_argument("file", metavar="FILE", nargs="?",
- type=argparse.FileType('r'), default=sys.stdin)
- args = parser.parse_args(argv)
-
- conf = json.loads(args.file.read())
- result = trainola(conf)
- print json.dumps(result, sort_keys = True, indent = 2 * ' ')
-
- if __name__ == "__main__":
- main()
|