These functions can now take an object or a type (class). If given an object, they will wrap subsequent calls to that object. If given a type, they will return an object that can be instantiated to create a new object, and all calls including __init__ will be covered by the serialization or thread verification.tags/nilmdb-1.2
@@ -35,7 +35,7 @@ def main(): | |||
# Create database object. Needs to be serialized before passing | |||
# to the Server. | |||
db = nilmdb.utils.serializer_proxy(NilmDB, args.database) | |||
db = nilmdb.utils.serializer_proxy(NilmDB)(args.database) | |||
# Configure the server | |||
if args.quiet: | |||
@@ -16,7 +16,7 @@ import functools | |||
class SerializerThread(threading.Thread): | |||
"""Thread that retrieves call information from the queue, makes the | |||
call, and returns the results.""" | |||
def __init__(self, call_queue): | |||
def __init__(self, classname, call_queue): | |||
threading.Thread.__init__(self) | |||
self.call_queue = call_queue | |||
@@ -39,54 +39,70 @@ class SerializerThread(threading.Thread): | |||
result_queue.put((exception, result)) | |||
del exception, result | |||
def _call_in_thread(__func, __queue, *args, **kwargs): | |||
"""Make a call by putting it in the serialization thread's | |||
queue and waiting for a response""" | |||
result_queue = Queue.Queue() | |||
__queue.put((result_queue, __func, args, kwargs)) | |||
( exc_info, result ) = result_queue.get() | |||
if exc_info is None: | |||
return result | |||
else: | |||
raise exc_info[0], exc_info[1], exc_info[2] | |||
def serializer_proxy(obj_or_type): | |||
"""Wrap the given object or type in a SerializerObjectProxy. | |||
def serializer_proxy(__cls, *args, **kwargs): | |||
"""Instantiates the given class with the given arguments, | |||
and returns a SerializerObjectProxy object that proxies all method | |||
calls to the new object, as well as attribute retrievals. | |||
Returns a SerializerObjectProxy object that proxies all method | |||
calls to the object, as well as attribute retrievals. | |||
The proxied requests, including instantiation, are performed in a | |||
single thread and serialized between caller threads. | |||
""" | |||
class SerializerCallProxy(object): | |||
def __init__(self, call_queue, func): | |||
def __init__(self, call_queue, func, objectproxy): | |||
self.call_queue = call_queue | |||
self.result_queue = Queue.Queue() | |||
self.func = func | |||
# Need to hold a reference to object proxy so it doesn't | |||
# go away (and kill the thread) until after get called. | |||
self.objectproxy = objectproxy | |||
def __call__(self, *args, **kwargs): | |||
return _call_in_thread(self.func, self.call_queue, | |||
*args, **kwargs) | |||
result_queue = Queue.Queue() | |||
self.call_queue.put((result_queue, self.func, args, kwargs)) | |||
( exc_info, result ) = result_queue.get() | |||
if exc_info is None: | |||
return result | |||
else: | |||
raise exc_info[0], exc_info[1], exc_info[2] | |||
class SerializerObjectProxy(object): | |||
def __init__(self, __cls, *args, **kwargs): | |||
self.__s_call_queue = Queue.Queue() | |||
self.__s_thread = SerializerThread( | |||
self.__s_call_queue) | |||
self.__s_thread.daemon = True | |||
self.__s_thread.start() | |||
self.__s_object = _call_in_thread(__cls, self.__s_call_queue, | |||
*args, **kwargs) | |||
def __init__(self, obj_or_type, *args, **kwargs): | |||
self.__object = obj_or_type | |||
try: | |||
if type(obj_or_type) in (types.TypeType, types.ClassType): | |||
classname = obj_or_type.__name__ | |||
else: | |||
classname = obj_or_type.__class__.__name__ | |||
except AttributeError: # pragma: no cover | |||
classname = "???" | |||
self.__call_queue = Queue.Queue() | |||
self.__thread = SerializerThread(classname, self.__call_queue) | |||
self.__thread.daemon = True | |||
self.__thread.start() | |||
self._thread_safe = True | |||
def __getattr__(self, key): | |||
attr = getattr(self.__s_object, key) | |||
if key.startswith("_SerializerObjectProxy__"): # pragma: no cover | |||
raise AttributeError | |||
attr = getattr(self.__object, key) | |||
if not callable(attr): | |||
getter = SerializerCallProxy(self.__s_call_queue, getattr) | |||
return getter(self.__s_object, key) | |||
return SerializerCallProxy(self.__s_call_queue, attr) | |||
getter = SerializerCallProxy(self.__call_queue, getattr, self) | |||
return getter(self.__object, key) | |||
r = SerializerCallProxy(self.__call_queue, attr, self) | |||
return r | |||
def __call__(self, *args, **kwargs): | |||
"""Call this to instantiate the type, if a type was passed | |||
to serializer_proxy. Otherwise, pass the call through.""" | |||
ret = SerializerCallProxy(self.__call_queue, | |||
self.__object, self)(*args, **kwargs) | |||
if type(self.__object) in (types.TypeType, types.ClassType): | |||
# Instantiation | |||
self.__object = ret | |||
return self | |||
return ret | |||
def __del__(self): | |||
self.__s_call_queue.put((None, None, None, None)) | |||
self.__s_thread.join() | |||
self.__call_queue.put((None, None, None, None)) | |||
self.__thread.join() | |||
return SerializerObjectProxy(__cls, *args, **kwargs) | |||
return SerializerObjectProxy(obj_or_type) |
@@ -1,34 +1,37 @@ | |||
from nilmdb.utils.printf import * | |||
import threading | |||
import warnings | |||
import types | |||
def verify_proxy(obj, exception = False, | |||
check_thread = True, check_concurrent = True): | |||
"""Return a VerifyObjectProxy that proxies all method calls | |||
to the given object, as well as attribute retrievals. | |||
def verify_proxy(obj_or_type, exception = False, check_thread = True, | |||
check_concurrent = True): | |||
"""Wrap the given object or type in a VerifyObjectProxy. | |||
When calling methods, the following checks are performed. | |||
If exception is True, an exception is raised. Otherwise, a warning | |||
Returns a VerifyObjectProxy that proxies all method calls to the | |||
given object, as well as attribute retrievals. | |||
When calling methods, the following checks are performed. If | |||
exception is True, an exception is raised. Otherwise, a warning | |||
is printed. | |||
check_thread = True # Warn/fail if two different threads call methods. | |||
check_concurrent = True # Warn/fail if two functions are concurrently | |||
# run through this proxy | |||
Note that the __init__ call is not included in the checks, since | |||
this wrapper is passed an already instantiated object. | |||
""" | |||
class Namespace(object): | |||
pass | |||
class VerifyCallProxy(object): | |||
def __init__(self, func, parent): | |||
def __init__(self, func, parent_namespace): | |||
self.func = func | |||
self.parent = parent | |||
self.parent_namespace = parent_namespace | |||
def __call__(self, *args, **kwargs): | |||
p = self.parent | |||
p = self.parent_namespace | |||
this = threading.current_thread() | |||
callee = self.func.__name__ | |||
classname = p.__class__.__name__ | |||
try: | |||
callee = self.func.__name__ | |||
except AttributeError: | |||
callee = "???" | |||
if p.thread is None: | |||
p.thread = this | |||
@@ -37,8 +40,8 @@ def verify_proxy(obj, exception = False, | |||
if check_thread and p.thread != this: | |||
err = sprintf("unsafe threading: %s called %s.%s," | |||
" but %s called %s.%s", | |||
p.thread.name, classname, p.thread_callee, | |||
this.name, classname, callee) | |||
p.thread.name, p.classname, p.thread_callee, | |||
this.name, p.classname, callee) | |||
if exception: | |||
raise AssertionError(err) | |||
else: # pragma: no cover | |||
@@ -49,8 +52,8 @@ def verify_proxy(obj, exception = False, | |||
if p.concur_lock.acquire(False) == False: | |||
err = sprintf("unsafe concurrency: %s called %s.%s " | |||
"while %s is still in %s.%s", | |||
this.name, classname, callee, | |||
p.concur_tname, classname, p.concur_callee) | |||
this.name, p.classname, callee, | |||
p.concur_tname, p.classname, p.concur_callee) | |||
if exception: | |||
raise AssertionError(err) | |||
else: # pragma: no cover | |||
@@ -68,18 +71,39 @@ def verify_proxy(obj, exception = False, | |||
return ret | |||
class VerifyObjectProxy(object): | |||
def __init__(self, obj): | |||
self.obj = obj | |||
self.thread = None | |||
self.thread_callee = None | |||
self.concur_lock = threading.Lock() | |||
self.concur_tname = None | |||
self.concur_callee = None | |||
def __init__(self, obj_or_type, *args, **kwargs): | |||
p = Namespace() | |||
self.__ns = p | |||
p.thread = None | |||
p.thread_callee = None | |||
p.concur_lock = threading.Lock() | |||
p.concur_tname = None | |||
p.concur_callee = None | |||
self.__obj = obj_or_type | |||
try: | |||
if type(obj_or_type) in (types.TypeType, types.ClassType): | |||
p.classname = self.__obj.__name__ | |||
else: | |||
p.classname = self.__obj.__class__.__name__ | |||
except AttributeError: # pragma: no cover | |||
p.classname = "???" | |||
def __getattr__(self, key): | |||
attr = getattr(self.obj, key) | |||
if key.startswith("_VerifyObjectProxy__"): # pragma: no cover | |||
raise AttributeError | |||
attr = getattr(self.__obj, key) | |||
if not callable(attr): | |||
return VerifyCallProxy(getattr, self)(self.obj, key) | |||
return VerifyCallProxy(attr, self) | |||
return VerifyCallProxy(getattr, self.__ns)(self.__obj, key) | |||
return VerifyCallProxy(attr, self.__ns) | |||
def __call__(self, *args, **kwargs): | |||
"""Call this to instantiate the type, if a type was passed | |||
to verify_proxy. Otherwise, pass the call through.""" | |||
ret = VerifyCallProxy(self.__obj, self.__ns)(*args, **kwargs) | |||
if type(self.__obj) in (types.TypeType, types.ClassType): | |||
# Instantiation | |||
self.__obj = ret | |||
return self | |||
return ret | |||
return VerifyObjectProxy(obj) | |||
return VerifyObjectProxy(obj_or_type) |
@@ -31,7 +31,7 @@ def setup_module(): | |||
recursive_unlink(testdb) | |||
# Start web app on a custom port | |||
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB, testdb, sync = False) | |||
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(testdb, sync = False) | |||
test_server = nilmdb.Server(test_db, host = "127.0.0.1", | |||
port = 12380, stoppable = False, | |||
fast_shutdown = True, | |||
@@ -27,10 +27,10 @@ testdb = "tests/cmdline-testdb" | |||
def server_start(max_results = None, bulkdata_args = {}): | |||
global test_server, test_db | |||
# Start web app on a custom port | |||
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB, | |||
testdb, sync = False, | |||
max_results = max_results, | |||
bulkdata_args = bulkdata_args) | |||
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)( | |||
testdb, sync = False, | |||
max_results = max_results, | |||
bulkdata_args = bulkdata_args) | |||
test_server = nilmdb.Server(test_db, host = "127.0.0.1", | |||
port = 12380, stoppable = False, | |||
fast_shutdown = True, | |||
@@ -16,6 +16,8 @@ import Queue | |||
import cStringIO | |||
import time | |||
from nilmdb.utils import serializer_proxy | |||
testdb = "tests/testdb" | |||
#@atexit.register | |||
@@ -104,8 +106,7 @@ class Test00Nilmdb(object): # named 00 so it runs first | |||
class TestBlockingServer(object): | |||
def setUp(self): | |||
self.db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB, | |||
testdb, sync=False) | |||
self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False) | |||
def tearDown(self): | |||
self.db.close() | |||
@@ -146,8 +147,7 @@ class TestServer(object): | |||
def setUp(self): | |||
# Start web app on a custom port | |||
self.db = nilmdb.utils.serializer_proxy( | |||
nilmdb.NilmDB, testdb, sync=False) | |||
self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False) | |||
self.server = nilmdb.Server(self.db, host = "127.0.0.1", | |||
port = 12380, stoppable = False) | |||
self.server.start(blocking = False) | |||
@@ -25,6 +25,9 @@ class Foo(object): | |||
def test(self, debug = False): | |||
self.tester(debug) | |||
def t(self): | |||
pass | |||
def tester(self, debug = False): | |||
# purposely not thread-safe | |||
self.test_thread = threading.current_thread().name | |||
@@ -38,12 +41,12 @@ class Foo(object): | |||
class Base(object): | |||
def test_1_wrapping(self): | |||
def test_wrapping(self): | |||
self.foo.test() | |||
with assert_raises(Exception): | |||
self.foo.fail() | |||
def test_2_threaded(self): | |||
def test_threaded(self): | |||
def func(foo): | |||
foo.test() | |||
threads = [] | |||
@@ -55,6 +58,10 @@ class Base(object): | |||
t.join() | |||
self.verify_result() | |||
def verify_result(self): | |||
eq_(self.foo.val, 20) | |||
eq_(self.foo.init_thread, self.foo.test_thread) | |||
class TestUnserialized(Base): | |||
def setUp(self): | |||
self.foo = Foo() | |||
@@ -62,15 +69,18 @@ class TestUnserialized(Base): | |||
def verify_result(self): | |||
# This should have failed to increment properly | |||
ne_(self.foo.val, 20) | |||
# Init and tests ran in different threads | |||
ne_(self.foo.init_thread, self.foo.test_thread) | |||
class TestSerializer(Base): | |||
# Now test the SerializerProxy version | |||
def setUp(self): | |||
self.foo = nilmdb.utils.serializer_proxy(Foo, "qwer") | |||
def verify_result(self): | |||
eq_(self.foo.val, 20) | |||
eq_(self.foo.init_thread, self.foo.test_thread) | |||
self.foo = nilmdb.utils.serializer_proxy(Foo)("qwer") | |||
def test_multi(self): | |||
sp = nilmdb.utils.serializer_proxy | |||
sp(Foo("x")).t() | |||
sp(sp(Foo)("x")).t() | |||
sp(sp(Foo))("x").t() | |||
sp(sp(Foo("x"))).t() | |||
sp(sp(Foo)("x")).t() | |||
sp(sp(Foo))("x").t() |
@@ -37,6 +37,7 @@ class Test(): | |||
def bar(self, reenter): | |||
if reenter: | |||
self.foo() | |||
return 123 | |||
def baz_threaded(self, target): | |||
t = Thread(target) | |||
@@ -74,10 +75,17 @@ class TestThreadSafety(object): | |||
def test(self): | |||
proxy = nilmdb.utils.threadsafety.verify_proxy | |||
self.tryit(Test(), True, True) | |||
self.tryit(proxy(Test(), True, True, True), False, False) | |||
self.tryit(proxy(Test(), True, True, False), False, True) | |||
self.tryit(proxy(Test(), True, False, True), True, False) | |||
self.tryit(proxy(Test(), True, False, False), True, True) | |||
self.tryit(proxy(Test, True, True, True)(), False, False) | |||
self.tryit(proxy(Test, True, True, False)(), False, True) | |||
self.tryit(proxy(Test, True, False, True)(), True, False) | |||
self.tryit(proxy(Test, True, False, False)(), True, True) | |||
proxy(proxy(proxy(Test))()).foo() | |||
c = proxy(Test()) | |||
c.foo() | |||