Browse Source

Implement verify_thread_safety to check for unsafe access patterns

Occasional segfaults may be the result of performing thread-unsafe
operations.  This class decorator verifies that all of its methods
are called in a thread-safe manner.

It can separately warn about:
- two threads calling methods in a function (the kind of thing sqlite
  doesn't like)
- recursion
- concurrency (two different threads functions at the same time)
tags/nilmdb-1.2
Jim Paris 10 years ago
parent
commit
965537d8cb
5 changed files with 203 additions and 0 deletions
  1. +1
    -0
      nilmdb/utils/__init__.py
  2. +91
    -0
      nilmdb/utils/threadsafety.py
  3. +1
    -0
      setup.cfg
  4. +1
    -0
      tests/test.order
  5. +109
    -0
      tests/test_threadsafety.py

+ 1
- 0
nilmdb/utils/__init__.py View File

@@ -8,3 +8,4 @@ from nilmdb.utils.diskusage import du, human_size
from nilmdb.utils.mustclose import must_close
from nilmdb.utils.urllib import urlencode
from nilmdb.utils import atomic
import nilmdb.utils.threadsafety

+ 91
- 0
nilmdb/utils/threadsafety.py View File

@@ -0,0 +1,91 @@
from nilmdb.utils.printf import *
import threading
import decorator
import inspect

def verify_thread_safety(check_thread = True, check_recursion = True,
check_concurrent = True):
"""Class decorator that raises an exception if the methods in the
class are called from separate threads.

check_thread = True # Fail if two different threads ever call methods.
check_recursion = True # Fail if any functions is being run twice.
check_concurrent = True # Fail if two different threads are calling any
# functions concurrently.
"""
def class_decorator(cls):

def wrap_class_method(wrapper):
try:
orig = getattr(cls, wrapper.__name__).im_func
except:
orig = lambda x: None
setattr(cls, wrapper.__name__, decorator.decorator(wrapper, orig))

@wrap_class_method
def __init__(orig, self, *args, **kwargs):
self.__thread = threading.current_thread()
self.__concur_lock = threading.Lock()
self.__concur_name = None
self.__concur_thread = None
ret = orig(self, *args, **kwargs)
return ret

# Wrap all other functions with the verifier
def verifier(orig, self, *args, **kwargs):
this = threading.current_thread()

if check_thread and self.__thread != this:
err = sprintf("unsafe threading: %s called %s.__init__,"
" but %s called %s.%s", self.__thread.name,
self.__class__.__name__, this.name,
self.__class__.__name__, orig.__name__)
raise AssertionError(err)

if check_recursion and orig.__recurse_lock.acquire(False) == False:
err = sprintf("unsafe recursion: %s called %s.%s "
"but it's already being run by %s",
this.name,
self.__class__.__name__, orig.__name__,
orig.__recurse_thread.name)
raise AssertionError(err)

need_concur_unlock = False
if check_concurrent:
if self.__concur_lock.acquire(False) == False:
if this != self.__concur_thread:
err = sprintf("unsafe concurrency: %s called %s.%s "
"while %s is still in %s.%s",
this.name, self.__class__.__name__,
orig.__name__, self.__concur_thread.name,
self.__class__.__name__,
self.__concur_name)
raise AssertionError(err)
else:
self.__concur_thread = this
self.__concur_name = orig.__name__
need_concur_unlock = True

orig.__recurse_thread = this
try:
ret = orig(self, *args, **kwargs)
finally:
if check_recursion:
orig.__recurse_lock.release()
if need_concur_unlock:
self.__concur_lock.release()
return ret

for (name, method) in inspect.getmembers(cls, inspect.ismethod):
# Skip class methods
if method.__self__ is not None:
continue
# Skip some methods
if name in [ "__del__", "__init__" ]:
continue
# Set up wrapper. Each function needs its own lock.
method.im_func.__recurse_lock = threading.Lock()
setattr(cls, name, decorator.decorator(verifier, method.im_func))

return cls
return class_decorator

+ 1
- 0
setup.cfg View File

@@ -20,6 +20,7 @@ cover-erase=1
stop=1
verbosity=2
tests=tests
#tests=tests/test_threadsafety.py
#tests=tests/test_bulkdata.py
#tests=tests/test_mustclose.py
#tests=tests/test_lrucache.py


+ 1
- 0
tests/test.order View File

@@ -1,4 +1,5 @@
test_printf.py
test_threadsafety.py
test_lrucache.py
test_mustclose.py



+ 109
- 0
tests/test_threadsafety.py View File

@@ -0,0 +1,109 @@
import nilmdb
from nilmdb.utils.printf import *

import nose
from nose.tools import *
from nose.tools import assert_raises

from testutil.helpers import *
import threading

class Thread(threading.Thread):
def __init__(self, target):
self.target = target
threading.Thread.__init__(self)

def run(self):
try:
self.target()
except AssertionError as e:
self.error = e
else:
self.error = None

class Test():
def __init__(self):
pass

@classmethod
def asdf(cls):
pass

def foo(self, exception = False, reenter = False):
if exception:
raise Exception()
self.bar(reenter)

def bar(self, reenter):
if reenter:
self.foo()

def baz(self):
t = Thread(self.foo)
t.start()
t.join()
return t

@nilmdb.utils.threadsafety.verify_thread_safety(False, False, False)
class Test1(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(True, False, False)
class Test2(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(False, True, False)
class Test3(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety(False, False, True)
class Test4(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety()
class Test5(Test):
pass

@nilmdb.utils.threadsafety.verify_thread_safety()
class Test6(object):
pass

class TestThreadSafety(object):
def tryit(self, c, threading_ok, recursion_ok, concurrent_ok):
t = Thread(c.foo)
t.start()
t.join()
if threading_ok and t.error:
raise Exception("got unexpected error: " + str(t.error))
if not threading_ok and not t.error:
raise Exception("failed to get expected error")
if recursion_ok:
c.foo(reenter = True)
else:
with assert_raises(AssertionError) as e:
c.foo(reenter = True)
if not threading_ok:
return
t = c.baz()
if concurrent_ok and t.error:
raise Exception("got unexpected error: " + str(t.error))
if not concurrent_ok and not t.error:
raise Exception("failed to get expected error")

def test(self):
self.tryit(Test(), True, True, True)
self.tryit(Test1(), True, True, True)
self.tryit(Test2(), False, True, True)
self.tryit(Test3(), True, False, True)
self.tryit(Test4(), True, True, False)
self.tryit(Test5(), False, False, False)

c = Test1()
c.foo()
try:
c.foo(exception = True)
except Exception:
pass
c.foo()

d = Test6()

Loading…
Cancel
Save