Browse Source

More trainola work

tags/nilmtools-1.3.1
Jim Paris 11 years ago
parent
commit
d610deaef0
2 changed files with 32 additions and 72 deletions
  1. +1
    -0
      Makefile
  2. +31
    -72
      nilmtools/trainola.py

+ 1
- 0
Makefile View File

@@ -9,6 +9,7 @@ else
endif

test:
-nilmtool -u http://bucket/nilmdb remove -s min -e max /sharon/prep-a-matches
make -C ../nilmrun

test_trainola:


+ 31
- 72
nilmtools/trainola.py View File

@@ -15,6 +15,7 @@ from numpy.core.umath_tests import inner1d
import nilmrun
from collections import OrderedDict
import sys
import functools

class DataError(ValueError):
pass
@@ -115,7 +116,7 @@ def peak_detect(data, delta):
lookformax = True
return (mins, maxs)

def trainola_matcher(data, args):
def trainola_matcher(data, interval, args, insert_func, final_chunk):
"""Perform cross-correlation match"""
( columns, exemplars ) = args
nrows = data.shape[0]
@@ -170,7 +171,7 @@ def trainola_matcher(data, args):
# Ignore matches that showed up at the very tail of the window,
# and shorten the window accordingly. This is an attempt to avoid
# problems at chunk boundaries.
if point > (valid - 50):
if point > (valid - 50) and not final_chunk:
valid -= 50
break
print "matched", data[point,0], "exemplar", exemplars[e_num].name
@@ -230,86 +231,44 @@ def trainola(conf):
raise DataError(sprintf("Exemplar %d column %s is not "
"available in source data", n, col))

# Process the data in a piecewise manner

# # See which intervals we should processs

# intervals = ( Interval(start, end)
# for (start, end) in
# self._client_src.stream_intervals(
# self.src.path, diffpath = self.dest.path,
# start = self.start, end = self.end) )
# # Optimize intervals: join intervals that are adjacent
# for interval in self.optimize_intervals(intervals):
# yield interval

# def process(main, function, args = None, rows = 200000):
# """Process through the data; similar to nilmtools.Filter.process_numpy"""
# if args is None:
# args = []

# extractor = main.client.stream_extract_numpy
# old_array = np.array([])
# for new_array in extractor(main.stream, main.start, main.end,
# layout = main.info.layout, maxrows = rows):
# # If we still had old data left, combine it
# if old_array.shape[0] != 0:
# array = np.vstack((old_array, new_array))
# else:
# array = new_array

# # Process it
# processed = function(array, args)

# # Save the unprocessed parts
# if processed >= 0:
# old_array = array[processed:]
# else:
# raise Exception(sprintf("%s return value %s must be >= 0",
# str(function), str(processed)))

# # Warn if there's too much data remaining
# if old_array.shape[0] > 3 * rows:
# printf("warning: %d unprocessed rows in buffer\n",
# old_array.shape[0])

# # Handle leftover data
# if old_array.shape[0] != 0:
# processed = function(array, args)



# process(src, src_client, dest, dest_client, dest, main, match, (main.columns, exemplars))
# Figure out which intervals we should process
intervals = ( Interval(s, e) for (s, e) in
src_client.stream_intervals(src_path,
diffpath = dest_path,
start = start, end = end) )
intervals = nilmdb.utils.interval.optimize(intervals)

# Do the processing
rows = 100000
extractor = functools.partial(src_client.stream_extract_numpy,
src.path, layout = src.layout, maxrows = rows)
inserter = functools.partial(dest_client.stream_insert_numpy_context,
dest.path)
for interval in intervals:
printf("Processing interval:\n")
printf(" %s\n", interval.human_string())
nilmtools.filter.process_numpy_interval(
interval, extractor, inserter, rows * 3,
trainola_matcher, (src_columns, exemplars))

return "done"

def main(argv = None):
import simplejson as json
import argparse
import sys

if argv is None:
argv = sys.argv[1:]
if len(argv) != 1:
raise DataError("need one argument, either a dictionary or JSON string")

try:
# Passed in a JSON string (e.g. on the command line)
conf = json.loads(argv[0])
except TypeError:
# Passed in the config dictionary (e.g. from NilmRun)
conf = argv[0]

# If the first parameter is a dictionary (passed in by direct call),
# don't both parsing it as a JSON string.
if len(argv) == 1 and isinstance(argv[0], dict):
return trainola(argv[0])

# Parse command line arguments as text
parser = argparse.ArgumentParser(
formatter_class = argparse.RawDescriptionHelpFormatter,
version = nilmtools.__version__,
description = """Run Trainola using parameters passed in as
JSON-formatted data.""")
parser.add_argument("data", metavar="DATA",
help="Arguments, formatted as a JSON string")
args = parser.parse_args(argv)

conf = json.loads(args.data)
return trainola(conf)

if __name__ == "__main__":


Loading…
Cancel
Save