diff --git a/nilmdb/httpclient.py b/nilmdb/httpclient.py index e183cda..ca3b735 100644 --- a/nilmdb/httpclient.py +++ b/nilmdb/httpclient.py @@ -119,7 +119,7 @@ class HTTPClient(object): self.curl.setopt(pycurl.WRITEFUNCTION, callback) self.curl.perform() try: - with nilmdb.utils.Iteratorizer(func) as it: + with nilmdb.utils.Iteratorizer(func, curl_hack = True) as it: for i in it: if self._status == 200: # If we had a 200 response, yield the data to caller. diff --git a/nilmdb/utils/iteratorizer.py b/nilmdb/utils/iteratorizer.py index 7ca1167..190e6dc 100644 --- a/nilmdb/utils/iteratorizer.py +++ b/nilmdb/utils/iteratorizer.py @@ -10,7 +10,7 @@ import contextlib # Based partially on http://stackoverflow.com/questions/9968592/ class IteratorizerThread(threading.Thread): - def __init__(self, queue, function): + def __init__(self, queue, function, curl_hack): """ function: function to execute, which takes the callback (provided by this class) as an argument @@ -19,11 +19,24 @@ class IteratorizerThread(threading.Thread): self.function = function self.queue = queue self.die = False + self.curl_hack = curl_hack def callback(self, data): - if self.die: - raise Exception() # trigger termination - self.queue.put((1, data)) + try: + if self.die: + raise Exception() # trigger termination + self.queue.put((1, data)) + except: + if self.curl_hack: + # We can't raise exceptions, because the pycurl + # extension module will unconditionally print the + # exception itself, and not pass it up to the caller. + # Instead, just return a value that tells curl to + # abort. (-1 would be best, in case we were given 0 + # bytes, but the extension doesn't support that). + self.queue.put((2, sys.exc_info())) + return 0 + raise def run(self): try: @@ -34,7 +47,7 @@ class IteratorizerThread(threading.Thread): self.queue.put((0, result)) @contextlib.contextmanager -def Iteratorizer(function): +def Iteratorizer(function, curl_hack = False): """ Context manager that takes a function expecting a callback, and provides an iterable that yields the values passed to that @@ -49,7 +62,7 @@ def Iteratorizer(function): print 'function returned:', it.retval """ queue = Queue.Queue(maxsize = 1) - thread = IteratorizerThread(queue, function) + thread = IteratorizerThread(queue, function, curl_hack) thread.daemon = True thread.start() diff --git a/tests/test_iteratorizer.py b/tests/test_iteratorizer.py index 7e628b1..9b65ad6 100644 --- a/tests/test_iteratorizer.py +++ b/tests/test_iteratorizer.py @@ -52,3 +52,10 @@ class TestIteratorizer(object): it.next() foo() eq_(it.retval, None) + + # Do the same thing when the curl hack is applied + def foo(): + with nilmdb.utils.Iteratorizer(f, curl_hack = True) as it: + it.next() + foo() + eq_(it.retval, None)