Browse Source

WIP on trainola improvements

tags/nilmtools-1.3.1
Jim Paris 9 years ago
parent
commit
5a2a32bec5
1 changed files with 142 additions and 85 deletions
  1. +142
    -85
      nilmtools/trainola.py

+ 142
- 85
nilmtools/trainola.py View File

@@ -14,47 +14,49 @@ import scipy.signal
from numpy.core.umath_tests import inner1d
import nilmrun
from collections import OrderedDict
import sys

class DataError(ValueError):
pass

class Data(object):
def __init__(self, name, url, stream, start, end, columns):
"""Initialize, get stream info, check columns"""
self.name = name
self.url = url
self.stream = stream
self.start = start
self.end = end
def build_column_mapping(colinfo, streaminfo):
"""Given the 'columns' list from the JSON data, verify and
pull out a dictionary mapping for the column names/numbers."""
columns = OrderedDict()
for c in colinfo:
if (c['name'] in columns.keys() or
c['index'] in columns.values()):
raise DataError("duplicated columns")
if (c['index'] < 0 or c['index'] >= streaminfo.layout_count):
raise DataError("bad column number")
columns[c['name']] = c['index']
if not len(columns):
raise DataError("no columns")
return columns

class Exemplar(object):
def __init__(self, exinfo, min_rows = 10, max_rows = 100000):
"""Given a dictionary entry from the 'exemplars' input JSON,
verify the stream, columns, etc. Then, fetch all the data
into self.data."""

self.name = exinfo['name']
self.url = exinfo['url']
self.stream = exinfo['stream']
self.start = exinfo['start']
self.end = exinfo['end']
self.dest_column = exinfo['dest_column']

# Get stream info
self.client = nilmdb.client.numpyclient.NumpyClient(url)
self.info = nilmtools.filter.get_stream_info(self.client, stream)
self.client = nilmdb.client.numpyclient.NumpyClient(self.url)
self.info = nilmtools.filter.get_stream_info(self.client, self.stream)

# Build up name => index mapping for the columns
self.columns = OrderedDict()
for c in columns:
if (c['name'] in self.columns.keys() or
c['index'] in self.columns.values()):
raise DataError("duplicated columns")
if (c['index'] < 0 or c['index'] >= self.info.layout_count):
raise DataError("bad column number")
self.columns[c['name']] = c['index']
if not len(self.columns):
raise DataError("no columns")
self.columns = build_column_mapping(exinfo['columns'], self.info)

# Count points
self.count = self.client.stream_count(self.stream, self.start, self.end)

def __str__(self):
return sprintf("%-20s: %s%s, %s rows",
self.name, self.stream, str(self.columns.keys()),
self.count)

def fetch(self, min_rows = 10, max_rows = 100000):
"""Fetch all the data into self.data. This is intended for
exemplars, and can only handle a relatively small number of
rows"""
# Verify count
if self.count == 0:
raise DataError("No data in this exemplar!")
@@ -83,39 +85,10 @@ class Data(object):
# Ensure a minimum (nonzero) scale and convert to list
self.scale = np.maximum(self.scale, [1e-9]).tolist()

def process(main, function, args = None, rows = 200000):
"""Process through the data; similar to nilmtools.Filter.process_numpy"""
if args is None:
args = []

extractor = main.client.stream_extract_numpy
old_array = np.array([])
for new_array in extractor(main.stream, main.start, main.end,
layout = main.info.layout, maxrows = rows):
# If we still had old data left, combine it
if old_array.shape[0] != 0:
array = np.vstack((old_array, new_array))
else:
array = new_array

# Process it
processed = function(array, args)

# Save the unprocessed parts
if processed >= 0:
old_array = array[processed:]
else:
raise Exception(sprintf("%s return value %s must be >= 0",
str(function), str(processed)))

# Warn if there's too much data remaining
if old_array.shape[0] > 3 * rows:
printf("warning: %d unprocessed rows in buffer\n",
old_array.shape[0])

# Handle leftover data
if old_array.shape[0] != 0:
processed = function(array, args)
def __str__(self):
return sprintf("\"%s\" %s [%s] %s rows",
self.name, self.stream, ",".join(self.columns.keys()),
self.count)

def peak_detect(data, delta):
"""Simple min/max peak detection algorithm, taken from my code
@@ -142,7 +115,7 @@ def peak_detect(data, delta):
lookformax = True
return (mins, maxs)

def match(data, args):
def trainola_matcher(data, args):
"""Perform cross-correlation match"""
( columns, exemplars ) = args
nrows = data.shape[0]
@@ -209,51 +182,135 @@ def match(data, args):
return max(valid, 0)

def trainola(conf):
print "Trainola", nilmtools.__version__

# Load main stream data
print "Loading stream data"
main = Data(None, conf['url'], conf['stream'],
conf['start'], conf['end'], conf['columns'])
url = conf['url']
src_path = conf['stream']
dest_path = conf['dest_stream']
start = conf['start']
end = conf['end']

# Get info for the src and dest streams
src_client = nilmdb.client.numpyclient.NumpyClient(url)
src = nilmtools.filter.get_stream_info(src_client, src_path)
if not src:
raise DataError("source path '" + src_path + "' does not exist")
src_columns = build_column_mapping(conf['columns'], src)

dest_client = nilmdb.client.numpyclient.NumpyClient(url)
dest = nilmtools.filter.get_stream_info(dest_client, dest_path)
if not dest:
raise DataError("destination path '" + dest_path + "' does not exist")

printf("Source:\n")
printf(" %s [%s]\n", src.path, ",".join(src_columns.keys()))
printf("Destination:\n")
printf(" %s (%s columns)\n", dest.path, dest.layout_count)

# Pull in the exemplar data
exemplars = []
for n, e in enumerate(conf['exemplars']):
print sprintf("Loading exemplar %d: %s", n, e['name'])
ex = Data(e['name'], e['url'], e['stream'],
e['start'], e['end'], e['columns'])
ex.fetch()
exemplars.append(ex)
for n, exinfo in enumerate(conf['exemplars']):
printf("Loading exemplar %d:\n", n)
e = Exemplar(exinfo)
col = e.dest_column
if col < 0 or col >= dest.layout_count:
raise DataError(sprintf("bad destination column number %d\n" +
"dest stream only has 0 through %d",
col, dest.layout_count - 1))
printf(" %s, output column %d\n", str(e), col)
exemplars.append(e)
if len(exemplars) == 0:
raise DataError("missing exemplars")

# Verify that the exemplar columns are all represented in the main data
for n, ex in enumerate(exemplars):
for col in ex.columns:
if col not in main.columns:
if col not in src_columns:
raise DataError(sprintf("Exemplar %d column %s is not "
"available in main data", n, col))

# Process the main data
process(main, match, (main.columns, exemplars))
"available in source data", n, col))

# Process the data in a piecewise manner

# # See which intervals we should processs

# intervals = ( Interval(start, end)
# for (start, end) in
# self._client_src.stream_intervals(
# self.src.path, diffpath = self.dest.path,
# start = self.start, end = self.end) )
# # Optimize intervals: join intervals that are adjacent
# for interval in self.optimize_intervals(intervals):
# yield interval

# def process(main, function, args = None, rows = 200000):
# """Process through the data; similar to nilmtools.Filter.process_numpy"""
# if args is None:
# args = []

# extractor = main.client.stream_extract_numpy
# old_array = np.array([])
# for new_array in extractor(main.stream, main.start, main.end,
# layout = main.info.layout, maxrows = rows):
# # If we still had old data left, combine it
# if old_array.shape[0] != 0:
# array = np.vstack((old_array, new_array))
# else:
# array = new_array

# # Process it
# processed = function(array, args)

# # Save the unprocessed parts
# if processed >= 0:
# old_array = array[processed:]
# else:
# raise Exception(sprintf("%s return value %s must be >= 0",
# str(function), str(processed)))

# # Warn if there's too much data remaining
# if old_array.shape[0] > 3 * rows:
# printf("warning: %d unprocessed rows in buffer\n",
# old_array.shape[0])

# # Handle leftover data
# if old_array.shape[0] != 0:
# processed = function(array, args)



# process(src, src_client, dest, dest_client, dest, main, match, (main.columns, exemplars))

return "done"

filterfunc = trainola

def main(argv = None):
import simplejson as json
import argparse
import sys

if argv is None:
argv = sys.argv[1:]

# If the first parameter is a dictionary (passed in by direct call),
# don't both parsing it as a JSON string.
if len(argv) == 1 and isinstance(argv[0], dict):
return trainola(argv[0])

# Parse command line arguments as text
parser = argparse.ArgumentParser(
formatter_class = argparse.RawDescriptionHelpFormatter,
version = nilmrun.__version__,
version = nilmtools.__version__,
description = """Run Trainola using parameters passed in as
JSON-formatted data.""")
parser.add_argument("file", metavar="FILE", nargs="?",
type=argparse.FileType('r'), default=sys.stdin)
parser.add_argument("data", metavar="DATA",
help="Arguments, formatted as a JSON string")
args = parser.parse_args(argv)

conf = json.loads(args.file.read())
result = trainola(conf)
print json.dumps(result, sort_keys = True, indent = 2 * ' ')
conf = json.loads(args.data)
return trainola(conf)

if __name__ == "__main__":
main()


Loading…
Cancel
Save