|
- #!/usr/bin/python
-
- from nilmdb.utils.printf import *
- import nilmdb.client
- import nilmtools.filter
- from nilmdb.utils.time import (timestamp_to_human,
- timestamp_to_seconds,
- seconds_to_timestamp)
- from nilmdb.utils import datetime_tz
- from nilmdb.utils.interval import Interval
-
- import numpy as np
- import scipy
- import scipy.signal
- from numpy.core.umath_tests import inner1d
- import nilmrun
- from collections import OrderedDict
- import sys
- import time
- import functools
- import collections
-
- class DataError(ValueError):
- pass
-
- def build_column_mapping(colinfo, streaminfo):
- """Given the 'columns' list from the JSON data, verify and
- pull out a dictionary mapping for the column names/numbers."""
- columns = OrderedDict()
- for c in colinfo:
- col_num = c['index'] + 1 # skip timestamp
- if (c['name'] in columns.keys() or col_num in columns.values()):
- raise DataError("duplicated columns")
- if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
- raise DataError("bad column number")
- columns[c['name']] = col_num
- if not len(columns):
- raise DataError("no columns")
- return columns
-
- class Exemplar(object):
- def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
- """Given a dictionary entry from the 'exemplars' input JSON,
- verify the stream, columns, etc. Then, fetch all the data
- into self.data."""
-
- self.name = exinfo['name']
- self.url = exinfo['url']
- self.stream = exinfo['stream']
- self.start = exinfo['start']
- self.end = exinfo['end']
- self.dest_column = exinfo['dest_column']
-
- # Get stream info
- self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
- self.info = nilmtools.filter.get_stream_info(self.client, self.stream)
- if not self.info:
- raise DataError(sprintf("exemplar stream '%s' does not exist " +
- "on server '%s'", self.stream, self.url))
-
- # Build up name => index mapping for the columns
- self.columns = build_column_mapping(exinfo['columns'], self.info)
-
- # Count points
- self.count = self.client.stream_count(self.stream, self.start, self.end)
-
- # Verify count
- if self.count == 0:
- raise DataError("No data in this exemplar!")
- if self.count < min_rows:
- raise DataError("Too few data points: " + str(self.count))
- if self.count > max_rows:
- raise DataError("Too many data points: " + str(self.count))
-
- # Extract the data
- datagen = self.client.stream_extract_numpy(self.stream,
- self.start, self.end,
- self.info.layout,
- maxrows = self.count)
- self.data = list(datagen)[0]
-
- # Extract just the columns that were specified in self.columns,
- # skipping the timestamp.
- extract_columns = [ value for (key, value) in self.columns.items() ]
- self.data = self.data[:,extract_columns]
-
- # Fix the column indices in e.columns, since we removed/reordered
- # columns in self.data
- for n, k in enumerate(self.columns):
- self.columns[k] = n
-
- # Subtract the means from each column
- self.data = self.data - self.data.mean(axis=0)
-
- # Get scale factors for each column by computing dot product
- # of each column with itself.
- self.scale = inner1d(self.data.T, self.data.T)
-
- # Ensure a minimum (nonzero) scale and convert to list
- self.scale = np.maximum(self.scale, [1e-9]).tolist()
-
- def __str__(self):
- return sprintf("\"%s\" %s [%s] %s rows",
- self.name, self.stream, ",".join(self.columns.keys()),
- self.count)
-
- def peak_detect(data, delta):
- """Simple min/max peak detection algorithm, taken from my code
- in the disagg.m from the 10-8-5 paper.
-
- Returns an array of peaks: each peak is a tuple
- (n, p, is_max)
- where n is the row number in 'data', and p is 'data[n]',
- and is_max is True if this is a maximum, False if it's a minimum,
- """
- peaks = [];
- cur_min = (None, np.inf)
- cur_max = (None, -np.inf)
- lookformax = False
- for (n, p) in enumerate(data):
- if p > cur_max[1]:
- cur_max = (n, p)
- if p < cur_min[1]:
- cur_min = (n, p)
- if lookformax:
- if p < (cur_max[1] - delta):
- peaks.append((cur_max[0], cur_max[1], True))
- cur_min = (n, p)
- lookformax = False
- else:
- if p > (cur_min[1] + delta):
- peaks.append((cur_min[0], cur_min[1], False))
- cur_max = (n, p)
- lookformax = True
- return peaks
-
- def timestamp_to_short_human(timestamp):
- dt = datetime_tz.datetime_tz.fromtimestamp(timestamp_to_seconds(timestamp))
- return dt.strftime("%H:%M:%S")
-
- def trainola_matcher(data, interval, args, insert_func, final_chunk):
- """Perform cross-correlation match"""
- ( src_columns, dest_count, exemplars ) = args
- nrows = data.shape[0]
-
- # We want at least 10% more points than the widest exemplar.
- widest = max([ x.count for x in exemplars ])
- if (widest * 1.1) > nrows:
- return 0
-
- # This is how many points we'll consider valid in the
- # cross-correlation.
- valid = nrows + 1 - widest
- matches = collections.defaultdict(list)
-
- # Try matching against each of the exemplars
- for e in exemplars:
- corrs = []
-
- # Compute cross-correlation for each column
- for col_name in e.columns:
- a = data[:, src_columns[col_name]]
- b = e.data[:, e.columns[col_name]]
- corr = scipy.signal.fftconvolve(a, np.flipud(b), 'valid')[0:valid]
-
- # Scale by the norm of the exemplar
- corr = corr / e.scale[e.columns[col_name]]
- corrs.append(corr)
-
- # Find the peaks using the column with the largest amplitude
- biggest = e.scale.index(max(e.scale))
- peaks = peak_detect(corrs[biggest], 0.1)
-
- # To try to reduce false positives, discard peaks where
- # there's a higher-magnitude peak (either min or max) within
- # one exemplar width nearby.
- good_peak_locations = []
- for (i, (n, p, is_max)) in enumerate(peaks):
- if not is_max:
- continue
- ok = True
- # check up to 'e.count' rows before this one
- j = i-1
- while ok and j >= 0 and peaks[j][0] > (n - e.count):
- if abs(peaks[j][1]) > abs(p):
- ok = False
- j -= 1
-
- # check up to 'e.count' rows after this one
- j = i+1
- while ok and j < len(peaks) and peaks[j][0] < (n + e.count):
- if abs(peaks[j][1]) > abs(p):
- ok = False
- j += 1
-
- if ok:
- good_peak_locations.append(n)
-
- # Now look at all good peaks
- for row in good_peak_locations:
- # Correlation for each column must be close enough to 1.
- for (corr, scale) in zip(corrs, e.scale):
- # The accepted distance from 1 is based on the relative
- # amplitude of the column. Use a linear mapping:
- # scale 1.0 -> distance 0.1
- # scale 0.0 -> distance 1.0
- distance = 1 - 0.9 * (scale / e.scale[biggest])
- if abs(corr[row] - 1) > distance:
- # No match
- break
- else:
- # Successful match
- matches[row].append(e)
-
- # Insert matches into destination stream.
- matched_rows = sorted(matches.keys())
- out = np.zeros((len(matched_rows), dest_count + 1))
-
- for n, row in enumerate(matched_rows):
- # Fill timestamp
- out[n][0] = data[row, 0]
-
- # Mark matched exemplars
- for exemplar in matches[row]:
- out[n, exemplar.dest_column + 1] = 1.0
-
- # Insert it
- insert_func(out)
-
- # Return how many rows we processed
- valid = max(valid, 0)
- printf(" [%s] matched %d exemplars in %d rows\n",
- timestamp_to_short_human(data[0][0]), np.sum(out[:,1:]), valid)
- return valid
-
- def trainola(conf):
- print "Trainola", nilmtools.__version__
-
- # Load main stream data
- url = conf['url']
- src_path = conf['stream']
- dest_path = conf['dest_stream']
- start = conf['start']
- end = conf['end']
-
- # Get info for the src and dest streams
- src_client = nilmdb.client.numpyclient.NumpyClient(url)
- src = nilmtools.filter.get_stream_info(src_client, src_path)
- if not src:
- raise DataError("source path '" + src_path + "' does not exist")
- src_columns = build_column_mapping(conf['columns'], src)
-
- dest_client = nilmdb.client.numpyclient.NumpyClient(url)
- dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
- if not dest:
- raise DataError("destination path '" + dest_path + "' does not exist")
-
- printf("Source:\n")
- printf(" %s [%s]\n", src.path, ",".join(src_columns.keys()))
- printf("Destination:\n")
- printf(" %s (%s columns)\n", dest.path, dest.layout_count)
-
- # Pull in the exemplar data
- exemplars = []
- for n, exinfo in enumerate(conf['exemplars']):
- printf("Loading exemplar %d:\n", n)
- e = Exemplar(exinfo)
- col = e.dest_column
- if col < 0 or col >= dest.layout_count:
- raise DataError(sprintf("bad destination column number %d\n" +
- "dest stream only has 0 through %d",
- col, dest.layout_count - 1))
- printf(" %s, output column %d\n", str(e), col)
- exemplars.append(e)
- if len(exemplars) == 0:
- raise DataError("missing exemplars")
-
- # Verify that the exemplar columns are all represented in the main data
- for n, ex in enumerate(exemplars):
- for col in ex.columns:
- if col not in src_columns:
- raise DataError(sprintf("Exemplar %d column %s is not "
- "available in source data", n, col))
-
- # Figure out which intervals we should process
- intervals = ( Interval(s, e) for (s, e) in
- src_client.stream_intervals(src_path,
- diffpath = dest_path,
- start = start, end = end) )
- intervals = nilmdb.utils.interval.optimize(intervals)
-
- # Do the processing
- rows = 100000
- extractor = functools.partial(src_client.stream_extract_numpy,
- src.path, layout = src.layout, maxrows = rows)
- inserter = functools.partial(dest_client.stream_insert_numpy_context,
- dest.path)
- start = time.time()
- processed_time = 0
- printf("Processing intervals:\n")
- for interval in intervals:
- printf("%s\n", interval.human_string())
- nilmtools.filter.process_numpy_interval(
- interval, extractor, inserter, rows * 3,
- trainola_matcher, (src_columns, dest.layout_count, exemplars))
- processed_time += (timestamp_to_seconds(interval.end) -
- timestamp_to_seconds(interval.start))
- elapsed = max(time.time() - start, 1e-3)
-
- printf("Done. Processed %.2f seconds per second.\n",
- processed_time / elapsed)
-
- def main(argv = None):
- import simplejson as json
- import sys
-
- if argv is None:
- argv = sys.argv[1:]
- if len(argv) != 1:
- raise DataError("need one argument, either a dictionary or JSON string")
-
- try:
- # Passed in a JSON string (e.g. on the command line)
- conf = json.loads(argv[0])
- except TypeError as e:
- # Passed in the config dictionary (e.g. from NilmRun)
- conf = argv[0]
-
- return trainola(conf)
-
- if __name__ == "__main__":
- main()
|