Browse Source

Replace threadsafety class decorator version, add explicit proxy version

Like the serializer changes, the class decorator was too fragile.
tags/nilmdb-1.2
Jim Paris 9 years ago
parent
commit
422317850e
2 changed files with 86 additions and 113 deletions
  1. +62
    -68
      nilmdb/utils/threadsafety.py
  2. +24
    -45
      tests/test_threadsafety.py

+ 62
- 68
nilmdb/utils/threadsafety.py View File

@@ -1,91 +1,85 @@
from nilmdb.utils.printf import *
import threading
import decorator
import inspect
import warnings

def verify_thread_safety(check_thread = True, check_recursion = True,
check_concurrent = True):
"""Class decorator that raises an exception if the methods in the
class are called from separate threads.
def verify_proxy(obj, exception = False,
check_thread = True, check_concurrent = True):
"""Return a VerifyObjectProxy that proxies all method calls
to the given object, as well as attribute retrievals.

check_thread = True # Fail if two different threads ever call methods.
check_recursion = True # Fail if any functions is being run twice.
check_concurrent = True # Fail if two different threads are calling any
# functions concurrently.
"""
def class_decorator(cls):
When calling methods, the following checks are performed.
If exception is True, an exception is raised. Otherwise, a warning
is printed.

def wrap_class_method(wrapper):
try:
orig = getattr(cls, wrapper.__name__).im_func
except:
orig = lambda x: None
setattr(cls, wrapper.__name__, decorator.decorator(wrapper, orig))
check_thread = True # Warn/fail if two different threads call methods.
check_concurrent = True # Warn/fail if two functions are concurrently
# run through this proxy

@wrap_class_method
def __init__(orig, self, *args, **kwargs):
self.__thread = threading.current_thread()
self.__concur_lock = threading.Lock()
self.__concur_name = None
self.__concur_thread = None
ret = orig(self, *args, **kwargs)
return ret
Note that the __init__ call is not included in the checks, since
this wrapper is passed an already instantiated object.
"""
class VerifyCallProxy(object):
def __init__(self, func, parent):
self.func = func
self.parent = parent

# Wrap all other functions with the verifier
def verifier(orig, self, *args, **kwargs):
def __call__(self, *args, **kwargs):
p = self.parent
this = threading.current_thread()
callee = self.func.__name__
classname = p.__class__.__name__

if check_thread and self.__thread != this:
err = sprintf("unsafe threading: %s called %s.__init__,"
" but %s called %s.%s", self.__thread.name,
self.__class__.__name__, this.name,
self.__class__.__name__, orig.__name__)
raise AssertionError(err)
if p.thread is None:
p.thread = this
p.thread_callee = callee

if check_recursion and orig.__recurse_lock.acquire(False) == False:
err = sprintf("unsafe recursion: %s called %s.%s "
"but it's already being run by %s",
this.name,
self.__class__.__name__, orig.__name__,
orig.__recurse_thread.name)
raise AssertionError(err)
if check_thread and p.thread != this:
err = sprintf("unsafe threading: %s called %s.%s,"
" but %s called %s.%s",
p.thread.name, classname, p.thread_callee,
this.name, classname, callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
warnings.warn(err)

need_concur_unlock = False
if check_concurrent:
if self.__concur_lock.acquire(False) == False:
if this != self.__concur_thread:
err = sprintf("unsafe concurrency: %s called %s.%s "
"while %s is still in %s.%s",
this.name, self.__class__.__name__,
orig.__name__, self.__concur_thread.name,
self.__class__.__name__,
self.__concur_name)
if p.concur_lock.acquire(False) == False:
err = sprintf("unsafe concurrency: %s called %s.%s "
"while %s is still in %s.%s",
this.name, classname, callee,
p.concur_tname, classname, p.concur_callee)
if exception:
raise AssertionError(err)
else: # pragma: no cover
warnings.warn(err)
else:
self.__concur_thread = this
self.__concur_name = orig.__name__
p.concur_tname = this.name
p.concur_callee = callee
need_concur_unlock = True

orig.__recurse_thread = this
try:
ret = orig(self, *args, **kwargs)
ret = self.func(*args, **kwargs)
finally:
if check_recursion:
orig.__recurse_lock.release()
if need_concur_unlock:
self.__concur_lock.release()
p.concur_lock.release()
return ret

for (name, method) in inspect.getmembers(cls, inspect.ismethod):
# Skip class methods
if method.__self__ is not None:
continue
# Skip some methods
if name in [ "__del__", "__init__" ]:
continue
# Set up wrapper. Each function needs its own lock.
method.im_func.__recurse_lock = threading.Lock()
setattr(cls, name, decorator.decorator(verifier, method.im_func))
class VerifyObjectProxy(object):
def __init__(self, obj):
self.obj = obj
self.thread = None
self.thread_callee = None
self.concur_lock = threading.Lock()
self.concur_tname = None
self.concur_callee = None

def __getattr__(self, key):
attr = getattr(self.obj, key)
if not callable(attr):
return VerifyCallProxy(getattr, self)(self.obj, key)
return VerifyCallProxy(attr, self)

return cls
return class_decorator
return VerifyObjectProxy(obj)

+ 24
- 45
tests/test_threadsafety.py View File

@@ -23,7 +23,7 @@ class Thread(threading.Thread):

class Test():
def __init__(self):
pass
self.test = 1234

@classmethod
def asdf(cls):
@@ -38,38 +38,19 @@ class Test():
if reenter:
self.foo()

def baz(self):
t = Thread(self.foo)
def baz_threaded(self, target):
t = Thread(target)
t.start()
t.join()
return t

@nilmdb.utils.threadsafety.verify_thread_safety(False, False, False)
class Test1(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(True, False, False)
class Test2(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(False, True, False)
class Test3(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(False, False, True)
class Test4(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety()
class Test5(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety()
class Test6(object):
pass
def baz(self, target):
target()

class TestThreadSafety(object):
def tryit(self, c, threading_ok, recursion_ok, concurrent_ok):
def tryit(self, c, threading_ok, concurrent_ok):
eq_(c.test, 1234)
c.foo()
t = Thread(c.foo)
t.start()
t.join()
@@ -77,33 +58,31 @@ class TestThreadSafety(object):
raise Exception("got unexpected error: " + str(t.error))
if not threading_ok and not t.error:
raise Exception("failed to get expected error")
if recursion_ok:
c.foo(reenter = True)
try:
c.baz(c.foo)
except AssertionError as e:
if concurrent_ok:
raise Exception("got unexpected error: " + str(e))
else:
with assert_raises(AssertionError) as e:
c.foo(reenter = True)
if not threading_ok:
return
t = c.baz()
if concurrent_ok and t.error:
if not concurrent_ok:
raise Exception("failed to get expected error")
t = c.baz_threaded(c.foo)
if (concurrent_ok and threading_ok) and t.error:
raise Exception("got unexpected error: " + str(t.error))
if not concurrent_ok and not t.error:
if not (concurrent_ok and threading_ok) and not t.error:
raise Exception("failed to get expected error")

def test(self):
self.tryit(Test(), True, True, True)
self.tryit(Test1(), True, True, True)
self.tryit(Test2(), False, True, True)
self.tryit(Test3(), True, False, True)
self.tryit(Test4(), True, True, False)
self.tryit(Test5(), False, False, False)
proxy = nilmdb.utils.threadsafety.verify_proxy
self.tryit(proxy(Test(), True, True, True), False, False)
self.tryit(proxy(Test(), True, True, False), False, True)
self.tryit(proxy(Test(), True, False, True), True, False)
self.tryit(proxy(Test(), True, False, False), True, True)

c = Test1()
c = proxy(Test())
c.foo()
try:
c.foo(exception = True)
except Exception:
pass
c.foo()

d = Test6()

Loading…
Cancel
Save