Compare commits

...

59 Commits

Author SHA1 Message Date
2d45466f66 Print version at server startup 2013-03-04 15:43:45 -05:00
c6a0e6e96f More complete CORS handling, including preflight requests (hopefully) 2013-03-04 15:40:35 -05:00
79755dc624 Fix Allow: header by switching to cherrypy's built in tools.allow().
Replaces custom tools.allow_methods which didn't return the Allow: header.
2013-03-04 14:08:37 -05:00
c512631184 bulkdata: Build up rows and write to disk all at once 2013-03-03 12:03:44 -05:00
19d27c31bc Fix streaming requests like stream_extract 2013-03-03 11:37:47 -05:00
28310fe886 Add test for extents 2013-03-02 15:19:25 -05:00
1ccc2bce7e Add commandline support for listing extents 2013-03-02 15:19:19 -05:00
00237e30b2 Add "extent" option to stream_list in client, server, and nilmdb 2013-03-02 15:18:54 -05:00
521ff88f7c Support 'nilmtool help command' just like 'nilmtool command --help' 2013-03-02 13:56:03 -05:00
64897a1dd1 Change port from 12380 -> 32180 when running tests
This is so tests can be run without interfering with a normal server.
2013-03-02 13:19:44 -05:00
41ce8480bb cmdline: Support NILMDB_URL environment variable for default URL 2013-03-02 13:18:33 -05:00
204a6ecb15 Optimize bulkdata.append() by postponing flushes & mmap resize
Rather than flushing and resizing after each row is written to the
file, have the file object iterate by itself and do all of the
writes.  Only flush and resize the mmap after finishing.  This should
be pretty safe to do, especially since nothing is concurrent at the
moment.
2013-03-01 16:30:49 -05:00
5db3b186a4 Make test_mustclose more complete 2013-03-01 16:30:22 -05:00
fe640cf421 Remove must_close verification wrappers on bulkdata
At this point we know that the close() behavior is correct, so it's
not worth slowing everything down for these checks.
2013-03-01 16:11:44 -05:00
ca67c79fe4 Improve test_layout_speed 2013-03-01 16:04:10 -05:00
8917bcd4bf Fix test case failures due to increased client chunk size 2013-03-01 16:04:00 -05:00
a75ec98673 Slight speed improvements in layout.pyx 2013-03-01 16:03:38 -05:00
e476338d61 Remove outdated numpy dependency 2013-03-01 16:03:19 -05:00
d752b882f2 Bump up block sizes in client
This will help amortize the sqlite synchronization costs.
2013-02-28 21:11:57 -05:00
ade27773e6 Add --nosync option to nilmdb-server script 2013-02-28 20:45:08 -05:00
0c1a1d2388 Fix nilmdb-server script 2013-02-28 18:53:06 -05:00
e3f335dfe5 Move time parsing from cmdline into nilmdb.utils.time 2013-02-28 17:09:26 -05:00
7a191c0ebb Fix versioneer to update versions on install 2013-02-28 14:50:53 -05:00
55bf11e393 Fix error when pyximport is too old 2013-02-26 22:21:23 -05:00
e90dcd10f3 Update README and setup.py with python-requests dependency 2013-02-26 22:00:42 -05:00
7d44f4eaa0 Cleanup Makefile; make tests run through setup.py when outside emacs 2013-02-26 22:00:42 -05:00
f541432d44 Merge branch 'requests' 2013-02-26 21:59:15 -05:00
aa4e32f78a Merge branch 'curl-multi' 2013-02-26 21:59:03 -05:00
2bc1416c00 Merge branch 'fixups' 2013-02-26 21:58:55 -05:00
68bbbf757d Remove nilmdb.utils.urllib
python-requests seems to handle UTF-8 just fine.
2013-02-26 19:46:22 -05:00
3df96fdfdd Reorder code 2013-02-26 19:41:55 -05:00
740ab76eaf Re-add persistent connection test for Requests based httpclient 2013-02-26 19:41:27 -05:00
ce13a47fea Save full response object for tests 2013-02-26 17:45:41 -05:00
50a4a60786 Replace pyCurl with Requests
Only tested with v1.1.0.  It's not clear how well older versions will
work.
2013-02-26 17:45:40 -05:00
14afa02db6 Temporarily remove curl-specific keepalive tests 2013-02-26 17:45:40 -05:00
cc990d6ce4 Test persistent connections 2013-02-26 13:41:40 -05:00
0f5162e0c0 Always use the curl multi interface
.. even for non-generator requests
2013-02-26 13:39:33 -05:00
b26cd52f8c Work around curl multi bug 2013-02-26 13:38:42 -05:00
236d925a1d Make sure we use POST when requested, even if the body is empty 2013-02-25 21:05:01 -05:00
a4a4bc61ba Switch to using pycurl.Multi instead of Iteratorizer 2013-02-25 21:05:01 -05:00
3d82888580 Enforce method types, and require POST for actions that change things.
This is a pretty big change that will render existing clients unable
to modify the database, but it's important that we use POST or PUT
instead of GET for anything that may change state, in case this
is ever put behind a cache.
2013-02-25 21:05:01 -05:00
749b878904 Add an explicit lock to httpclient's public methods
This is to prevent possible reentrancy problems.
2013-02-25 18:06:00 -05:00
f396e3934c Remove cherrypy version check
Dependencies should be handled by installation, not at runtime.
2013-02-25 16:50:19 -05:00
dd7594b5fa Fix issue where PUT responses were being dropped
PUTs generate a "HTTP/1.1 100 Continue" response before the
"HTTP/1.1 200 OK" response, and so we were mistakenly picking up
the 100 status code and not returning any data.  Improve the
header callback to correctly process any number of status codes.
2013-02-23 17:51:59 -05:00
4ac1beee6d layout: allow zero and negative timestamps in parser 2013-02-23 16:58:49 -05:00
8c0ce736d8 Disable use of signals in Curl
Various places suggest that this is needed for better thread-safety,
and the only drawback is that some systems cannot timeout properly on
DNS lookups.
2013-02-23 16:15:28 -05:00
8858c9426f Fix error message text in nilmdb.server.Server 2013-02-23 16:13:47 -05:00
9123ccb583 Merge branch 'decorator-work' 2013-02-23 14:38:36 -05:00
5b0441de6b Give serializer and iteratorizer threads names 2013-02-23 14:28:37 -05:00
317c53ab6f Improve serializer_proxy and verify_thread_proxy
These functions can now take an object or a type (class).

If given an object, they will wrap subsequent calls to that object.
If given a type, they will return an object that can be instantiated
to create a new object, and all calls including __init__ will be
covered by the serialization or thread verification.
2013-02-23 14:28:37 -05:00
7db4411462 Cleanup nilmdb.utils.must_close a bit 2013-02-23 11:28:03 -05:00
422317850e Replace threadsafety class decorator version, add explicit proxy version
Like the serializer changes, the class decorator was too fragile.
2013-02-23 11:25:40 -05:00
965537d8cb Implement verify_thread_safety to check for unsafe access patterns
Occasional segfaults may be the result of performing thread-unsafe
operations.  This class decorator verifies that all of its methods
are called in a thread-safe manner.

It can separately warn about:
- two threads calling methods in a function (the kind of thing sqlite
  doesn't like)
- recursion
- concurrency (two different threads functions at the same time)
2013-02-23 11:25:02 -05:00
0dcdec5949 Turn on sqlite thread safety checks -- serializer should fully protect it 2013-02-23 11:25:01 -05:00
7fce305a1d Make server check that the db object has been wrapped in a serializer
It's only the server that calls it in multiple threads.
2013-02-23 11:25:01 -05:00
dfbbe23512 Switch to explicitly wrapping nilmdb objects in a serializer_proxy
This is quite a bit simpler than the class decorator method, so it
may be more reliable.
2013-02-23 11:23:54 -05:00
7761a91242 Remove class decorator version of the serializer; it's too fragile 2013-02-23 11:23:54 -05:00
9b06e46bf1 Add back a proxy version of the Serializer, which is much simpler. 2013-02-23 11:23:54 -05:00
171e6f1871 Replace "serializer" function with a "serialized" decorator
This decorator makes a class always be serialized, including its
instantiation, in a separate thread.  This is an improvement over
the old Serializer() object wrapper, which didn't put the
instantiation into the new thread.
2013-02-23 11:23:54 -05:00
39 changed files with 904 additions and 560 deletions

View File

@@ -21,7 +21,13 @@ lint:
pylint --rcfile=.pylintrc nilmdb pylint --rcfile=.pylintrc nilmdb
test: test:
ifeq ($(INSIDE_EMACS), t)
# Use the slightly more flexible script
python tests/runtests.py python tests/runtests.py
else
# Let setup.py check dependencies, build stuff, and run the test
python setup.py nosetests
endif
clean:: clean::
find . -name '*pyc' | xargs rm -f find . -name '*pyc' | xargs rm -f
@@ -33,4 +39,4 @@ clean::
gitclean:: gitclean::
git clean -dXf git clean -dXf
.PHONY: all build dist sdist install docs lint test clean .PHONY: all version build dist sdist install docs lint test clean

View File

@@ -7,11 +7,15 @@ Prerequisites:
sudo apt-get install python2.7 python2.7-dev python-setuptools cython sudo apt-get install python2.7 python2.7-dev python-setuptools cython
# Base NilmDB dependencies # Base NilmDB dependencies
sudo apt-get install python-cherrypy3 python-decorator python-simplejson python-pycurl python-dateutil python-tz python-psutil sudo apt-get install python-cherrypy3 python-decorator python-simplejson
sudo apt-get install python-requests python-dateutil python-tz python-psutil
# Tools for running tests # Tools for running tests
sudo apt-get install python-nose python-coverage sudo apt-get install python-nose python-coverage
Test:
python setup.py nosetests
Install: Install:
python setup.py install python setup.py install

View File

@@ -52,12 +52,14 @@ class Client(object):
as a dictionary.""" as a dictionary."""
return self.http.get("dbinfo") return self.http.get("dbinfo")
def stream_list(self, path = None, layout = None): def stream_list(self, path = None, layout = None, extent = False):
params = {} params = {}
if path is not None: if path is not None:
params["path"] = path params["path"] = path
if layout is not None: if layout is not None:
params["layout"] = layout params["layout"] = layout
if extent:
params["extent"] = 1
return self.http.get("stream/list", params) return self.http.get("stream/list", params)
def stream_get_metadata(self, path, keys = None): def stream_get_metadata(self, path, keys = None):
@@ -73,7 +75,7 @@ class Client(object):
"path": path, "path": path,
"data": self._json_param(data) "data": self._json_param(data)
} }
return self.http.get("stream/set_metadata", params) return self.http.post("stream/set_metadata", params)
def stream_update_metadata(self, path, data): def stream_update_metadata(self, path, data):
"""Update stream metadata from a dictionary""" """Update stream metadata from a dictionary"""
@@ -81,18 +83,18 @@ class Client(object):
"path": path, "path": path,
"data": self._json_param(data) "data": self._json_param(data)
} }
return self.http.get("stream/update_metadata", params) return self.http.post("stream/update_metadata", params)
def stream_create(self, path, layout): def stream_create(self, path, layout):
"""Create a new stream""" """Create a new stream"""
params = { "path": path, params = { "path": path,
"layout" : layout } "layout" : layout }
return self.http.get("stream/create", params) return self.http.post("stream/create", params)
def stream_destroy(self, path): def stream_destroy(self, path):
"""Delete stream and its contents""" """Delete stream and its contents"""
params = { "path": path } params = { "path": path }
return self.http.get("stream/destroy", params) return self.http.post("stream/destroy", params)
def stream_remove(self, path, start = None, end = None): def stream_remove(self, path, start = None, end = None):
"""Remove data from the specified time range""" """Remove data from the specified time range"""
@@ -103,7 +105,7 @@ class Client(object):
params["start"] = float_to_string(start) params["start"] = float_to_string(start)
if end is not None: if end is not None:
params["end"] = float_to_string(end) params["end"] = float_to_string(end)
return self.http.get("stream/remove", params) return self.http.post("stream/remove", params)
@contextlib.contextmanager @contextlib.contextmanager
def stream_insert_context(self, path, start = None, end = None): def stream_insert_context(self, path, start = None, end = None):
@@ -156,7 +158,7 @@ class Client(object):
params["start"] = float_to_string(start) params["start"] = float_to_string(start)
if end is not None: if end is not None:
params["end"] = float_to_string(end) params["end"] = float_to_string(end)
return self.http.get_gen("stream/intervals", params, retjson = True) return self.http.get_gen("stream/intervals", params)
def stream_extract(self, path, start = None, end = None, count = False): def stream_extract(self, path, start = None, end = None, count = False):
""" """
@@ -176,8 +178,7 @@ class Client(object):
params["end"] = float_to_string(end) params["end"] = float_to_string(end)
if count: if count:
params["count"] = 1 params["count"] = 1
return self.http.get_gen("stream/extract", params)
return self.http.get_gen("stream/extract", params, retjson = False)
def stream_count(self, path, start = None, end = None): def stream_count(self, path, start = None, end = None):
""" """
@@ -221,7 +222,7 @@ class StreamInserter(object):
# These are soft limits -- actual data might be rounded up. # These are soft limits -- actual data might be rounded up.
# We send when we have a certain amount of data queued, or # We send when we have a certain amount of data queued, or
# when a certain amount of time has passed since the last send. # when a certain amount of time has passed since the last send.
_max_data = 1048576 _max_data = 2 * 1024 * 1024
_max_time = 30 _max_time = 30
# Delta to add to the final timestamp, if "end" wasn't given # Delta to add to the final timestamp, if "end" wasn't given

View File

@@ -6,8 +6,7 @@ from nilmdb.client.errors import ClientError, ServerError, Error
import simplejson as json import simplejson as json
import urlparse import urlparse
import pycurl import requests
import cStringIO
class HTTPClient(object): class HTTPClient(object):
"""Class to manage and perform HTTP requests from the client""" """Class to manage and perform HTTP requests from the client"""
@@ -19,40 +18,19 @@ class HTTPClient(object):
if '://' not in reparsed: if '://' not in reparsed:
reparsed = urlparse.urlparse("http://" + baseurl).geturl() reparsed = urlparse.urlparse("http://" + baseurl).geturl()
self.baseurl = reparsed self.baseurl = reparsed
self.curl = pycurl.Curl()
self.curl.setopt(pycurl.SSL_VERIFYHOST, 2)
self.curl.setopt(pycurl.FOLLOWLOCATION, 1)
self.curl.setopt(pycurl.MAXREDIRS, 5)
self._setup_url()
def _setup_url(self, url = "", params = ""): # Build Requests session object, enable SSL verification
url = urlparse.urljoin(self.baseurl, url) self.session = requests.Session()
if params: self.session.verify = True
url = urlparse.urljoin(
url, "?" + nilmdb.utils.urllib.urlencode(params))
self.curl.setopt(pycurl.URL, url)
self.url = url
def _check_busy_and_set_upload(self, upload): # Saved response, so that tests can verify a few things.
"""Sets the pycurl.UPLOAD option, but also raises a more self._last_response = {}
friendly exception if the client is already serving a request."""
try:
self.curl.setopt(pycurl.UPLOAD, upload)
except pycurl.error as e:
if "is currently running" in str(e):
raise Exception("Client is already performing a request, and "
"nesting calls is not supported.")
else: # pragma: no cover (shouldn't happen)
raise
def _check_error(self, body = None): def _handle_error(self, url, code, body):
code = self.curl.getinfo(pycurl.RESPONSE_CODE)
if code == 200:
return
# Default variables for exception. We use the entire body as # Default variables for exception. We use the entire body as
# the default message, in case we can't extract it from a JSON # the default message, in case we can't extract it from a JSON
# response. # response.
args = { "url" : self.url, args = { "url" : url,
"status" : str(code), "status" : str(code),
"message" : body, "message" : body,
"traceback" : None } "traceback" : None }
@@ -76,133 +54,68 @@ class HTTPClient(object):
else: else:
raise Error(**args) raise Error(**args)
def _req_generator(self, url, params):
"""
Like self._req(), but runs the perform in a separate thread.
It returns a generator that spits out arbitrary-sized chunks
of the resulting data, instead of using the WRITEFUNCTION
callback.
"""
self._setup_url(url, params)
self._status = None
error_body = ""
self._headers = ""
def header_callback(data):
if self._status is None:
self._status = int(data.split(" ")[1])
self._headers += data
self.curl.setopt(pycurl.HEADERFUNCTION, header_callback)
def perform(callback):
self.curl.setopt(pycurl.WRITEFUNCTION, callback)
self.curl.perform()
try:
with nilmdb.utils.Iteratorizer(perform, curl_hack = True) as it:
for i in it:
if self._status == 200:
# If we had a 200 response, yield the data to caller.
yield i
else:
# Otherwise, collect it into an error string.
error_body += i
except pycurl.error as e:
raise ServerError(status = "502 Error",
url = self.url,
message = e[1])
# Raise an exception if there was an error
self._check_error(error_body)
def _req(self, url, params):
"""
GET or POST that returns raw data. Returns the body
data as a string, or raises an error if it contained an error.
"""
self._setup_url(url, params)
body = cStringIO.StringIO()
self.curl.setopt(pycurl.WRITEFUNCTION, body.write)
self._headers = ""
def header_callback(data):
self._headers += data
self.curl.setopt(pycurl.HEADERFUNCTION, header_callback)
try:
self.curl.perform()
except pycurl.error as e:
raise ServerError(status = "502 Error",
url = self.url,
message = e[1])
body_str = body.getvalue()
# Raise an exception if there was an error
self._check_error(body_str)
return body_str
def close(self): def close(self):
self.curl.close() self.session.close()
def _iterate_lines(self, it): def _do_req(self, method, url, query_data, body_data, stream):
url = urlparse.urljoin(self.baseurl, url)
try:
response = self.session.request(method, url,
params = query_data,
data = body_data,
stream = stream)
except requests.RequestException as e:
raise ServerError(status = "502 Error", url = url,
message = str(e.message))
if response.status_code != 200:
self._handle_error(url, response.status_code, response.content)
self._last_response = response
if response.headers["content-type"] in ("application/json",
"application/x-json-stream"):
return (response, True)
else:
return (response, False)
# Normal versions that return data directly
def _req(self, method, url, query = None, body = None):
""" """
Given an iterator that returns arbitrarily-sized chunks Make a request and return the body data as a string or parsed
of data, return '\n'-delimited lines of text JSON object, or raise an error if it contained an error.
""" """
partial = "" (response, isjson) = self._do_req(method, url, query, body, False)
for chunk in it: if isjson:
partial += chunk return json.loads(response.content)
lines = partial.split("\n") return response.content
for line in lines[0:-1]:
yield line
partial = lines[-1]
if partial != "":
yield partial
# Non-generator versions def get(self, url, params = None):
def _doreq(self, url, params, retjson): """Simple GET (parameters in URL)"""
return self._req("GET", url, params, None)
def post(self, url, params = None):
"""Simple POST (parameters in body)"""
return self._req("POST", url, None, params)
def put(self, url, data, params = None):
"""Simple PUT (parameters in URL, data in body)"""
return self._req("PUT", url, params, data)
# Generator versions that return data one line at a time.
def _req_gen(self, method, url, query = None, body = None):
""" """
Perform a request, and return the body. Make a request and return a generator that gives back strings
or JSON decoded lines of the body data, or raise an error if
url: URL to request (relative to baseurl) it contained an eror.
params: dictionary of query parameters
retjson: expect JSON and return python objects instead of string
""" """
out = self._req(url, params) (response, isjson) = self._do_req(method, url, query, body, True)
if retjson: for line in response.iter_lines():
return json.loads(out) if isjson:
return out
def get(self, url, params = None, retjson = True):
"""Simple GET"""
self._check_busy_and_set_upload(0)
return self._doreq(url, params, retjson)
def put(self, url, postdata, params = None, retjson = True):
"""Simple PUT"""
self._check_busy_and_set_upload(1)
self._setup_url(url, params)
data = cStringIO.StringIO(postdata)
self.curl.setopt(pycurl.READFUNCTION, data.read)
return self._doreq(url, params, retjson)
# Generator versions
def _doreq_gen(self, url, params, retjson):
"""
Perform a request, and return lines of the body in a generator.
url: URL to request (relative to baseurl)
params: dictionary of query parameters
retjson: expect JSON and yield python objects instead of strings
"""
for line in self._iterate_lines(self._req_generator(url, params)):
if retjson:
yield json.loads(line) yield json.loads(line)
else: else:
yield line yield line
def get_gen(self, url, params = None, retjson = True): def get_gen(self, url, params = None):
"""Simple GET, returning a generator""" """Simple GET (parameters in URL) returning a generator"""
self._check_busy_and_set_upload(0) return self._req_gen("GET", url, params)
return self._doreq_gen(url, params, retjson)
def put_gen(self, url, postdata, params = None, retjson = True): # Not much use for a POST or PUT generator, since they don't
"""Simple PUT, returning a generator""" # return much data.
self._check_busy_and_set_upload(1)
self._setup_url(url, params)
data = cStringIO.StringIO(postdata)
self.curl.setopt(pycurl.READFUNCTION, data.read)
return self._doreq_gen(url, params, retjson)

View File

@@ -3,16 +3,17 @@
import nilmdb import nilmdb
from nilmdb.utils.printf import * from nilmdb.utils.printf import *
from nilmdb.utils import datetime_tz from nilmdb.utils import datetime_tz
import nilmdb.utils.time
import sys import sys
import re import os
import argparse import argparse
from argparse import ArgumentDefaultsHelpFormatter as def_form from argparse import ArgumentDefaultsHelpFormatter as def_form
# Valid subcommands. Defined in separate files just to break # Valid subcommands. Defined in separate files just to break
# things up -- they're still called with Cmdline as self. # things up -- they're still called with Cmdline as self.
subcommands = [ "info", "create", "list", "metadata", "insert", "extract", subcommands = [ "help", "info", "create", "list", "metadata",
"remove", "destroy" ] "insert", "extract", "remove", "destroy" ]
# Import the subcommand modules # Import the subcommand modules
subcmd_mods = {} subcmd_mods = {}
@@ -29,67 +30,17 @@ class Cmdline(object):
def __init__(self, argv = None): def __init__(self, argv = None):
self.argv = argv or sys.argv[1:] self.argv = argv or sys.argv[1:]
self.client = None self.client = None
self.def_url = os.environ.get("NILMDB_URL", "http://localhost:12380")
self.subcmd = {}
def arg_time(self, toparse): def arg_time(self, toparse):
"""Parse a time string argument""" """Parse a time string argument"""
try: try:
return self.parse_time(toparse).totimestamp() return nilmdb.utils.time.parse_time(toparse).totimestamp()
except ValueError as e: except ValueError as e:
raise argparse.ArgumentTypeError(sprintf("%s \"%s\"", raise argparse.ArgumentTypeError(sprintf("%s \"%s\"",
str(e), toparse)) str(e), toparse))
def parse_time(self, toparse):
"""
Parse a free-form time string and return a datetime_tz object.
If the string doesn't contain a timestamp, the current local
timezone is assumed (e.g. from the TZ env var).
"""
# If string isn't "now" and doesn't contain at least 4 digits,
# consider it invalid. smartparse might otherwise accept
# empty strings and strings with just separators.
if toparse != "now" and len(re.findall(r"\d", toparse)) < 4:
raise ValueError("not enough digits for a timestamp")
# Try to just parse the time as given
try:
return datetime_tz.datetime_tz.smartparse(toparse)
except ValueError:
pass
# Try to extract a substring in a condensed format that we expect
# to see in a filename or header comment
res = re.search(r"(^|[^\d])(" # non-numeric or SOL
r"(199\d|2\d\d\d)" # year
r"[-/]?" # separator
r"(0[1-9]|1[012])" # month
r"[-/]?" # separator
r"([012]\d|3[01])" # day
r"[-T ]?" # separator
r"([01]\d|2[0-3])" # hour
r"[:]?" # separator
r"([0-5]\d)" # minute
r"[:]?" # separator
r"([0-5]\d)?" # second
r"([-+]\d\d\d\d)?" # timezone
r")", toparse)
if res is not None:
try:
return datetime_tz.datetime_tz.smartparse(res.group(2))
except ValueError:
pass
# Could also try to successively parse substrings, but let's
# just give up for now.
raise ValueError("unable to parse timestamp")
def time_string(self, timestamp):
"""
Convert a Unix timestamp to a string for printing, using the
local timezone for display (e.g. from the TZ env var).
"""
dt = datetime_tz.datetime_tz.fromtimestamp(timestamp)
return dt.strftime("%a, %d %b %Y %H:%M:%S.%f %z")
def parser_setup(self): def parser_setup(self):
self.parser = JimArgumentParser(add_help = False, self.parser = JimArgumentParser(add_help = False,
formatter_class = def_form) formatter_class = def_form)
@@ -102,18 +53,17 @@ class Cmdline(object):
group = self.parser.add_argument_group("Server") group = self.parser.add_argument_group("Server")
group.add_argument("-u", "--url", action="store", group.add_argument("-u", "--url", action="store",
default="http://localhost:12380/", default=self.def_url,
help="NilmDB server URL (default: %(default)s)") help="NilmDB server URL (default: %(default)s)")
sub = self.parser.add_subparsers(title="Commands", sub = self.parser.add_subparsers(
dest="command", title="Commands", dest="command",
description="Specify --help after " description="Use 'help command' or 'command --help' for more "
"the command for command-specific " "details on a particular command.")
"options.")
# Set up subcommands (defined in separate files) # Set up subcommands (defined in separate files)
for cmd in subcommands: for cmd in subcommands:
subcmd_mods[cmd].setup(self, sub) self.subcmd[cmd] = subcmd_mods[cmd].setup(self, sub)
def die(self, formatstr, *args): def die(self, formatstr, *args):
fprintf(sys.stderr, formatstr + "\n", *args) fprintf(sys.stderr, formatstr + "\n", *args)
@@ -136,11 +86,13 @@ class Cmdline(object):
self.client = nilmdb.Client(self.args.url) self.client = nilmdb.Client(self.args.url)
# Make a test connection to make sure things work # Make a test connection to make sure things work,
try: # unless the particular command requests that we don't.
server_version = self.client.version() if "no_test_connect" not in self.args:
except nilmdb.client.Error as e: try:
self.die("error connecting to server: %s", str(e)) server_version = self.client.version()
except nilmdb.client.Error as e:
self.die("error connecting to server: %s", str(e))
# Now dispatch client request to appropriate function. Parser # Now dispatch client request to appropriate function. Parser
# should have ensured that we don't have any unknown commands # should have ensured that we don't have any unknown commands

View File

@@ -26,6 +26,7 @@ Layout types are of the format: type_count
help="Path (in database) of new stream, e.g. /foo/bar") help="Path (in database) of new stream, e.g. /foo/bar")
group.add_argument("layout", group.add_argument("layout",
help="Layout type for new stream, e.g. float32_8") help="Layout type for new stream, e.g. float32_8")
return cmd
def cmd_create(self): def cmd_create(self):
"""Create new stream""" """Create new stream"""

View File

@@ -16,6 +16,7 @@ def setup(self, sub):
group = cmd.add_argument_group("Required arguments") group = cmd.add_argument_group("Required arguments")
group.add_argument("path", group.add_argument("path",
help="Path of the stream to delete, e.g. /foo/bar") help="Path of the stream to delete, e.g. /foo/bar")
return cmd
def cmd_destroy(self): def cmd_destroy(self):
"""Destroy stream""" """Destroy stream"""

View File

@@ -30,6 +30,7 @@ def setup(self, sub):
help="Show raw timestamps in annotated information") help="Show raw timestamps in annotated information")
group.add_argument("-c", "--count", action="store_true", group.add_argument("-c", "--count", action="store_true",
help="Just output a count of matched data points") help="Just output a count of matched data points")
return cmd
def cmd_extract_verify(self): def cmd_extract_verify(self):
if self.args.start is not None and self.args.end is not None: if self.args.start is not None and self.args.end is not None:
@@ -45,7 +46,7 @@ def cmd_extract(self):
if self.args.timestamp_raw: if self.args.timestamp_raw:
time_string = repr time_string = repr
else: else:
time_string = self.time_string time_string = nilmdb.utils.time.format_time
if self.args.annotate: if self.args.annotate:
printf("# path: %s\n", self.args.path) printf("# path: %s\n", self.args.path)

26
nilmdb/cmdline/help.py Normal file
View File

@@ -0,0 +1,26 @@
from nilmdb.utils.printf import *
import argparse
import sys
def setup(self, sub):
cmd = sub.add_parser("help", help="Show detailed help for a command",
description="""
Show help for a command. 'help command' is
the same as 'command --help'.
""")
cmd.set_defaults(handler = cmd_help)
cmd.set_defaults(no_test_connect = True)
cmd.add_argument("command", nargs="?",
help="Command to get help about")
cmd.add_argument("rest", nargs=argparse.REMAINDER,
help=argparse.SUPPRESS)
return cmd
def cmd_help(self):
if self.args.command in self.subcmd:
self.subcmd[self.args.command].print_help()
else:
self.parser.print_help()
return

View File

@@ -12,6 +12,7 @@ def setup(self, sub):
version. version.
""") """)
cmd.set_defaults(handler = cmd_info) cmd.set_defaults(handler = cmd_info)
return cmd
def cmd_info(self): def cmd_info(self):
"""Print info about the server""" """Print info about the server"""

View File

@@ -2,6 +2,7 @@ from nilmdb.utils.printf import *
import nilmdb import nilmdb
import nilmdb.client import nilmdb.client
import nilmdb.utils.timestamper as timestamper import nilmdb.utils.timestamper as timestamper
import nilmdb.utils.time
import sys import sys
@@ -46,6 +47,7 @@ def setup(self, sub):
help="Path of stream, e.g. /foo/bar") help="Path of stream, e.g. /foo/bar")
group.add_argument("file", nargs="*", default=['-'], group.add_argument("file", nargs="*", default=['-'],
help="File(s) to insert (default: - (stdin))") help="File(s) to insert (default: - (stdin))")
return cmd
def cmd_insert(self): def cmd_insert(self):
# Find requested stream # Find requested stream
@@ -73,7 +75,7 @@ def cmd_insert(self):
start = self.args.start start = self.args.start
else: else:
try: try:
start = self.parse_time(filename) start = nilmdb.utils.time.parse_time(filename)
except ValueError: except ValueError:
self.die("error extracting time from filename '%s'", self.die("error extracting time from filename '%s'",
filename) filename)

View File

@@ -1,4 +1,5 @@
from nilmdb.utils.printf import * from nilmdb.utils.printf import *
import nilmdb.utils.time
import fnmatch import fnmatch
import argparse import argparse
@@ -23,11 +24,13 @@ def setup(self, sub):
group.add_argument("-l", "--layout", default="*", group.add_argument("-l", "--layout", default="*",
help="Match only this stream layout") help="Match only this stream layout")
group = cmd.add_argument_group("Interval extent")
group.add_argument("-E", "--extent", action="store_true",
help="Show min/max timestamps in this stream")
group = cmd.add_argument_group("Interval details") group = cmd.add_argument_group("Interval details")
group.add_argument("-d", "--detail", action="store_true", group.add_argument("-d", "--detail", action="store_true",
help="Show available data time intervals") help="Show available data time intervals")
group.add_argument("-T", "--timestamp-raw", action="store_true",
help="Show raw timestamps in time intervals")
group.add_argument("-s", "--start", group.add_argument("-s", "--start",
metavar="TIME", type=self.arg_time, metavar="TIME", type=self.arg_time,
help="Starting timestamp (free-form, inclusive)") help="Starting timestamp (free-form, inclusive)")
@@ -35,6 +38,12 @@ def setup(self, sub):
metavar="TIME", type=self.arg_time, metavar="TIME", type=self.arg_time,
help="Ending timestamp (free-form, noninclusive)") help="Ending timestamp (free-form, noninclusive)")
group = cmd.add_argument_group("Misc options")
group.add_argument("-T", "--timestamp-raw", action="store_true",
help="Show raw timestamps in time intervals or extents")
return cmd
def cmd_list_verify(self): def cmd_list_verify(self):
# A hidden "path_positional" argument lets the user leave off the # A hidden "path_positional" argument lets the user leave off the
# "-p" when specifying the path. Handle it here. # "-p" when specifying the path. Handle it here.
@@ -50,28 +59,38 @@ def cmd_list_verify(self):
if self.args.start >= self.args.end: if self.args.start >= self.args.end:
self.parser.error("start must precede end") self.parser.error("start must precede end")
if self.args.start is not None or self.args.end is not None:
if not self.args.detail:
self.parser.error("--start and --end only make sense with --detail")
def cmd_list(self): def cmd_list(self):
"""List available streams""" """List available streams"""
streams = self.client.stream_list() streams = self.client.stream_list(extent = True)
if self.args.timestamp_raw: if self.args.timestamp_raw:
time_string = repr time_string = repr
else: else:
time_string = self.time_string time_string = nilmdb.utils.time.format_time
for (path, layout) in streams: for (path, layout, extent_min, extent_max) in streams:
if not (fnmatch.fnmatch(path, self.args.path) and if not (fnmatch.fnmatch(path, self.args.path) and
fnmatch.fnmatch(layout, self.args.layout)): fnmatch.fnmatch(layout, self.args.layout)):
continue continue
printf("%s %s\n", path, layout) printf("%s %s\n", path, layout)
if not self.args.detail:
continue
printed = False if self.args.extent:
for (start, end) in self.client.stream_intervals(path, self.args.start, if extent_min is None or extent_max is None:
self.args.end): printf(" extent: (no data)\n")
printf(" [ %s -> %s ]\n", time_string(start), time_string(end)) else:
printed = True printf(" extent: %s -> %s\n",
if not printed: time_string(extent_min), time_string(extent_max))
printf(" (no intervals)\n")
if self.args.detail:
printed = False
for (start, end) in self.client.stream_intervals(
path, self.args.start, self.args.end):
printf(" [ %s -> %s ]\n", time_string(start), time_string(end))
printed = True
if not printed:
printf(" (no intervals)\n")

View File

@@ -26,6 +26,7 @@ def setup(self, sub):
exc.add_argument("-u", "--update", nargs="+", metavar="key=value", exc.add_argument("-u", "--update", nargs="+", metavar="key=value",
help="Update metadata using provided " help="Update metadata using provided "
"key=value pairs") "key=value pairs")
return cmd
def cmd_metadata(self): def cmd_metadata(self):
"""Manipulate metadata""" """Manipulate metadata"""

View File

@@ -23,6 +23,7 @@ def setup(self, sub):
group = cmd.add_argument_group("Output format") group = cmd.add_argument_group("Output format")
group.add_argument("-c", "--count", action="store_true", group.add_argument("-c", "--count", action="store_true",
help="Output number of data points removed") help="Output number of data points removed")
return cmd
def cmd_remove(self): def cmd_remove(self):
try: try:

View File

@@ -25,6 +25,9 @@ def main():
default = os.path.join(os.getcwd(), "db")) default = os.path.join(os.getcwd(), "db"))
group.add_argument('-q', '--quiet', help = 'Silence output', group.add_argument('-q', '--quiet', help = 'Silence output',
action = 'store_true') action = 'store_true')
group.add_argument('-n', '--nosync', help = 'Use asynchronous '
'commits for sqlite transactions',
action = 'store_true', default = False)
group = parser.add_argument_group("Debug options") group = parser.add_argument_group("Debug options")
group.add_argument('-y', '--yappi', help = 'Run under yappi profiler and ' group.add_argument('-y', '--yappi', help = 'Run under yappi profiler and '
@@ -33,8 +36,10 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Create database object # Create database object. Needs to be serialized before passing
db = nilmdb.server.NilmDB(args.database) # to the Server.
db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(args.database,
sync = not args.nosync)
# Configure the server # Configure the server
if args.quiet: if args.quiet:
@@ -48,6 +53,7 @@ def main():
# Print info # Print info
if not args.quiet: if not args.quiet:
print "Version: %s" % nilmdb.__version__
print "Database: %s" % (os.path.realpath(args.database)) print "Database: %s" % (os.path.realpath(args.database))
if args.address == '0.0.0.0' or args.address == '::': if args.address == '0.0.0.0' or args.address == '::':
host = socket.getfqdn() host = socket.getfqdn()

View File

@@ -5,15 +5,15 @@ from __future__ import absolute_import
# Try to set up pyximport to automatically rebuild Cython modules. If # Try to set up pyximport to automatically rebuild Cython modules. If
# this doesn't work, it's OK, as long as the modules were built externally. # this doesn't work, it's OK, as long as the modules were built externally.
# (e.g. python setup.py build_ext --inplace) # (e.g. python setup.py build_ext --inplace)
try: try: # pragma: no cover
import Cython import Cython
import distutils.version import distutils.version
if (distutils.version.LooseVersion(Cython.__version__) < if (distutils.version.LooseVersion(Cython.__version__) <
distutils.version.LooseVersion("0.16")): # pragma: no cover distutils.version.LooseVersion("0.17")): # pragma: no cover
raise ImportError("Cython version too old") raise ImportError("Cython version too old")
import pyximport import pyximport
pyximport.install(inplace = True, build_in_temp = False) pyximport.install(inplace = True, build_in_temp = False)
except ImportError: # pragma: no cover except (ImportError, TypeError): # pragma: no cover
pass pass
import nilmdb.server.layout import nilmdb.server.layout

View File

@@ -28,7 +28,7 @@ except: # pragma: no cover
table_cache_size = 16 table_cache_size = 16
fd_cache_size = 16 fd_cache_size = 16
@nilmdb.utils.must_close(wrap_verify = True) @nilmdb.utils.must_close(wrap_verify = False)
class BulkData(object): class BulkData(object):
def __init__(self, basepath, **kwargs): def __init__(self, basepath, **kwargs):
self.basepath = basepath self.basepath = basepath
@@ -171,7 +171,7 @@ class BulkData(object):
ospath = os.path.join(self.root, *elements) ospath = os.path.join(self.root, *elements)
return Table(ospath) return Table(ospath)
@nilmdb.utils.must_close(wrap_verify = True) @nilmdb.utils.must_close(wrap_verify = False)
class File(object): class File(object):
"""Object representing a single file on disk. Data can be appended, """Object representing a single file on disk. Data can be appended,
or the self.mmap handle can be used for random reads.""" or the self.mmap handle can be used for random reads."""
@@ -210,14 +210,28 @@ class File(object):
self.mmap.close() self.mmap.close()
self._f.close() self._f.close()
def append(self, data): def append(self, data): # pragma: no cover (below version used instead)
# Write data, flush it, and resize our mmap accordingly # Write data, flush it, and resize our mmap accordingly
self._f.write(data) self._f.write(data)
self._f.flush() self._f.flush()
self.size += len(data) self.size += len(data)
self._mmap_reopen() self._mmap_reopen()
@nilmdb.utils.must_close(wrap_verify = True) def append_pack_iter(self, count, packer, dataiter):
# An optimized verison of append, to avoid flushing the file
# and resizing the mmap after each data point.
try:
rows = []
for i in xrange(count):
row = dataiter.next()
rows.append(packer(*row))
self._f.write("".join(rows))
finally:
self._f.flush()
self.size = self._f.tell()
self._mmap_reopen()
@nilmdb.utils.must_close(wrap_verify = False)
class Table(object): class Table(object):
"""Tools to help access a single table (data at a specific OS path).""" """Tools to help access a single table (data at a specific OS path)."""
# See design.md for design details # See design.md for design details
@@ -351,9 +365,7 @@ class Table(object):
f = self.file_open(subdir, fname) f = self.file_open(subdir, fname)
# Write the data # Write the data
for i in xrange(count): f.append_pack_iter(count, self.packer.pack, dataiter)
row = dataiter.next()
f.append(self.packer.pack(*row))
remaining -= count remaining -= count
self.nrows += count self.nrows += count

View File

@@ -4,7 +4,6 @@ import time
import sys import sys
import inspect import inspect
import cStringIO import cStringIO
import numpy as np
cdef enum: cdef enum:
max_value_count = 64 max_value_count = 64
@@ -42,10 +41,12 @@ class Layout:
if datatype == 'uint16': if datatype == 'uint16':
self.parse = self.parse_uint16 self.parse = self.parse_uint16
self.format = self.format_uint16 self.format_str = "%.6f" + " %d" * self.count
self.format = self.format_generic
elif datatype == 'float32' or datatype == 'float64': elif datatype == 'float32' or datatype == 'float64':
self.parse = self.parse_float64 self.parse = self.parse_float64
self.format = self.format_float64 self.format_str = "%.6f" + " %f" * self.count
self.format = self.format_generic
else: else:
raise KeyError("invalid type") raise KeyError("invalid type")
@@ -57,15 +58,15 @@ class Layout:
cdef double ts cdef double ts
# Return doubles even in float32 case, since they're going into # Return doubles even in float32 case, since they're going into
# a Python array which would upconvert to double anyway. # a Python array which would upconvert to double anyway.
result = [] result = [0] * (self.count + 1)
cdef char *end cdef char *end
ts = libc.stdlib.strtod(text, &end) ts = libc.stdlib.strtod(text, &end)
if end == text: if end == text:
raise ValueError("bad timestamp") raise ValueError("bad timestamp")
result.append(ts) result[0] = ts
for n in range(self.count): for n in range(self.count):
text = end text = end
result.append(libc.stdlib.strtod(text, &end)) result[n+1] = libc.stdlib.strtod(text, &end)
if end == text: if end == text:
raise ValueError("wrong number of values") raise ValueError("wrong number of values")
n = 0 n = 0
@@ -79,18 +80,18 @@ class Layout:
cdef int n cdef int n
cdef double ts cdef double ts
cdef int v cdef int v
result = []
cdef char *end cdef char *end
result = [0] * (self.count + 1)
ts = libc.stdlib.strtod(text, &end) ts = libc.stdlib.strtod(text, &end)
if end == text: if end == text:
raise ValueError("bad timestamp") raise ValueError("bad timestamp")
result.append(ts) result[0] = ts
for n in range(self.count): for n in range(self.count):
text = end text = end
v = libc.stdlib.strtol(text, &end, 10) v = libc.stdlib.strtol(text, &end, 10)
if v < 0 or v > 65535: if v < 0 or v > 65535:
raise ValueError("value out of range") raise ValueError("value out of range")
result.append(v) result[n+1] = v
if end == text: if end == text:
raise ValueError("wrong number of values") raise ValueError("wrong number of values")
n = 0 n = 0
@@ -101,25 +102,12 @@ class Layout:
return (ts, result) return (ts, result)
# Formatters # Formatters
def format_float64(self, d): def format_generic(self, d):
n = len(d) - 1 n = len(d) - 1
if n != self.count: if n != self.count:
raise ValueError("wrong number of values for layout type: " raise ValueError("wrong number of values for layout type: "
"got %d, wanted %d" % (n, self.count)) "got %d, wanted %d" % (n, self.count))
s = "%.6f" % d[0] return (self.format_str % tuple(d)) + "\n"
for i in range(n):
s += " %f" % d[i+1]
return s + "\n"
def format_uint16(self, d):
n = len(d) - 1
if n != self.count:
raise ValueError("wrong number of values for layout type: "
"got %d, wanted %d" % (n, self.count))
s = "%.6f" % d[0]
for i in range(n):
s += " %d" % d[i+1]
return s + "\n"
# Get a layout by name # Get a layout by name
def get_named(typestring): def get_named(typestring):
@@ -154,7 +142,7 @@ class Parser(object):
layout, into an internal data structure suitable for a layout, into an internal data structure suitable for a
pytables 'table.append(parser.data)'. pytables 'table.append(parser.data)'.
""" """
cdef double last_ts = 0, ts cdef double last_ts = -1e12, ts
cdef int n = 0, i cdef int n = 0, i
cdef char *line cdef char *line

View File

@@ -97,12 +97,7 @@ class NilmDB(object):
# SQLite database too # SQLite database too
sqlfilename = os.path.join(self.basepath, "data.sql") sqlfilename = os.path.join(self.basepath, "data.sql")
# We use check_same_thread = False, assuming that the rest self.con = sqlite3.connect(sqlfilename, check_same_thread = True)
# of the code (e.g. Server) will be smart and not access this
# database from multiple threads simultaneously. Otherwise
# false positives will occur when the database is only opened
# in one thread, and only accessed in another.
self.con = sqlite3.connect(sqlfilename, check_same_thread = False)
self._sql_schema_update() self._sql_schema_update()
# See big comment at top about the performance implications of this # See big comment at top about the performance implications of this
@@ -274,28 +269,39 @@ class NilmDB(object):
return return
def stream_list(self, path = None, layout = None): def stream_list(self, path = None, layout = None, extent = False):
"""Return list of [path, layout] lists of all streams """Return list of lists of all streams in the database.
in the database.
If path is specified, include only streams with a path that If path is specified, include only streams with a path that
matches the given string. matches the given string.
If layout is specified, include only streams with a layout If layout is specified, include only streams with a layout
that matches the given string. that matches the given string.
"""
where = "WHERE 1=1"
params = ()
if layout:
where += " AND layout=?"
params += (layout,)
if path:
where += " AND path=?"
params += (path,)
result = self.con.execute("SELECT path, layout "
"FROM streams " + where, params).fetchall()
return sorted(list(x) for x in result) If extent = False, returns a list of lists containing
the path and layout: [ path, layout ]
If extent = True, returns a list of lists containing the
path, layout, and min/max extent of the data:
[ path, layout, extent_min, extent_max ]
"""
params = ()
query = "SELECT streams.path, streams.layout"
if extent:
query += ", min(ranges.start_time), max(ranges.end_time)"
query += " FROM streams"
if extent:
query += " LEFT JOIN ranges ON streams.id = ranges.stream_id"
query += " WHERE 1=1"
if layout is not None:
query += " AND streams.layout=?"
params += (layout,)
if path is not None:
query += " AND streams.path=?"
params += (path,)
query += " GROUP BY streams.id ORDER BY streams.path"
result = self.con.execute(query, params).fetchall()
return [ list(x) for x in result ]
def stream_intervals(self, path, start = None, end = None): def stream_intervals(self, path, start = None, end = None):
""" """

View File

@@ -15,12 +15,6 @@ import decorator
import traceback import traceback
import psutil import psutil
try:
cherrypy.tools.json_out
except: # pragma: no cover
sys.stderr.write("Cherrypy 3.2+ required\n")
sys.exit(1)
class NilmApp(object): class NilmApp(object):
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
@@ -77,6 +71,52 @@ def exception_to_httperror(*expected):
# care of that. # care of that.
return decorator.decorator(wrapper) return decorator.decorator(wrapper)
# Custom CherryPy tools
def CORS_allow(methods):
"""This does several things:
Handles CORS preflight requests.
Adds Allow: header to all requests.
Raise 405 if request.method not in method.
It is similar to cherrypy.tools.allow, with the CORS stuff added.
"""
request = cherrypy.request.headers
response = cherrypy.response.headers
if not isinstance(methods, (tuple, list)): # pragma: no cover
methods = [ methods ]
methods = [ m.upper() for m in methods if m ]
if not methods: # pragma: no cover
methods = [ 'GET', 'HEAD' ]
elif 'GET' in methods and 'HEAD' not in methods: # pragma: no cover
methods.append('HEAD')
response['Allow'] = ', '.join(methods)
# Allow all origins
if 'Origin' in request:
response['Access-Control-Allow-Origin'] = request['Origin']
# If it's a CORS request, send response.
request_method = request.get("Access-Control-Request-Method", None)
request_headers = request.get("Access-Control-Request-Headers", None)
if (cherrypy.request.method == "OPTIONS" and
request_method and request_headers):
response['Access-Control-Allow-Headers'] = request_headers
response['Access-Control-Allow-Methods'] = ', '.join(methods)
# Try to stop further processing and return a 200 OK
cherrypy.response.status = "200 OK"
cherrypy.response.body = ""
cherrypy.request.handler = lambda: ""
return
# Reject methods that were not explicitly allowed
if cherrypy.request.method not in methods:
raise cherrypy.HTTPError(405)
cherrypy.tools.CORS_allow = cherrypy.Tool('on_start_resource', CORS_allow)
# CherryPy apps # CherryPy apps
class Root(NilmApp): class Root(NilmApp):
"""Root application for NILM database""" """Root application for NILM database"""
@@ -116,19 +156,28 @@ class Stream(NilmApp):
# /stream/list # /stream/list
# /stream/list?layout=PrepData # /stream/list?layout=PrepData
# /stream/list?path=/newton/prep # /stream/list?path=/newton/prep&extent=1
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
def list(self, path = None, layout = None): def list(self, path = None, layout = None, extent = None):
"""List all streams in the database. With optional path or """List all streams in the database. With optional path or
layout parameter, just list streams that match the given path layout parameter, just list streams that match the given path
or layout""" or layout.
return self.db.stream_list(path, layout)
If extent is not given, returns a list of lists containing
the path and layout: [ path, layout ]
If extent is provided, returns a list of lists containing the
path, layout, and min/max extent of the data:
[ path, layout, extent_min, extent_max ]
"""
return self.db.stream_list(path, layout, bool(extent))
# /stream/create?path=/newton/prep&layout=PrepData # /stream/create?path=/newton/prep&layout=PrepData
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
@exception_to_httperror(NilmDBError, ValueError) @exception_to_httperror(NilmDBError, ValueError)
@cherrypy.tools.CORS_allow(methods = ["POST"])
def create(self, path, layout): def create(self, path, layout):
"""Create a new stream in the database. Provide path """Create a new stream in the database. Provide path
and one of the nilmdb.layout.layouts keys. and one of the nilmdb.layout.layouts keys.
@@ -139,6 +188,7 @@ class Stream(NilmApp):
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
@exception_to_httperror(NilmDBError) @exception_to_httperror(NilmDBError)
@cherrypy.tools.CORS_allow(methods = ["POST"])
def destroy(self, path): def destroy(self, path):
"""Delete a stream and its associated data.""" """Delete a stream and its associated data."""
return self.db.stream_destroy(path) return self.db.stream_destroy(path)
@@ -171,6 +221,7 @@ class Stream(NilmApp):
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
@exception_to_httperror(NilmDBError, LookupError, TypeError) @exception_to_httperror(NilmDBError, LookupError, TypeError)
@cherrypy.tools.CORS_allow(methods = ["POST"])
def set_metadata(self, path, data): def set_metadata(self, path, data):
"""Set metadata for the named stream, replacing any """Set metadata for the named stream, replacing any
existing metadata. Data should be a json-encoded existing metadata. Data should be a json-encoded
@@ -182,6 +233,7 @@ class Stream(NilmApp):
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
@exception_to_httperror(NilmDBError, LookupError, TypeError) @exception_to_httperror(NilmDBError, LookupError, TypeError)
@cherrypy.tools.CORS_allow(methods = ["POST"])
def update_metadata(self, path, data): def update_metadata(self, path, data):
"""Update metadata for the named stream. Data """Update metadata for the named stream. Data
should be a json-encoded dictionary""" should be a json-encoded dictionary"""
@@ -191,7 +243,7 @@ class Stream(NilmApp):
# /stream/insert?path=/newton/prep # /stream/insert?path=/newton/prep
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
#@cherrypy.tools.disable_prb() @cherrypy.tools.CORS_allow(methods = ["PUT"])
def insert(self, path, start, end): def insert(self, path, start, end):
""" """
Insert new data into the database. Provide textual data Insert new data into the database. Provide textual data
@@ -199,12 +251,9 @@ class Stream(NilmApp):
""" """
# Important that we always read the input before throwing any # Important that we always read the input before throwing any
# errors, to keep lengths happy for persistent connections. # errors, to keep lengths happy for persistent connections.
# However, CherryPy 3.2.2 has a bug where this fails for GET # Note that CherryPy 3.2.2 has a bug where this fails for GET
# requests, so catch that. (issue #1134) # requests, if we ever want to handle those (issue #1134)
try: body = cherrypy.request.body.read()
body = cherrypy.request.body.read()
except TypeError:
raise cherrypy.HTTPError("400 Bad Request", "No request body")
# Check path and get layout # Check path and get layout
streams = self.db.stream_list(path = path) streams = self.db.stream_list(path = path)
@@ -250,6 +299,7 @@ class Stream(NilmApp):
@cherrypy.expose @cherrypy.expose
@cherrypy.tools.json_out() @cherrypy.tools.json_out()
@exception_to_httperror(NilmDBError) @exception_to_httperror(NilmDBError)
@cherrypy.tools.CORS_allow(methods = ["POST"])
def remove(self, path, start = None, end = None): def remove(self, path, start = None, end = None):
""" """
Remove data from the backend database. Removes all data in Remove data from the backend database. Removes all data in
@@ -270,17 +320,16 @@ class Stream(NilmApp):
# /stream/intervals?path=/newton/prep&start=1234567890.0&end=1234567899.0 # /stream/intervals?path=/newton/prep&start=1234567890.0&end=1234567899.0
@cherrypy.expose @cherrypy.expose
@chunked_response @chunked_response
@response_type("text/plain") @response_type("application/x-json-stream")
def intervals(self, path, start = None, end = None): def intervals(self, path, start = None, end = None):
""" """
Get intervals from backend database. Streams the resulting Get intervals from backend database. Streams the resulting
intervals as JSON strings separated by newlines. This may intervals as JSON strings separated by CR LF pairs. This may
make multiple requests to the nilmdb backend to avoid causing make multiple requests to the nilmdb backend to avoid causing
it to block for too long. it to block for too long.
Note that the response type is set to 'text/plain' even Note that the response type is the non-standard
though we're sending back JSON; this is because we're not 'application/x-json-stream' for lack of a better option.
really returning a single JSON object.
""" """
if start is not None: if start is not None:
start = float(start) start = float(start)
@@ -300,8 +349,8 @@ class Stream(NilmApp):
def content(start, end): def content(start, end):
# Note: disable chunked responses to see tracebacks from here. # Note: disable chunked responses to see tracebacks from here.
while True: while True:
(intervals, restart) = self.db.stream_intervals(path, start, end) (ints, restart) = self.db.stream_intervals(path, start, end)
response = ''.join([ json.dumps(i) + "\n" for i in intervals ]) response = ''.join([ json.dumps(i) + "\r\n" for i in ints ])
yield response yield response
if restart == 0: if restart == 0:
break break
@@ -381,17 +430,20 @@ class Server(object):
# Save server version, just for verification during tests # Save server version, just for verification during tests
self.version = nilmdb.__version__ self.version = nilmdb.__version__
# Need to wrap DB object in a serializer because we'll call
# into it from separate threads.
self.embedded = embedded self.embedded = embedded
self.db = nilmdb.utils.Serializer(db) self.db = db
if not getattr(db, "_thread_safe", None):
raise KeyError("Database object " + str(db) + " doesn't claim "
"to be thread safe. You should pass "
"nilmdb.utils.serializer_proxy(NilmDB)(args) "
"rather than NilmDB(args).")
# Build up global server configuration # Build up global server configuration
cherrypy.config.update({ cherrypy.config.update({
'server.socket_host': host, 'server.socket_host': host,
'server.socket_port': port, 'server.socket_port': port,
'engine.autoreload_on': False, 'engine.autoreload_on': False,
'server.max_request_body_size': 4*1024*1024, 'server.max_request_body_size': 8*1024*1024,
}) })
if self.embedded: if self.embedded:
cherrypy.config.update({ 'environment': 'embedded' }) cherrypy.config.update({ 'environment': 'embedded' })
@@ -402,11 +454,14 @@ class Server(object):
'error_page.default': self.json_error_page, 'error_page.default': self.json_error_page,
}) })
# Send a permissive Access-Control-Allow-Origin (CORS) header # Some default headers to just help identify that things are working
# with all responses so that browsers can send cross-domain app_config.update({ 'response.headers.X-Jim-Is-Awesome': 'yeah' })
# requests to this server.
app_config.update({ 'response.headers.Access-Control-Allow-Origin': # Set up Cross-Origin Resource Sharing (CORS) handler so we
'*' }) # can correctly respond to browsers' CORS preflight requests.
# This also limits verbs to GET and HEAD by default.
app_config.update({ 'tools.CORS_allow.on': True,
'tools.CORS_allow.methods': ['GET', 'HEAD'] })
# Send tracebacks in error responses. They're hidden by the # Send tracebacks in error responses. They're hidden by the
# error_page function for client errors (code 400-499). # error_page function for client errors (code 400-499).

View File

@@ -2,9 +2,9 @@
from nilmdb.utils.timer import Timer from nilmdb.utils.timer import Timer
from nilmdb.utils.iteratorizer import Iteratorizer from nilmdb.utils.iteratorizer import Iteratorizer
from nilmdb.utils.serializer import Serializer from nilmdb.utils.serializer import serializer_proxy
from nilmdb.utils.lrucache import lru_cache from nilmdb.utils.lrucache import lru_cache
from nilmdb.utils.diskusage import du, human_size from nilmdb.utils.diskusage import du, human_size
from nilmdb.utils.mustclose import must_close from nilmdb.utils.mustclose import must_close
from nilmdb.utils.urllib import urlencode
from nilmdb.utils import atomic from nilmdb.utils import atomic
import nilmdb.utils.threadsafety

View File

@@ -16,6 +16,7 @@ class IteratorizerThread(threading.Thread):
callback (provided by this class) as an argument callback (provided by this class) as an argument
""" """
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = "Iteratorizer-" + function.__name__ + "-" + self.name
self.function = function self.function = function
self.queue = queue self.queue = queue
self.die = False self.die = False

View File

@@ -12,15 +12,12 @@ def must_close(errorfile = sys.stderr, wrap_verify = False):
already been called.""" already been called."""
def class_decorator(cls): def class_decorator(cls):
# Helper to replace a class method with a wrapper function, def wrap_class_method(wrapper):
# while maintaining argument specs etc. try:
def wrap_class_method(wrapper_func): orig = getattr(cls, wrapper.__name__).im_func
method = wrapper_func.__name__ except:
if method in cls.__dict__: orig = lambda x: None
orig = getattr(cls, method).im_func setattr(cls, wrapper.__name__, decorator.decorator(wrapper, orig))
else:
orig = lambda self: None
setattr(cls, method, decorator.decorator(wrapper_func, orig))
@wrap_class_method @wrap_class_method
def __init__(orig, self, *args, **kwargs): def __init__(orig, self, *args, **kwargs):

View File

@@ -1,6 +1,10 @@
import Queue import Queue
import threading import threading
import sys import sys
import decorator
import inspect
import types
import functools
# This file provides a class that will wrap an object and serialize # This file provides a class that will wrap an object and serialize
# all calls to its methods. All calls to that object will be queued # all calls to its methods. All calls to that object will be queued
@@ -12,8 +16,9 @@ import sys
class SerializerThread(threading.Thread): class SerializerThread(threading.Thread):
"""Thread that retrieves call information from the queue, makes the """Thread that retrieves call information from the queue, makes the
call, and returns the results.""" call, and returns the results."""
def __init__(self, call_queue): def __init__(self, classname, call_queue):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = "Serializer-" + classname + "-" + self.name
self.call_queue = call_queue self.call_queue = call_queue
def run(self): def run(self):
@@ -22,51 +27,83 @@ class SerializerThread(threading.Thread):
# Terminate if result_queue is None # Terminate if result_queue is None
if result_queue is None: if result_queue is None:
return return
exception = None
result = None
try: try:
result = func(*args, **kwargs) # wrapped result = func(*args, **kwargs) # wrapped
except: except:
result_queue.put((sys.exc_info(), None)) exception = sys.exc_info()
# Ensure we delete these before returning a result, so
# we don't unncessarily hold onto a reference while
# we're waiting for the next call.
del func, args, kwargs
result_queue.put((exception, result))
del exception, result
def serializer_proxy(obj_or_type):
"""Wrap the given object or type in a SerializerObjectProxy.
Returns a SerializerObjectProxy object that proxies all method
calls to the object, as well as attribute retrievals.
The proxied requests, including instantiation, are performed in a
single thread and serialized between caller threads.
"""
class SerializerCallProxy(object):
def __init__(self, call_queue, func, objectproxy):
self.call_queue = call_queue
self.func = func
# Need to hold a reference to object proxy so it doesn't
# go away (and kill the thread) until after get called.
self.objectproxy = objectproxy
def __call__(self, *args, **kwargs):
result_queue = Queue.Queue()
self.call_queue.put((result_queue, self.func, args, kwargs))
( exc_info, result ) = result_queue.get()
if exc_info is None:
return result
else: else:
result_queue.put((None, result)) raise exc_info[0], exc_info[1], exc_info[2]
class WrapCall(object): class SerializerObjectProxy(object):
"""Wrap a callable using the given queues""" def __init__(self, obj_or_type, *args, **kwargs):
self.__object = obj_or_type
try:
if type(obj_or_type) in (types.TypeType, types.ClassType):
classname = obj_or_type.__name__
else:
classname = obj_or_type.__class__.__name__
except AttributeError: # pragma: no cover
classname = "???"
self.__call_queue = Queue.Queue()
self.__thread = SerializerThread(classname, self.__call_queue)
self.__thread.daemon = True
self.__thread.start()
self._thread_safe = True
def __init__(self, call_queue, result_queue, func): def __getattr__(self, key):
self.call_queue = call_queue if key.startswith("_SerializerObjectProxy__"): # pragma: no cover
self.result_queue = result_queue raise AttributeError
self.func = func attr = getattr(self.__object, key)
if not callable(attr):
getter = SerializerCallProxy(self.__call_queue, getattr, self)
return getter(self.__object, key)
r = SerializerCallProxy(self.__call_queue, attr, self)
return r
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self.call_queue.put((self.result_queue, self.func, args, kwargs)) """Call this to instantiate the type, if a type was passed
( exc_info, result ) = self.result_queue.get() to serializer_proxy. Otherwise, pass the call through."""
if exc_info is None: ret = SerializerCallProxy(self.__call_queue,
return result self.__object, self)(*args, **kwargs)
else: if type(self.__object) in (types.TypeType, types.ClassType):
raise exc_info[0], exc_info[1], exc_info[2] # Instantiation
self.__object = ret
return self
return ret
class WrapObject(object): def __del__(self):
"""Wrap all calls to methods in a target object with WrapCall""" self.__call_queue.put((None, None, None, None))
self.__thread.join()
def __init__(self, target): return SerializerObjectProxy(obj_or_type)
self.__wrap_target = target
self.__wrap_call_queue = Queue.Queue()
self.__wrap_serializer = SerializerThread(self.__wrap_call_queue)
self.__wrap_serializer.daemon = True
self.__wrap_serializer.start()
def __getattr__(self, key):
"""Wrap methods of self.__wrap_target in a WrapCall instance"""
func = getattr(self.__wrap_target, key)
if not callable(func):
raise TypeError("Can't serialize attribute %r (type: %s)"
% (key, type(func)))
result_queue = Queue.Queue()
return WrapCall(self.__wrap_call_queue, result_queue, func)
def __del__(self):
self.__wrap_call_queue.put((None, None, None, None))
self.__wrap_serializer.join()
# Just an alias
Serializer = WrapObject

View File

@@ -0,0 +1,109 @@
from nilmdb.utils.printf import *
import threading
import warnings
import types
def verify_proxy(obj_or_type, exception = False, check_thread = True,
check_concurrent = True):
"""Wrap the given object or type in a VerifyObjectProxy.
Returns a VerifyObjectProxy that proxies all method calls to the
given object, as well as attribute retrievals.
When calling methods, the following checks are performed. If
exception is True, an exception is raised. Otherwise, a warning
is printed.
check_thread = True # Warn/fail if two different threads call methods.
check_concurrent = True # Warn/fail if two functions are concurrently
# run through this proxy
"""
class Namespace(object):
pass
class VerifyCallProxy(object):
def __init__(self, func, parent_namespace):
self.func = func
self.parent_namespace = parent_namespace
def __call__(self, *args, **kwargs):
p = self.parent_namespace
this = threading.current_thread()
try:
callee = self.func.__name__
except AttributeError:
callee = "???"
if p.thread is None:
p.thread = this
p.thread_callee = callee
if check_thread and p.thread != this:
err = sprintf("unsafe threading: %s called %s.%s,"
" but %s called %s.%s",
p.thread.name, p.classname, p.thread_callee,
this.name, p.classname, callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
warnings.warn(err)
need_concur_unlock = False
if check_concurrent:
if p.concur_lock.acquire(False) == False:
err = sprintf("unsafe concurrency: %s called %s.%s "
"while %s is still in %s.%s",
this.name, p.classname, callee,
p.concur_tname, p.classname, p.concur_callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
warnings.warn(err)
else:
p.concur_tname = this.name
p.concur_callee = callee
need_concur_unlock = True
try:
ret = self.func(*args, **kwargs)
finally:
if need_concur_unlock:
p.concur_lock.release()
return ret
class VerifyObjectProxy(object):
def __init__(self, obj_or_type, *args, **kwargs):
p = Namespace()
self.__ns = p
p.thread = None
p.thread_callee = None
p.concur_lock = threading.Lock()
p.concur_tname = None
p.concur_callee = None
self.__obj = obj_or_type
try:
if type(obj_or_type) in (types.TypeType, types.ClassType):
p.classname = self.__obj.__name__
else:
p.classname = self.__obj.__class__.__name__
except AttributeError: # pragma: no cover
p.classname = "???"
def __getattr__(self, key):
if key.startswith("_VerifyObjectProxy__"): # pragma: no cover
raise AttributeError
attr = getattr(self.__obj, key)
if not callable(attr):
return VerifyCallProxy(getattr, self.__ns)(self.__obj, key)
return VerifyCallProxy(attr, self.__ns)
def __call__(self, *args, **kwargs):
"""Call this to instantiate the type, if a type was passed
to verify_proxy. Otherwise, pass the call through."""
ret = VerifyCallProxy(self.__obj, self.__ns)(*args, **kwargs)
if type(self.__obj) in (types.TypeType, types.ClassType):
# Instantiation
self.__obj = ret
return self
return ret
return VerifyObjectProxy(obj_or_type)

54
nilmdb/utils/time.py Normal file
View File

@@ -0,0 +1,54 @@
from nilmdb.utils import datetime_tz
import re
def parse_time(toparse):
"""
Parse a free-form time string and return a datetime_tz object.
If the string doesn't contain a timestamp, the current local
timezone is assumed (e.g. from the TZ env var).
"""
# If string isn't "now" and doesn't contain at least 4 digits,
# consider it invalid. smartparse might otherwise accept
# empty strings and strings with just separators.
if toparse != "now" and len(re.findall(r"\d", toparse)) < 4:
raise ValueError("not enough digits for a timestamp")
# Try to just parse the time as given
try:
return datetime_tz.datetime_tz.smartparse(toparse)
except ValueError:
pass
# Try to extract a substring in a condensed format that we expect
# to see in a filename or header comment
res = re.search(r"(^|[^\d])(" # non-numeric or SOL
r"(199\d|2\d\d\d)" # year
r"[-/]?" # separator
r"(0[1-9]|1[012])" # month
r"[-/]?" # separator
r"([012]\d|3[01])" # day
r"[-T ]?" # separator
r"([01]\d|2[0-3])" # hour
r"[:]?" # separator
r"([0-5]\d)" # minute
r"[:]?" # separator
r"([0-5]\d)?" # second
r"([-+]\d\d\d\d)?" # timezone
r")", toparse)
if res is not None:
try:
return datetime_tz.datetime_tz.smartparse(res.group(2))
except ValueError:
pass
# Could also try to successively parse substrings, but let's
# just give up for now.
raise ValueError("unable to parse timestamp")
def format_time(timestamp):
"""
Convert a Unix timestamp to a string for printing, using the
local timezone for display (e.g. from the TZ env var).
"""
dt = datetime_tz.datetime_tz.fromtimestamp(timestamp)
return dt.strftime("%a, %d %b %Y %H:%M:%S.%f %z")

View File

@@ -6,6 +6,7 @@
# foo.flush() # foo.flush()
from __future__ import print_function from __future__ import print_function
from __future__ import absolute_import
import contextlib import contextlib
import time import time

View File

@@ -1,37 +0,0 @@
from __future__ import absolute_import
from urllib import quote_plus, _is_unicode
# urllib.urlencode insists on encoding Unicode as ASCII. This is based
# on that function, except we always encode it as UTF-8 instead.
def urlencode(query):
"""Encode a dictionary into a URL query string.
If any values in the query arg are sequences, each sequence
element is converted to a separate parameter.
"""
query = query.items()
l = []
for k, v in query:
k = quote_plus(str(k))
if isinstance(v, str):
v = quote_plus(v)
l.append(k + '=' + v)
elif _is_unicode(v):
v = quote_plus(v.encode("utf-8","strict"))
l.append(k + '=' + v)
else:
try:
# is this a sufficient test for sequence-ness?
len(v)
except TypeError:
# not a sequence
v = quote_plus(str(v))
l.append(k + '=' + v)
else:
# loop over the sequence
for elt in v:
l.append(k + '=' + quote_plus(str(elt)))
return '&'.join(l)

View File

@@ -20,6 +20,7 @@ cover-erase=1
stop=1 stop=1
verbosity=2 verbosity=2
tests=tests tests=tests
#tests=tests/test_threadsafety.py
#tests=tests/test_bulkdata.py #tests=tests/test_bulkdata.py
#tests=tests/test_mustclose.py #tests=tests/test_mustclose.py
#tests=tests/test_lrucache.py #tests=tests/test_lrucache.py

View File

@@ -115,6 +115,7 @@ setup(name='nilmdb',
'python-dateutil', 'python-dateutil',
'pytz', 'pytz',
'psutil >= 0.3.0', 'psutil >= 0.3.0',
'requests >= 1.1.0, < 2.0.0',
], ],
packages = [ 'nilmdb', packages = [ 'nilmdb',
'nilmdb.utils', 'nilmdb.utils',

View File

@@ -1,4 +1,5 @@
test_printf.py test_printf.py
test_threadsafety.py
test_lrucache.py test_lrucache.py
test_mustclose.py test_mustclose.py

View File

@@ -6,6 +6,7 @@ from nilmdb.utils import timestamper
from nilmdb.client import ClientError, ServerError from nilmdb.client import ClientError, ServerError
from nilmdb.utils import datetime_tz from nilmdb.utils import datetime_tz
from nose.plugins.skip import SkipTest
from nose.tools import * from nose.tools import *
from nose.tools import assert_raises from nose.tools import assert_raises
import itertools import itertools
@@ -19,11 +20,12 @@ import unittest
import warnings import warnings
import resource import resource
import time import time
import re
from testutil.helpers import * from testutil.helpers import *
testdb = "tests/client-testdb" testdb = "tests/client-testdb"
testurl = "http://localhost:12380/" testurl = "http://localhost:32180/"
def setup_module(): def setup_module():
global test_server, test_db global test_server, test_db
@@ -31,9 +33,9 @@ def setup_module():
recursive_unlink(testdb) recursive_unlink(testdb)
# Start web app on a custom port # Start web app on a custom port
test_db = nilmdb.NilmDB(testdb, sync = False) test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(testdb, sync = False)
test_server = nilmdb.Server(test_db, host = "127.0.0.1", test_server = nilmdb.Server(test_db, host = "127.0.0.1",
port = 12380, stoppable = False, port = 32180, stoppable = False,
fast_shutdown = True, fast_shutdown = True,
force_traceback = False) force_traceback = False)
test_server.start(blocking = False) test_server.start(blocking = False)
@@ -53,20 +55,14 @@ class TestClient(object):
client.version() client.version()
client.close() client.close()
# Trigger same error with a PUT request
client = nilmdb.Client(url = "http://localhost:1/")
with assert_raises(nilmdb.client.ServerError):
client.version()
client.close()
# Then a fake URL on a real host # Then a fake URL on a real host
client = nilmdb.Client(url = "http://localhost:12380/fake/") client = nilmdb.Client(url = "http://localhost:32180/fake/")
with assert_raises(nilmdb.client.ClientError): with assert_raises(nilmdb.client.ClientError):
client.version() client.version()
client.close() client.close()
# Now a real URL with no http:// prefix # Now a real URL with no http:// prefix
client = nilmdb.Client(url = "localhost:12380") client = nilmdb.Client(url = "localhost:32180")
version = client.version() version = client.version()
client.close() client.close()
@@ -97,6 +93,15 @@ class TestClient(object):
with assert_raises(ClientError): with assert_raises(ClientError):
client.stream_create("/newton/prep", "NoSuchLayout") client.stream_create("/newton/prep", "NoSuchLayout")
# Bad method types
with assert_raises(ClientError):
client.http.put("/stream/list","")
# Try a bunch of times to make sure the request body is getting consumed
for x in range(10):
with assert_raises(ClientError):
client.http.post("/stream/list")
client = nilmdb.Client(url = testurl)
# Create three streams # Create three streams
client.stream_create("/newton/prep", "PrepData") client.stream_create("/newton/prep", "PrepData")
client.stream_create("/newton/raw", "RawData") client.stream_create("/newton/raw", "RawData")
@@ -162,6 +167,10 @@ class TestClient(object):
def test_client_04_insert(self): def test_client_04_insert(self):
client = nilmdb.Client(url = testurl) client = nilmdb.Client(url = testurl)
# Limit _max_data to 1 MB, since our test file is 1.5 MB
old_max_data = nilmdb.client.client.StreamInserter._max_data
nilmdb.client.client.StreamInserter._max_data = 1 * 1024 * 1024
datetime_tz.localtz_set("America/New_York") datetime_tz.localtz_set("America/New_York")
testfile = "tests/data/prep-20120323T1000" testfile = "tests/data/prep-20120323T1000"
@@ -234,8 +243,9 @@ class TestClient(object):
in_("400 Bad Request", str(e.exception)) in_("400 Bad Request", str(e.exception))
# Client chunks the input, so the exact timestamp here might change # Client chunks the input, so the exact timestamp here might change
# if the chunk positions change. # if the chunk positions change.
in_("Data timestamp 1332511271.016667 >= end time 1332511201.0", assert(re.search("Data timestamp 13325[0-9]+\.[0-9]+ "
str(e.exception)) ">= end time 1332511201.0", str(e.exception))
is not None)
# Now do the real load # Now do the real load
data = timestamper.TimestamperRate(testfile, start, 120) data = timestamper.TimestamperRate(testfile, start, 120)
@@ -254,6 +264,7 @@ class TestClient(object):
in_("400 Bad Request", str(e.exception)) in_("400 Bad Request", str(e.exception))
in_("verlap", str(e.exception)) in_("verlap", str(e.exception))
nilmdb.client.client.StreamInserter._max_data = old_max_data
client.close() client.close()
def test_client_05_extractremove(self): def test_client_05_extractremove(self):
@@ -266,12 +277,6 @@ class TestClient(object):
with assert_raises(ClientError) as e: with assert_raises(ClientError) as e:
client.stream_remove("/newton/prep", 123, 120) client.stream_remove("/newton/prep", 123, 120)
# Test the exception we get if we nest requests
with assert_raises(Exception) as e:
for data in client.stream_extract("/newton/prep"):
x = client.stream_intervals("/newton/prep")
in_("nesting calls is not supported", str(e.exception))
# Test count # Test count
eq_(client.stream_count("/newton/prep"), 14400) eq_(client.stream_count("/newton/prep"), 14400)
@@ -301,24 +306,6 @@ class TestClient(object):
with assert_raises(ServerError) as e: with assert_raises(ServerError) as e:
client.http.get_gen("http://nosuchurl/").next() client.http.get_gen("http://nosuchurl/").next()
# Check non-json version of string output
eq_(json.loads(client.http.get("/stream/list",retjson=False)),
client.http.get("/stream/list",retjson=True))
# Check non-json version of generator output
for (a, b) in itertools.izip(
client.http.get_gen("/stream/list",retjson=False),
client.http.get_gen("/stream/list",retjson=True)):
eq_(json.loads(a), b)
# Check PUT with generator out
with assert_raises(ClientError) as e:
client.http.put_gen("stream/insert", "",
{ "path": "/newton/prep",
"start": 0, "end": 0 }).next()
in_("400 Bad Request", str(e.exception))
in_("start must precede end", str(e.exception))
# Check 404 for missing streams # Check 404 for missing streams
for function in [ client.stream_intervals, client.stream_extract ]: for function in [ client.stream_intervals, client.stream_extract ]:
with assert_raises(ClientError) as e: with assert_raises(ClientError) as e:
@@ -337,35 +324,35 @@ class TestClient(object):
client = nilmdb.Client(url = testurl) client = nilmdb.Client(url = testurl)
http = client.http http = client.http
# Use a warning rather than returning a test failure, so that we can # Use a warning rather than returning a test failure for the
# still disable chunked responses for debugging. # transfer-encoding, so that we can still disable chunked
# responses for debugging.
def headers():
h = ""
for (k, v) in http._last_response.headers.items():
h += k + ": " + v + "\n"
return h.lower()
# Intervals # Intervals
x = http.get("stream/intervals", { "path": "/newton/prep" }, x = http.get("stream/intervals", { "path": "/newton/prep" })
retjson=False) if "transfer-encoding: chunked" not in headers():
lines_(x, 1)
if "Transfer-Encoding: chunked" not in http._headers:
warnings.warn("Non-chunked HTTP response for /stream/intervals") warnings.warn("Non-chunked HTTP response for /stream/intervals")
if "Content-Type: text/plain;charset=utf-8" not in http._headers: if "content-type: application/x-json-stream" not in headers():
raise AssertionError("/stream/intervals is not text/plain:\n" + raise AssertionError("/stream/intervals content type "
http._headers) "is not application/x-json-stream:\n" +
headers())
# Extract # Extract
x = http.get("stream/extract", x = http.get("stream/extract",
{ "path": "/newton/prep", { "path": "/newton/prep",
"start": "123", "start": "123",
"end": "124" }, retjson=False) "end": "124" })
if "Transfer-Encoding: chunked" not in http._headers: if "transfer-encoding: chunked" not in headers():
warnings.warn("Non-chunked HTTP response for /stream/extract") warnings.warn("Non-chunked HTTP response for /stream/extract")
if "Content-Type: text/plain;charset=utf-8" not in http._headers: if "content-type: text/plain;charset=utf-8" not in headers():
raise AssertionError("/stream/extract is not text/plain:\n" + raise AssertionError("/stream/extract is not text/plain:\n" +
http._headers) headers())
# Make sure Access-Control-Allow-Origin gets set
if "Access-Control-Allow-Origin: " not in http._headers:
raise AssertionError("No Access-Control-Allow-Origin (CORS) "
"header in /stream/extract response:\n" +
http._headers)
client.close() client.close()
@@ -576,3 +563,38 @@ class TestClient(object):
# Clean up # Clean up
client.stream_destroy("/empty/test") client.stream_destroy("/empty/test")
client.close() client.close()
def test_client_12_persistent(self):
# Check that connections are persistent when they should be.
# This is pretty hard to test; we have to poke deep into
# the Requests library.
with nilmdb.Client(url = testurl) as c:
def connections():
try:
poolmanager = c.http._last_response.connection.poolmanager
pool = poolmanager.pools[('http','localhost',32180)]
return (pool.num_connections, pool.num_requests)
except:
raise SkipTest("can't get connection info")
# First request makes a connection
c.stream_create("/persist/test", "uint16_1")
eq_(connections(), (1, 1))
# Non-generator
c.stream_list("/persist/test")
eq_(connections(), (1, 2))
c.stream_list("/persist/test")
eq_(connections(), (1, 3))
# Generators
for x in c.stream_intervals("/persist/test"):
pass
eq_(connections(), (1, 4))
for x in c.stream_intervals("/persist/test"):
pass
eq_(connections(), (1, 5))
# Clean up
c.stream_destroy("/persist/test")
eq_(connections(), (1, 6))

View File

@@ -11,12 +11,7 @@ from nose.tools import assert_raises
import itertools import itertools
import os import os
import re import re
import shutil
import sys import sys
import threading
import urllib2
from urllib2 import urlopen, HTTPError
import Queue
import StringIO import StringIO
import shlex import shlex
@@ -27,11 +22,12 @@ testdb = "tests/cmdline-testdb"
def server_start(max_results = None, bulkdata_args = {}): def server_start(max_results = None, bulkdata_args = {}):
global test_server, test_db global test_server, test_db
# Start web app on a custom port # Start web app on a custom port
test_db = nilmdb.NilmDB(testdb, sync = False, test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(
max_results = max_results, testdb, sync = False,
bulkdata_args = bulkdata_args) max_results = max_results,
bulkdata_args = bulkdata_args)
test_server = nilmdb.Server(test_db, host = "127.0.0.1", test_server = nilmdb.Server(test_db, host = "127.0.0.1",
port = 12380, stoppable = False, port = 32180, stoppable = False,
fast_shutdown = True, fast_shutdown = True,
force_traceback = False) force_traceback = False)
test_server.start(blocking = False) test_server.start(blocking = False)
@@ -63,6 +59,7 @@ class TestCmdline(object):
passing the given input. Returns a tuple with the output and passing the given input. Returns a tuple with the output and
exit code""" exit code"""
# printf("TZ=UTC ./nilmtool.py %s\n", arg_string) # printf("TZ=UTC ./nilmtool.py %s\n", arg_string)
os.environ['NILMDB_URL'] = "http://localhost:32180/"
class stdio_wrapper: class stdio_wrapper:
def __init__(self, stdin, stdout, stderr): def __init__(self, stdin, stdout, stderr):
self.io = (stdin, stdout, stderr) self.io = (stdin, stdout, stderr)
@@ -162,18 +159,18 @@ class TestCmdline(object):
# try some URL constructions # try some URL constructions
self.fail("--url http://nosuchurl/ info") self.fail("--url http://nosuchurl/ info")
self.contain("Couldn't resolve host 'nosuchurl'") self.contain("error connecting to server")
self.fail("--url nosuchurl info") self.fail("--url nosuchurl info")
self.contain("Couldn't resolve host 'nosuchurl'") self.contain("error connecting to server")
self.fail("-u nosuchurl/foo info") self.fail("-u nosuchurl/foo info")
self.contain("Couldn't resolve host 'nosuchurl'") self.contain("error connecting to server")
self.fail("-u localhost:0 info") self.fail("-u localhost:1 info")
self.contain("couldn't connect to host") self.contain("error connecting to server")
self.ok("-u localhost:12380 info") self.ok("-u localhost:32180 info")
self.ok("info") self.ok("info")
# Duplicated arguments should fail, but this isn't implemented # Duplicated arguments should fail, but this isn't implemented
@@ -191,16 +188,46 @@ class TestCmdline(object):
self.fail("extract --start 2000-01-01 --start 2001-01-02") self.fail("extract --start 2000-01-01 --start 2001-01-02")
self.contain("duplicated argument") self.contain("duplicated argument")
def test_02_info(self): # Verify that "help command" and "command --help" are identical
# for all commands.
self.fail("")
m = re.search(r"{(.*)}", self.captured)
for command in [""] + m.group(1).split(','):
self.ok(command + " --help")
cap1 = self.captured
self.ok("help " + command)
cap2 = self.captured
self.ok("help " + command + " asdf --url --zxcv -")
cap3 = self.captured
eq_(cap1, cap2)
eq_(cap2, cap3)
def test_02_parsetime(self):
os.environ['TZ'] = "America/New_York"
test = datetime_tz.datetime_tz.now()
parse_time = nilmdb.utils.time.parse_time
eq_(parse_time(str(test)), test)
test = datetime_tz.datetime_tz.smartparse("20120405 1400-0400")
eq_(parse_time("hi there 20120405 1400-0400 testing! 123"), test)
eq_(parse_time("20120405 1800 UTC"), test)
eq_(parse_time("20120405 1400-0400 UTC"), test)
for badtime in [ "20120405 1400-9999", "hello", "-", "", "4:00" ]:
with assert_raises(ValueError):
x = parse_time(badtime)
x = parse_time("now")
eq_(parse_time("snapshot-20120405-140000.raw.gz"), test)
eq_(parse_time("prep-20120405T1400"), test)
def test_03_info(self):
self.ok("info") self.ok("info")
self.contain("Server URL: http://localhost:12380/") self.contain("Server URL: http://localhost:32180/")
self.contain("Client version: " + nilmdb.__version__) self.contain("Client version: " + nilmdb.__version__)
self.contain("Server version: " + test_server.version) self.contain("Server version: " + test_server.version)
self.contain("Server database path") self.contain("Server database path")
self.contain("Server database size") self.contain("Server database size")
self.contain("Server database free space") self.contain("Server database free space")
def test_03_createlist(self): def test_04_createlist(self):
# Basic stream tests, like those in test_client. # Basic stream tests, like those in test_client.
# No streams # No streams
@@ -276,7 +303,7 @@ class TestCmdline(object):
self.fail("list /newton/prep --start 2020-01-01 --end 2000-01-01") self.fail("list /newton/prep --start 2020-01-01 --end 2000-01-01")
self.contain("start must precede end") self.contain("start must precede end")
def test_04_metadata(self): def test_05_metadata(self):
# Set / get metadata # Set / get metadata
self.fail("metadata") self.fail("metadata")
self.fail("metadata --get") self.fail("metadata --get")
@@ -333,22 +360,6 @@ class TestCmdline(object):
self.fail("metadata /newton/nosuchpath") self.fail("metadata /newton/nosuchpath")
self.contain("No stream at path /newton/nosuchpath") self.contain("No stream at path /newton/nosuchpath")
def test_05_parsetime(self):
os.environ['TZ'] = "America/New_York"
cmd = nilmdb.cmdline.Cmdline(None)
test = datetime_tz.datetime_tz.now()
eq_(cmd.parse_time(str(test)), test)
test = datetime_tz.datetime_tz.smartparse("20120405 1400-0400")
eq_(cmd.parse_time("hi there 20120405 1400-0400 testing! 123"), test)
eq_(cmd.parse_time("20120405 1800 UTC"), test)
eq_(cmd.parse_time("20120405 1400-0400 UTC"), test)
for badtime in [ "20120405 1400-9999", "hello", "-", "", "4:00" ]:
with assert_raises(ValueError):
x = cmd.parse_time(badtime)
x = cmd.parse_time("now")
eq_(cmd.parse_time("snapshot-20120405-140000.raw.gz"), test)
eq_(cmd.parse_time("prep-20120405T1400"), test)
def test_06_insert(self): def test_06_insert(self):
self.ok("insert --help") self.ok("insert --help")
@@ -417,7 +428,7 @@ class TestCmdline(object):
# bad start time # bad start time
self.fail("insert --rate 120 --start 'whatever' /newton/prep /dev/null") self.fail("insert --rate 120 --start 'whatever' /newton/prep /dev/null")
def test_07_detail(self): def test_07_detail_extent(self):
# Just count the number of lines, it's probably fine # Just count the number of lines, it's probably fine
self.ok("list --detail") self.ok("list --detail")
lines_(self.captured, 8) lines_(self.captured, 8)
@@ -462,6 +473,18 @@ class TestCmdline(object):
lines_(self.captured, 2) lines_(self.captured, 2)
self.contain("[ 1332497115.612 -> 1332497159.991668 ]") self.contain("[ 1332497115.612 -> 1332497159.991668 ]")
# Check --extent output
self.ok("list --extent")
lines_(self.captured, 6)
self.ok("list -E -T")
self.contain(" extent: 1332496800 -> 1332497159.991668")
self.contain(" extent: (no data)")
# Misc
self.fail("list --extent --start='23 Mar 2012 10:05:15.50'")
self.contain("--start and --end only make sense with --detail")
def test_08_extract(self): def test_08_extract(self):
# nonexistent stream # nonexistent stream
self.fail("extract /no/such/foo --start 2000-01-01 --end 2020-01-01") self.fail("extract /no/such/foo --start 2000-01-01 --end 2020-01-01")

View File

@@ -9,14 +9,7 @@ from nose.tools import assert_raises
import distutils.version import distutils.version
import itertools import itertools
import os import os
import shutil
import sys import sys
import cherrypy
import threading
import urllib2
from urllib2 import urlopen, HTTPError
import Queue
import cStringIO
import random import random
import unittest import unittest
@@ -246,7 +239,7 @@ class TestLayoutSpeed:
parser = Parser(layout) parser = Parser(layout)
formatter = Formatter(layout) formatter = Formatter(layout)
parser.parse(data) parser.parse(data)
data = formatter.format(parser.data) formatter.format(parser.data)
elapsed = time.time() - start elapsed = time.time() - start
printf("roundtrip %s: %d ms, %.1f μs/row, %d rows/sec\n", printf("roundtrip %s: %d ms, %.1f μs/row, %d rows/sec\n",
layout, layout,
@@ -264,3 +257,8 @@ class TestLayoutSpeed:
return [ sprintf("%d", random.randint(0,65535)) return [ sprintf("%d", random.randint(0,65535))
for x in range(10) ] for x in range(10) ]
do_speedtest("uint16_10", datagen) do_speedtest("uint16_10", datagen)
def datagen():
return [ sprintf("%d", random.randint(0,65535))
for x in range(6) ]
do_speedtest("uint16_6", datagen)

View File

@@ -34,6 +34,10 @@ class Bar:
def __del__(self): def __del__(self):
fprintf(err, "Deleting\n") fprintf(err, "Deleting\n")
@classmethod
def baz(self):
fprintf(err, "Baz\n")
def close(self): def close(self):
fprintf(err, "Closing\n") fprintf(err, "Closing\n")

View File

@@ -6,15 +6,15 @@ import distutils.version
import simplejson as json import simplejson as json
import itertools import itertools
import os import os
import shutil
import sys import sys
import cherrypy
import threading import threading
import urllib2 import urllib2
from urllib2 import urlopen, HTTPError from urllib2 import urlopen, HTTPError
import Queue
import cStringIO import cStringIO
import time import time
import requests
from nilmdb.utils import serializer_proxy
testdb = "tests/testdb" testdb = "tests/testdb"
@@ -104,15 +104,20 @@ class Test00Nilmdb(object): # named 00 so it runs first
class TestBlockingServer(object): class TestBlockingServer(object):
def setUp(self): def setUp(self):
self.db = nilmdb.NilmDB(testdb, sync=False) self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False)
def tearDown(self): def tearDown(self):
self.db.close() self.db.close()
def test_blocking_server(self): def test_blocking_server(self):
# Server should fail if the database doesn't have a "_thread_safe"
# property.
with assert_raises(KeyError):
nilmdb.Server(object())
# Start web app on a custom port # Start web app on a custom port
self.server = nilmdb.Server(self.db, host = "127.0.0.1", self.server = nilmdb.Server(self.db, host = "127.0.0.1",
port = 12380, stoppable = True) port = 32180, stoppable = True)
# Run it # Run it
event = threading.Event() event = threading.Event()
@@ -124,13 +129,13 @@ class TestBlockingServer(object):
raise AssertionError("server didn't start in 10 seconds") raise AssertionError("server didn't start in 10 seconds")
# Send request to exit. # Send request to exit.
req = urlopen("http://127.0.0.1:12380/exit/", timeout = 1) req = urlopen("http://127.0.0.1:32180/exit/", timeout = 1)
# Wait for it # Wait for it
thread.join() thread.join()
def geturl(path): def geturl(path):
req = urlopen("http://127.0.0.1:12380" + path, timeout = 10) req = urlopen("http://127.0.0.1:32180" + path, timeout = 10)
return req.read() return req.read()
def getjson(path): def getjson(path):
@@ -140,9 +145,9 @@ class TestServer(object):
def setUp(self): def setUp(self):
# Start web app on a custom port # Start web app on a custom port
self.db = nilmdb.NilmDB(testdb, sync=False) self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False)
self.server = nilmdb.Server(self.db, host = "127.0.0.1", self.server = nilmdb.Server(self.db, host = "127.0.0.1",
port = 12380, stoppable = False) port = 32180, stoppable = False)
self.server.start(blocking = False) self.server.start(blocking = False)
def tearDown(self): def tearDown(self):
@@ -202,11 +207,33 @@ class TestServer(object):
"&key=foo") "&key=foo")
eq_(data, {'foo': None}) eq_(data, {'foo': None})
def test_cors_headers(self):
# Test that CORS headers are being set correctly
def test_insert(self): # Normal GET should send simple response
# GET instead of POST (no body) url = "http://127.0.0.1:32180/stream/list"
# (actual POST test is done by client code) r = requests.get(url, headers = { "Origin": "http://google.com/" })
with assert_raises(HTTPError) as e: eq_(r.status_code, 200)
getjson("/stream/insert?path=/newton/prep&start=0&end=0") if "access-control-allow-origin" not in r.headers:
eq_(e.exception.code, 400) raise AssertionError("No Access-Control-Allow-Origin (CORS) "
"header in response:\n", r.headers)
eq_(r.headers["access-control-allow-origin"], "http://google.com/")
# OPTIONS without CORS preflight headers should result in 405
r = requests.options(url, headers = {
"Origin": "http://google.com/",
})
eq_(r.status_code, 405)
# OPTIONS with preflight headers should give preflight response
r = requests.options(url, headers = {
"Origin": "http://google.com/",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "X-Custom",
})
eq_(r.status_code, 200)
if "access-control-allow-origin" not in r.headers:
raise AssertionError("No Access-Control-Allow-Origin (CORS) "
"header in response:\n", r.headers)
eq_(r.headers["access-control-allow-methods"], "GET, HEAD")
eq_(r.headers["access-control-allow-headers"], "X-Custom")

View File

@@ -9,16 +9,28 @@ import time
from testutil.helpers import * from testutil.helpers import *
#raise nose.exc.SkipTest("Skip these")
class Foo(object): class Foo(object):
val = 0 val = 0
def __init__(self, asdf = "asdf"):
self.init_thread = threading.current_thread().name
@classmethod
def foo(self):
pass
def fail(self): def fail(self):
raise Exception("you asked me to do this") raise Exception("you asked me to do this")
def test(self, debug = False): def test(self, debug = False):
self.tester(debug)
def t(self):
pass
def tester(self, debug = False):
# purposely not thread-safe # purposely not thread-safe
self.test_thread = threading.current_thread().name
oldval = self.val oldval = self.val
newval = oldval + 1 newval = oldval + 1
time.sleep(0.05) time.sleep(0.05)
@@ -46,27 +58,29 @@ class Base(object):
t.join() t.join()
self.verify_result() self.verify_result()
def verify_result(self):
eq_(self.foo.val, 20)
eq_(self.foo.init_thread, self.foo.test_thread)
class TestUnserialized(Base): class TestUnserialized(Base):
def setUp(self): def setUp(self):
self.foo = Foo() self.foo = Foo()
def verify_result(self): def verify_result(self):
# This should have failed to increment properly # This should have failed to increment properly
assert(self.foo.val != 20) ne_(self.foo.val, 20)
# Init and tests ran in different threads
ne_(self.foo.init_thread, self.foo.test_thread)
class TestSerialized(Base): class TestSerializer(Base):
def setUp(self): def setUp(self):
self.realfoo = Foo() self.foo = nilmdb.utils.serializer_proxy(Foo)("qwer")
self.foo = nilmdb.utils.Serializer(self.realfoo)
def tearDown(self): def test_multi(self):
del self.foo sp = nilmdb.utils.serializer_proxy
sp(Foo("x")).t()
def verify_result(self): sp(sp(Foo)("x")).t()
# This should have worked sp(sp(Foo))("x").t()
eq_(self.realfoo.val, 20) sp(sp(Foo("x"))).t()
sp(sp(Foo)("x")).t()
def test_attribute(self): sp(sp(Foo))("x").t()
# Can't wrap attributes yet
with assert_raises(TypeError):
self.foo.val

View File

@@ -0,0 +1,96 @@
import nilmdb
from nilmdb.utils.printf import *
import nose
from nose.tools import *
from nose.tools import assert_raises
from testutil.helpers import *
import threading
class Thread(threading.Thread):
def __init__(self, target):
self.target = target
threading.Thread.__init__(self)
def run(self):
try:
self.target()
except AssertionError as e:
self.error = e
else:
self.error = None
class Test():
def __init__(self):
self.test = 1234
@classmethod
def asdf(cls):
pass
def foo(self, exception = False, reenter = False):
if exception:
raise Exception()
self.bar(reenter)
def bar(self, reenter):
if reenter:
self.foo()
return 123
def baz_threaded(self, target):
t = Thread(target)
t.start()
t.join()
return t
def baz(self, target):
target()
class TestThreadSafety(object):
def tryit(self, c, threading_ok, concurrent_ok):
eq_(c.test, 1234)
c.foo()
t = Thread(c.foo)
t.start()
t.join()
if threading_ok and t.error:
raise Exception("got unexpected error: " + str(t.error))
if not threading_ok and not t.error:
raise Exception("failed to get expected error")
try:
c.baz(c.foo)
except AssertionError as e:
if concurrent_ok:
raise Exception("got unexpected error: " + str(e))
else:
if not concurrent_ok:
raise Exception("failed to get expected error")
t = c.baz_threaded(c.foo)
if (concurrent_ok and threading_ok) and t.error:
raise Exception("got unexpected error: " + str(t.error))
if not (concurrent_ok and threading_ok) and not t.error:
raise Exception("failed to get expected error")
def test(self):
proxy = nilmdb.utils.threadsafety.verify_proxy
self.tryit(Test(), True, True)
self.tryit(proxy(Test(), True, True, True), False, False)
self.tryit(proxy(Test(), True, True, False), False, True)
self.tryit(proxy(Test(), True, False, True), True, False)
self.tryit(proxy(Test(), True, False, False), True, True)
self.tryit(proxy(Test, True, True, True)(), False, False)
self.tryit(proxy(Test, True, True, False)(), False, True)
self.tryit(proxy(Test, True, False, True)(), True, False)
self.tryit(proxy(Test, True, False, False)(), True, True)
proxy(proxy(proxy(Test))()).foo()
c = proxy(Test())
c.foo()
try:
c.foo(exception = True)
except Exception:
pass
c.foo()

View File

@@ -83,7 +83,7 @@ To use it:
import os, sys, re import os, sys, re
from distutils.core import Command from distutils.core import Command
from distutils.command.sdist import sdist as _sdist from distutils.command.sdist import sdist as _sdist
from distutils.command.build import build as _build from distutils.command.build_py import build_py as _build_py
versionfile_source = None versionfile_source = None
versionfile_build = None versionfile_build = None
@@ -578,11 +578,10 @@ class cmd_version(Command):
ver = get_version(verbose=True) ver = get_version(verbose=True)
print("Version is currently: %s" % ver) print("Version is currently: %s" % ver)
class cmd_build_py(_build_py):
class cmd_build(_build):
def run(self): def run(self):
versions = get_versions(verbose=True) versions = get_versions(verbose=True)
_build.run(self) _build_py.run(self)
# now locate _version.py in the new build/ directory and replace it # now locate _version.py in the new build/ directory and replace it
# with an updated value # with an updated value
target_versionfile = os.path.join(self.build_lib, versionfile_build) target_versionfile = os.path.join(self.build_lib, versionfile_build)
@@ -651,6 +650,6 @@ class cmd_update_files(Command):
def get_cmdclass(): def get_cmdclass():
return {'version': cmd_version, return {'version': cmd_version,
'update_files': cmd_update_files, 'update_files': cmd_update_files,
'build': cmd_build, 'build_py': cmd_build_py,
'sdist': cmd_sdist, 'sdist': cmd_sdist,
} }