|
- #!/usr/bin/python
-
- from nilmdb.utils.printf import *
- import nilmdb.client
- import nilmtools.filter
- from nilmdb.utils.time import (timestamp_to_human,
- timestamp_to_seconds,
- seconds_to_timestamp)
- from nilmdb.utils.interval import Interval
-
- import numpy as np
- import scipy
- import scipy.signal
- from numpy.core.umath_tests import inner1d
- import nilmrun
- from collections import OrderedDict
- import sys
-
- class DataError(ValueError):
- pass
-
- def build_column_mapping(colinfo, streaminfo):
- """Given the 'columns' list from the JSON data, verify and
- pull out a dictionary mapping for the column names/numbers."""
- columns = OrderedDict()
- for c in colinfo:
- if (c['name'] in columns.keys() or
- c['index'] in columns.values()):
- raise DataError("duplicated columns")
- if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
- raise DataError("bad column number")
- columns[c['name']] = c['index']
- if not len(columns):
- raise DataError("no columns")
- return columns
-
- class Exemplar(object):
- def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
- """Given a dictionary entry from the 'exemplars' input JSON,
- verify the stream, columns, etc. Then, fetch all the data
- into self.data."""
-
- self.name = exinfo['name']
- self.url = exinfo['url']
- self.stream = exinfo['stream']
- self.start = exinfo['start']
- self.end = exinfo['end']
- self.dest_column = exinfo['dest_column']
-
- # Get stream info
- self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
- self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
-
- # Build up name => index mapping for the columns
- self.columns = build_column_mapping(exinfo['columns'], self.info)
-
- # Count points
- self.count = self.client.stream_count(self.stream, self.start, self.end)
-
- # Verify count
- if self.count == 0:
- raise DataError("No data in this exemplar!")
- if self.count < min_rows:
- raise DataError("Too few data points: " + str(self.count))
- if self.count > max_rows:
- raise DataError("Too many data points: " + str(self.count))
-
- # Extract the data
- datagen = self.client.stream_extract_numpy(self.stream,
- self.start, self.end,
- self.info.layout,
- maxrows = self.count)
- self.data = list(datagen)[0]
-
- # Discard timestamp
- self.data = self.data[:,1:]
-
- # Subtract the mean from each column
- self.data = self.data - self.data.mean(axis=0)
-
- # Get scale factors for each column by computing dot product
- # of each column with itself.
- self.scale = inner1d(self.data.T, self.data.T)
-
- # Ensure a minimum (nonzero) scale and convert to list
- self.scale = np.maximum(self.scale, [1e-9]).tolist()
-
- def __str__(self):
- return sprintf("\"%s\" %s [%s] %s rows",
- self.name, self.stream, ",".join(self.columns.keys()),
- self.count)
-
- def peak_detect(data, delta):
- """Simple min/max peak detection algorithm, taken from my code
- in the disagg.m from the 10-8-5 paper"""
- mins = [];
- maxs = [];
- cur_min = (None, np.inf)
- cur_max = (None, -np.inf)
- lookformax = False
- for (n, p) in enumerate(data):
- if p > cur_max[1]:
- cur_max = (n, p)
- if p < cur_min[1]:
- cur_min = (n, p)
- if lookformax:
- if p < (cur_max[1] - delta):
- maxs.append(cur_max)
- cur_min = (n, p)
- lookformax = False
- else:
- if p > (cur_min[1] + delta):
- mins.append(cur_min)
- cur_max = (n, p)
- lookformax = True
- return (mins, maxs)
-
- def trainola_matcher(data, args):
- """Perform cross-correlation match"""
- ( columns, exemplars ) = args
- nrows = data.shape[0]
-
- # We want at least 10% more points than the widest exemplar.
- widest = max([ x.count for x in exemplars ])
- if (widest * 1.1) > nrows:
- return 0
-
- # This is how many points we'll consider valid in the
- # cross-correlation.
- valid = nrows + 1 - widest
- matches = []
-
- # Try matching against each of the exemplars
- for e_num, e in enumerate(exemplars):
- corrs = []
-
- # Compute cross-correlation for each column
- for c in e.columns:
- a = data[:,columns[c] + 1]
- b = e.data[:,e.columns[c]]
- corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
-
- # Scale by the norm of the exemplar
- corr = corr / e.scale[columns[c]]
- corrs.append(corr)
-
- # Find the peaks using the column with the largest amplitude
- biggest = e.scale.index(max(e.scale))
- peaks_minmax = peak_detect(corrs[biggest], 0.1)
- peaks = [ p[0] for p in peaks_minmax[1] ]
-
- # Now look at every peak
- for p in peaks:
- # Correlation for each column must be close enough to 1.
- for (corr, scale) in zip(corrs, e.scale):
- # The accepted distance from 1 is based on the relative
- # amplitude of the column. Use a linear mapping:
- # scale 1.0 -> distance 0.1
- # scale 0.0 -> distance 1.0
- distance = 1 - 0.9 * (scale / e.scale[biggest])
- if abs(corr[p] - 1) > distance:
- # No match
- break
- else:
- # Successful match
- matches.append((p, e_num))
-
- # Print matches
- for (point, e_num) in sorted(matches):
- # Ignore matches that showed up at the very tail of the window,
- # and shorten the window accordingly. This is an attempt to avoid
- # problems at chunk boundaries.
- if point > (valid - 50):
- valid -= 50
- break
- print "matched", data[point,0], "exemplar", exemplars[e_num].name
-
- #from matplotlib import pyplot as p
- #p.plot(data[:,1:3])
- #p.show()
-
- return max(valid, 0)
-
- def trainola(conf):
- print "Trainola", nilmtools.__version__
-
- # Load main stream data
- url = conf['url']
- src_path = conf['stream']
- dest_path = conf['dest_stream']
- start = conf['start']
- end = conf['end']
-
- # Get info for the src and dest streams
- src_client = nilmdb.client.numpyclient.NumpyClient(url)
- src = nilmtools.filter.get_stream_info(src_client, src_path)
- if not src:
- raise DataError("source path '" + src_path + "' does not exist")
- src_columns = build_column_mapping(conf['columns'], src)
-
- dest_client = nilmdb.client.numpyclient.NumpyClient(url)
- dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
- if not dest:
- raise DataError("destination path '" + dest_path + "' does not exist")
-
- printf("Source:\n")
- printf(" %s [%s]\n", src.path, ",".join(src_columns.keys()))
- printf("Destination:\n")
- printf(" %s (%s columns)\n", dest.path, dest.layout_count)
-
- # Pull in the exemplar data
- exemplars = []
- for n, exinfo in enumerate(conf['exemplars']):
- printf("Loading exemplar %d:\n", n)
- e = Exemplar(exinfo)
- col = e.dest_column
- if col < 0 or col >= dest.layout_count:
- raise DataError(sprintf("bad destination column number %d\n" +
- "dest stream only has 0 through %d",
- col, dest.layout_count - 1))
- printf(" %s, output column %d\n", str(e), col)
- exemplars.append(e)
- if len(exemplars) == 0:
- raise DataError("missing exemplars")
-
- # Verify that the exemplar columns are all represented in the main data
- for n, ex in enumerate(exemplars):
- for col in ex.columns:
- if col not in src_columns:
- raise DataError(sprintf("Exemplar %d column %s is not "
- "available in source data", n, col))
-
- # Process the data in a piecewise manner
-
- # # See which intervals we should processs
-
-
-
- # intervals = ( Interval(start, end)
- # for (start, end) in
- # self._client_src.stream_intervals(
- # self.src.path, diffpath = self.dest.path,
- # start = self.start, end = self.end) )
- # # Optimize intervals: join intervals that are adjacent
- # for interval in self.optimize_intervals(intervals):
- # yield interval
-
-
- # def process(main, function, args = None, rows = 200000):
- # """Process through the data; similar to nilmtools.Filter.process_numpy"""
- # if args is None:
- # args = []
-
- # extractor = main.client.stream_extract_numpy
- # old_array = np.array([])
- # for new_array in extractor(main.stream, main.start, main.end,
- # layout = main.info.layout, maxrows = rows):
- # # If we still had old data left, combine it
- # if old_array.shape[0] != 0:
- # array = np.vstack((old_array, new_array))
- # else:
- # array = new_array
-
- # # Process it
- # processed = function(array, args)
-
- # # Save the unprocessed parts
- # if processed >= 0:
- # old_array = array[processed:]
- # else:
- # raise Exception(sprintf("%s return value %s must be >= 0",
- # str(function), str(processed)))
-
- # # Warn if there's too much data remaining
- # if old_array.shape[0] > 3 * rows:
- # printf("warning: %d unprocessed rows in buffer\n",
- # old_array.shape[0])
-
- # # Handle leftover data
- # if old_array.shape[0] != 0:
- # processed = function(array, args)
-
-
-
- # process(src, src_client, dest, dest_client, dest, main, match, (main.columns, exemplars))
-
- return "done"
-
- def main(argv = None):
- import simplejson as json
- import argparse
- import sys
-
- if argv is None:
- argv = sys.argv[1:]
-
- # If the first parameter is a dictionary (passed in by direct call),
- # don't both parsing it as a JSON string.
- if len(argv) == 1 and isinstance(argv[0], dict):
- return trainola(argv[0])
-
- # Parse command line arguments as text
- parser = argparse.ArgumentParser(
- formatter_class = argparse.RawDescriptionHelpFormatter,
- version = nilmtools.__version__,
- description = """Run Trainola using parameters passed in as
- JSON-formatted data.""")
- parser.add_argument("data", metavar="DATA",
- help="Arguments, formatted as a JSON string")
- args = parser.parse_args(argv)
-
- conf = json.loads(args.data)
- return trainola(conf)
-
- if __name__ == "__main__":
- main()
|