Occasional segfaults may be the result of performing thread-unsafe operations. This class decorator verifies that all of its methods are called in a thread-safe manner. It can separately warn about: - two threads calling methods in a function (the kind of thing sqlite doesn't like) - recursion - concurrency (two different threads functions at the same time)tags/nilmdb-1.2
@@ -8,3 +8,4 @@ from nilmdb.utils.diskusage import du, human_size | |||
from nilmdb.utils.mustclose import must_close | |||
from nilmdb.utils.urllib import urlencode | |||
from nilmdb.utils import atomic | |||
import nilmdb.utils.threadsafety |
@@ -0,0 +1,91 @@ | |||
from nilmdb.utils.printf import * | |||
import threading | |||
import decorator | |||
import inspect | |||
def verify_thread_safety(check_thread = True, check_recursion = True, | |||
check_concurrent = True): | |||
"""Class decorator that raises an exception if the methods in the | |||
class are called from separate threads. | |||
check_thread = True # Fail if two different threads ever call methods. | |||
check_recursion = True # Fail if any functions is being run twice. | |||
check_concurrent = True # Fail if two different threads are calling any | |||
# functions concurrently. | |||
""" | |||
def class_decorator(cls): | |||
def wrap_class_method(wrapper): | |||
try: | |||
orig = getattr(cls, wrapper.__name__).im_func | |||
except: | |||
orig = lambda x: None | |||
setattr(cls, wrapper.__name__, decorator.decorator(wrapper, orig)) | |||
@wrap_class_method | |||
def __init__(orig, self, *args, **kwargs): | |||
self.__thread = threading.current_thread() | |||
self.__concur_lock = threading.Lock() | |||
self.__concur_name = None | |||
self.__concur_thread = None | |||
ret = orig(self, *args, **kwargs) | |||
return ret | |||
# Wrap all other functions with the verifier | |||
def verifier(orig, self, *args, **kwargs): | |||
this = threading.current_thread() | |||
if check_thread and self.__thread != this: | |||
err = sprintf("unsafe threading: %s called %s.__init__," | |||
" but %s called %s.%s", self.__thread.name, | |||
self.__class__.__name__, this.name, | |||
self.__class__.__name__, orig.__name__) | |||
raise AssertionError(err) | |||
if check_recursion and orig.__recurse_lock.acquire(False) == False: | |||
err = sprintf("unsafe recursion: %s called %s.%s " | |||
"but it's already being run by %s", | |||
this.name, | |||
self.__class__.__name__, orig.__name__, | |||
orig.__recurse_thread.name) | |||
raise AssertionError(err) | |||
need_concur_unlock = False | |||
if check_concurrent: | |||
if self.__concur_lock.acquire(False) == False: | |||
if this != self.__concur_thread: | |||
err = sprintf("unsafe concurrency: %s called %s.%s " | |||
"while %s is still in %s.%s", | |||
this.name, self.__class__.__name__, | |||
orig.__name__, self.__concur_thread.name, | |||
self.__class__.__name__, | |||
self.__concur_name) | |||
raise AssertionError(err) | |||
else: | |||
self.__concur_thread = this | |||
self.__concur_name = orig.__name__ | |||
need_concur_unlock = True | |||
orig.__recurse_thread = this | |||
try: | |||
ret = orig(self, *args, **kwargs) | |||
finally: | |||
if check_recursion: | |||
orig.__recurse_lock.release() | |||
if need_concur_unlock: | |||
self.__concur_lock.release() | |||
return ret | |||
for (name, method) in inspect.getmembers(cls, inspect.ismethod): | |||
# Skip class methods | |||
if method.__self__ is not None: | |||
continue | |||
# Skip some methods | |||
if name in [ "__del__", "__init__" ]: | |||
continue | |||
# Set up wrapper. Each function needs its own lock. | |||
method.im_func.__recurse_lock = threading.Lock() | |||
setattr(cls, name, decorator.decorator(verifier, method.im_func)) | |||
return cls | |||
return class_decorator |
@@ -20,6 +20,7 @@ cover-erase=1 | |||
stop=1 | |||
verbosity=2 | |||
tests=tests | |||
#tests=tests/test_threadsafety.py | |||
#tests=tests/test_bulkdata.py | |||
#tests=tests/test_mustclose.py | |||
#tests=tests/test_lrucache.py | |||
@@ -1,4 +1,5 @@ | |||
test_printf.py | |||
test_threadsafety.py | |||
test_lrucache.py | |||
test_mustclose.py | |||
@@ -0,0 +1,109 @@ | |||
import nilmdb | |||
from nilmdb.utils.printf import * | |||
import nose | |||
from nose.tools import * | |||
from nose.tools import assert_raises | |||
from testutil.helpers import * | |||
import threading | |||
class Thread(threading.Thread): | |||
def __init__(self, target): | |||
self.target = target | |||
threading.Thread.__init__(self) | |||
def run(self): | |||
try: | |||
self.target() | |||
except AssertionError as e: | |||
self.error = e | |||
else: | |||
self.error = None | |||
class Test(): | |||
def __init__(self): | |||
pass | |||
@classmethod | |||
def asdf(cls): | |||
pass | |||
def foo(self, exception = False, reenter = False): | |||
if exception: | |||
raise Exception() | |||
self.bar(reenter) | |||
def bar(self, reenter): | |||
if reenter: | |||
self.foo() | |||
def baz(self): | |||
t = Thread(self.foo) | |||
t.start() | |||
t.join() | |||
return t | |||
@nilmdb.utils.threadsafety.verify_thread_safety(False, False, False) | |||
class Test1(Test): | |||
pass | |||
@nilmdb.utils.threadsafety.verify_thread_safety(True, False, False) | |||
class Test2(Test): | |||
pass | |||
@nilmdb.utils.threadsafety.verify_thread_safety(False, True, False) | |||
class Test3(Test): | |||
pass | |||
@nilmdb.utils.threadsafety.verify_thread_safety(False, False, True) | |||
class Test4(Test): | |||
pass | |||
@nilmdb.utils.threadsafety.verify_thread_safety() | |||
class Test5(Test): | |||
pass | |||
@nilmdb.utils.threadsafety.verify_thread_safety() | |||
class Test6(object): | |||
pass | |||
class TestThreadSafety(object): | |||
def tryit(self, c, threading_ok, recursion_ok, concurrent_ok): | |||
t = Thread(c.foo) | |||
t.start() | |||
t.join() | |||
if threading_ok and t.error: | |||
raise Exception("got unexpected error: " + str(t.error)) | |||
if not threading_ok and not t.error: | |||
raise Exception("failed to get expected error") | |||
if recursion_ok: | |||
c.foo(reenter = True) | |||
else: | |||
with assert_raises(AssertionError) as e: | |||
c.foo(reenter = True) | |||
if not threading_ok: | |||
return | |||
t = c.baz() | |||
if concurrent_ok and t.error: | |||
raise Exception("got unexpected error: " + str(t.error)) | |||
if not concurrent_ok and not t.error: | |||
raise Exception("failed to get expected error") | |||
def test(self): | |||
self.tryit(Test(), True, True, True) | |||
self.tryit(Test1(), True, True, True) | |||
self.tryit(Test2(), False, True, True) | |||
self.tryit(Test3(), True, False, True) | |||
self.tryit(Test4(), True, True, False) | |||
self.tryit(Test5(), False, False, False) | |||
c = Test1() | |||
c.foo() | |||
try: | |||
c.foo(exception = True) | |||
except Exception: | |||
pass | |||
c.foo() | |||
d = Test6() |