Browse Source

Improve serializer_proxy and verify_thread_proxy

These functions can now take an object or a type (class).

If given an object, they will wrap subsequent calls to that object.
If given a type, they will return an object that can be instantiated
to create a new object, and all calls including __init__ will be
covered by the serialization or thread verification.
tags/nilmdb-1.2
Jim Paris 8 years ago
parent
commit
317c53ab6f
8 changed files with 141 additions and 83 deletions
  1. +1
    -1
      nilmdb/scripts/nilmdb_server.py
  2. +50
    -34
      nilmdb/utils/serializer.py
  3. +54
    -30
      nilmdb/utils/threadsafety.py
  4. +1
    -1
      tests/test_client.py
  5. +4
    -4
      tests/test_cmdline.py
  6. +4
    -4
      tests/test_nilmdb.py
  7. +19
    -9
      tests/test_serializer.py
  8. +8
    -0
      tests/test_threadsafety.py

+ 1
- 1
nilmdb/scripts/nilmdb_server.py View File

@@ -35,7 +35,7 @@ def main():

# Create database object. Needs to be serialized before passing
# to the Server.
db = nilmdb.utils.serializer_proxy(NilmDB, args.database)
db = nilmdb.utils.serializer_proxy(NilmDB)(args.database)

# Configure the server
if args.quiet:


+ 50
- 34
nilmdb/utils/serializer.py View File

@@ -16,7 +16,7 @@ import functools
class SerializerThread(threading.Thread):
"""Thread that retrieves call information from the queue, makes the
call, and returns the results."""
def __init__(self, call_queue):
def __init__(self, classname, call_queue):
threading.Thread.__init__(self)
self.call_queue = call_queue

@@ -39,54 +39,70 @@ class SerializerThread(threading.Thread):
result_queue.put((exception, result))
del exception, result

def _call_in_thread(__func, __queue, *args, **kwargs):
"""Make a call by putting it in the serialization thread's
queue and waiting for a response"""
result_queue = Queue.Queue()
__queue.put((result_queue, __func, args, kwargs))
( exc_info, result ) = result_queue.get()
if exc_info is None:
return result
else:
raise exc_info[0], exc_info[1], exc_info[2]
def serializer_proxy(obj_or_type):
"""Wrap the given object or type in a SerializerObjectProxy.

def serializer_proxy(__cls, *args, **kwargs):
"""Instantiates the given class with the given arguments,
and returns a SerializerObjectProxy object that proxies all method
calls to the new object, as well as attribute retrievals.
Returns a SerializerObjectProxy object that proxies all method
calls to the object, as well as attribute retrievals.

The proxied requests, including instantiation, are performed in a
single thread and serialized between caller threads.
"""
class SerializerCallProxy(object):
def __init__(self, call_queue, func):
def __init__(self, call_queue, func, objectproxy):
self.call_queue = call_queue
self.result_queue = Queue.Queue()
self.func = func
# Need to hold a reference to object proxy so it doesn't
# go away (and kill the thread) until after get called.
self.objectproxy = objectproxy
def __call__(self, *args, **kwargs):
return _call_in_thread(self.func, self.call_queue,
*args, **kwargs)
result_queue = Queue.Queue()
self.call_queue.put((result_queue, self.func, args, kwargs))
( exc_info, result ) = result_queue.get()
if exc_info is None:
return result
else:
raise exc_info[0], exc_info[1], exc_info[2]

class SerializerObjectProxy(object):
def __init__(self, __cls, *args, **kwargs):
self.__s_call_queue = Queue.Queue()
self.__s_thread = SerializerThread(
self.__s_call_queue)
self.__s_thread.daemon = True
self.__s_thread.start()
self.__s_object = _call_in_thread(__cls, self.__s_call_queue,
*args, **kwargs)
def __init__(self, obj_or_type, *args, **kwargs):
self.__object = obj_or_type
try:
if type(obj_or_type) in (types.TypeType, types.ClassType):
classname = obj_or_type.__name__
else:
classname = obj_or_type.__class__.__name__
except AttributeError: # pragma: no cover
classname = "???"
self.__call_queue = Queue.Queue()
self.__thread = SerializerThread(classname, self.__call_queue)
self.__thread.daemon = True
self.__thread.start()
self._thread_safe = True

def __getattr__(self, key):
attr = getattr(self.__s_object, key)
if key.startswith("_SerializerObjectProxy__"): # pragma: no cover
raise AttributeError
attr = getattr(self.__object, key)
if not callable(attr):
getter = SerializerCallProxy(self.__s_call_queue, getattr)
return getter(self.__s_object, key)
return SerializerCallProxy(self.__s_call_queue, attr)
getter = SerializerCallProxy(self.__call_queue, getattr, self)
return getter(self.__object, key)
r = SerializerCallProxy(self.__call_queue, attr, self)
return r

def __call__(self, *args, **kwargs):
"""Call this to instantiate the type, if a type was passed
to serializer_proxy. Otherwise, pass the call through."""
ret = SerializerCallProxy(self.__call_queue,
self.__object, self)(*args, **kwargs)
if type(self.__object) in (types.TypeType, types.ClassType):
# Instantiation
self.__object = ret
return self
return ret

def __del__(self):
self.__s_call_queue.put((None, None, None, None))
self.__s_thread.join()
self.__call_queue.put((None, None, None, None))
self.__thread.join()

return SerializerObjectProxy(__cls, *args, **kwargs)
return SerializerObjectProxy(obj_or_type)

+ 54
- 30
nilmdb/utils/threadsafety.py View File

@@ -1,34 +1,37 @@
from nilmdb.utils.printf import *
import threading
import warnings
import types

def verify_proxy(obj, exception = False,
check_thread = True, check_concurrent = True):
"""Return a VerifyObjectProxy that proxies all method calls
to the given object, as well as attribute retrievals.
def verify_proxy(obj_or_type, exception = False, check_thread = True,
check_concurrent = True):
"""Wrap the given object or type in a VerifyObjectProxy.

When calling methods, the following checks are performed.
If exception is True, an exception is raised. Otherwise, a warning
Returns a VerifyObjectProxy that proxies all method calls to the
given object, as well as attribute retrievals.

When calling methods, the following checks are performed. If
exception is True, an exception is raised. Otherwise, a warning
is printed.

check_thread = True # Warn/fail if two different threads call methods.
check_concurrent = True # Warn/fail if two functions are concurrently
# run through this proxy

Note that the __init__ call is not included in the checks, since
this wrapper is passed an already instantiated object.
"""

class Namespace(object):
pass
class VerifyCallProxy(object):
def __init__(self, func, parent):
def __init__(self, func, parent_namespace):
self.func = func
self.parent = parent
self.parent_namespace = parent_namespace

def __call__(self, *args, **kwargs):
p = self.parent
p = self.parent_namespace
this = threading.current_thread()
callee = self.func.__name__
classname = p.__class__.__name__
try:
callee = self.func.__name__
except AttributeError:
callee = "???"

if p.thread is None:
p.thread = this
@@ -37,8 +40,8 @@ def verify_proxy(obj, exception = False,
if check_thread and p.thread != this:
err = sprintf("unsafe threading: %s called %s.%s,"
" but %s called %s.%s",
p.thread.name, classname, p.thread_callee,
this.name, classname, callee)
p.thread.name, p.classname, p.thread_callee,
this.name, p.classname, callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
@@ -49,8 +52,8 @@ def verify_proxy(obj, exception = False,
if p.concur_lock.acquire(False) == False:
err = sprintf("unsafe concurrency: %s called %s.%s "
"while %s is still in %s.%s",
this.name, classname, callee,
p.concur_tname, classname, p.concur_callee)
this.name, p.classname, callee,
p.concur_tname, p.classname, p.concur_callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
@@ -68,18 +71,39 @@ def verify_proxy(obj, exception = False,
return ret

class VerifyObjectProxy(object):
def __init__(self, obj):
self.obj = obj
self.thread = None
self.thread_callee = None
self.concur_lock = threading.Lock()
self.concur_tname = None
self.concur_callee = None
def __init__(self, obj_or_type, *args, **kwargs):
p = Namespace()
self.__ns = p
p.thread = None
p.thread_callee = None
p.concur_lock = threading.Lock()
p.concur_tname = None
p.concur_callee = None
self.__obj = obj_or_type
try:
if type(obj_or_type) in (types.TypeType, types.ClassType):
p.classname = self.__obj.__name__
else:
p.classname = self.__obj.__class__.__name__
except AttributeError: # pragma: no cover
p.classname = "???"

def __getattr__(self, key):
attr = getattr(self.obj, key)
if key.startswith("_VerifyObjectProxy__"): # pragma: no cover
raise AttributeError
attr = getattr(self.__obj, key)
if not callable(attr):
return VerifyCallProxy(getattr, self)(self.obj, key)
return VerifyCallProxy(attr, self)
return VerifyCallProxy(getattr, self.__ns)(self.__obj, key)
return VerifyCallProxy(attr, self.__ns)

def __call__(self, *args, **kwargs):
"""Call this to instantiate the type, if a type was passed
to verify_proxy. Otherwise, pass the call through."""
ret = VerifyCallProxy(self.__obj, self.__ns)(*args, **kwargs)
if type(self.__obj) in (types.TypeType, types.ClassType):
# Instantiation
self.__obj = ret
return self
return ret

return VerifyObjectProxy(obj)
return VerifyObjectProxy(obj_or_type)

+ 1
- 1
tests/test_client.py View File

@@ -31,7 +31,7 @@ def setup_module():
recursive_unlink(testdb)

# Start web app on a custom port
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB, testdb, sync = False)
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(testdb, sync = False)
test_server = nilmdb.Server(test_db, host = "127.0.0.1",
port = 12380, stoppable = False,
fast_shutdown = True,


+ 4
- 4
tests/test_cmdline.py View File

@@ -27,10 +27,10 @@ testdb = "tests/cmdline-testdb"
def server_start(max_results = None, bulkdata_args = {}):
global test_server, test_db
# Start web app on a custom port
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB,
testdb, sync = False,
max_results = max_results,
bulkdata_args = bulkdata_args)
test_db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB)(
testdb, sync = False,
max_results = max_results,
bulkdata_args = bulkdata_args)
test_server = nilmdb.Server(test_db, host = "127.0.0.1",
port = 12380, stoppable = False,
fast_shutdown = True,


+ 4
- 4
tests/test_nilmdb.py View File

@@ -16,6 +16,8 @@ import Queue
import cStringIO
import time

from nilmdb.utils import serializer_proxy

testdb = "tests/testdb"

#@atexit.register
@@ -104,8 +106,7 @@ class Test00Nilmdb(object): # named 00 so it runs first

class TestBlockingServer(object):
def setUp(self):
self.db = nilmdb.utils.serializer_proxy(nilmdb.NilmDB,
testdb, sync=False)
self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False)

def tearDown(self):
self.db.close()
@@ -146,8 +147,7 @@ class TestServer(object):

def setUp(self):
# Start web app on a custom port
self.db = nilmdb.utils.serializer_proxy(
nilmdb.NilmDB, testdb, sync=False)
self.db = serializer_proxy(nilmdb.NilmDB)(testdb, sync=False)
self.server = nilmdb.Server(self.db, host = "127.0.0.1",
port = 12380, stoppable = False)
self.server.start(blocking = False)


+ 19
- 9
tests/test_serializer.py View File

@@ -25,6 +25,9 @@ class Foo(object):
def test(self, debug = False):
self.tester(debug)

def t(self):
pass

def tester(self, debug = False):
# purposely not thread-safe
self.test_thread = threading.current_thread().name
@@ -38,12 +41,12 @@ class Foo(object):

class Base(object):

def test_1_wrapping(self):
def test_wrapping(self):
self.foo.test()
with assert_raises(Exception):
self.foo.fail()

def test_2_threaded(self):
def test_threaded(self):
def func(foo):
foo.test()
threads = []
@@ -55,6 +58,10 @@ class Base(object):
t.join()
self.verify_result()

def verify_result(self):
eq_(self.foo.val, 20)
eq_(self.foo.init_thread, self.foo.test_thread)

class TestUnserialized(Base):
def setUp(self):
self.foo = Foo()
@@ -62,15 +69,18 @@ class TestUnserialized(Base):
def verify_result(self):
# This should have failed to increment properly
ne_(self.foo.val, 20)

# Init and tests ran in different threads
ne_(self.foo.init_thread, self.foo.test_thread)

class TestSerializer(Base):
# Now test the SerializerProxy version
def setUp(self):
self.foo = nilmdb.utils.serializer_proxy(Foo, "qwer")

def verify_result(self):
eq_(self.foo.val, 20)
eq_(self.foo.init_thread, self.foo.test_thread)
self.foo = nilmdb.utils.serializer_proxy(Foo)("qwer")

def test_multi(self):
sp = nilmdb.utils.serializer_proxy
sp(Foo("x")).t()
sp(sp(Foo)("x")).t()
sp(sp(Foo))("x").t()
sp(sp(Foo("x"))).t()
sp(sp(Foo)("x")).t()
sp(sp(Foo))("x").t()

+ 8
- 0
tests/test_threadsafety.py View File

@@ -37,6 +37,7 @@ class Test():
def bar(self, reenter):
if reenter:
self.foo()
return 123

def baz_threaded(self, target):
t = Thread(target)
@@ -74,10 +75,17 @@ class TestThreadSafety(object):

def test(self):
proxy = nilmdb.utils.threadsafety.verify_proxy
self.tryit(Test(), True, True)
self.tryit(proxy(Test(), True, True, True), False, False)
self.tryit(proxy(Test(), True, True, False), False, True)
self.tryit(proxy(Test(), True, False, True), True, False)
self.tryit(proxy(Test(), True, False, False), True, True)
self.tryit(proxy(Test, True, True, True)(), False, False)
self.tryit(proxy(Test, True, True, False)(), False, True)
self.tryit(proxy(Test, True, False, True)(), True, False)
self.tryit(proxy(Test, True, False, False)(), True, True)

proxy(proxy(proxy(Test))()).foo()

c = proxy(Test())
c.foo()


Loading…
Cancel
Save