Skip to content

Commit

Permalink
Gracefully handle errors from callbacks.
Browse files Browse the repository at this point in the history
In grpc#19910, it was pointed out that
raising an exception from a Future callback would cause the channel spin
thread to terminate. If there are outstanding events on the channel,
this will cause calls to Channel.close() to block indefinitely.

This commit ensures that the channel spin thread does not die. Instead,
exceptions will be logged at ERROR level.
  • Loading branch information
gnossen committed Aug 20, 2019
1 parent 073b234 commit 09a270d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/python/grpcio/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def add_done_callback(self, fn):
If the computation has already completed, the callback will be called
immediately.
Exceptions raised in the callback will be logged at ERROR level, but
will not terminate any threads of execution.
Args:
fn: A callable taking this Future object as its single parameter.
"""
Expand Down
11 changes: 9 additions & 2 deletions src/python/grpcio/grpc/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,14 @@ def handle_event(event):
state.condition.notify_all()
done = not state.due
for callback in callbacks:
callback()
# TODO(gnossen): Are these *only* user callbacks?
try:
callback()
except Exception as e: # pylint: disable=broad-except
# NOTE(rbellevi): We suppress but log errors here so as not to
# kill the channel spin thread.
logging.error('Exception in callback %s: %s', repr(
callback.func), repr(e))
return done and state.fork_epoch >= cygrpc.get_fork_epoch()

return handle_event
Expand Down Expand Up @@ -338,7 +345,7 @@ def traceback(self, timeout=None):
def add_done_callback(self, fn):
with self._state.condition:
if self._state.code is None:
self._state.callbacks.append(lambda: fn(self))
self._state.callbacks.append(functools.partial(fn, self))
return

fn(self)
Expand Down
67 changes: 59 additions & 8 deletions src/python/grpcio_tests/tests/unit/_channel_close_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
_SOME_TIME = 5
_MORE_TIME = 10

_STREAM_URI = 'Meffod'
_UNARY_URI = 'MeffodMan'

class _MethodHandler(grpc.RpcMethodHandler):

class _StreamingMethodHandler(grpc.RpcMethodHandler):

request_streaming = True
response_streaming = True
Expand All @@ -40,13 +43,28 @@ def stream_stream(self, request_iterator, servicer_context):
yield request * 2


_METHOD_HANDLER = _MethodHandler()
class _UnaryMethodHandler(grpc.RpcMethodHandler):

request_streaming = False
response_streaming = False
request_deserializer = None
response_serializer = None

def unary_unary(self, request, servicer_context):
return request * 2


_STREAMING_METHOD_HANDLER = _StreamingMethodHandler()
_UNARY_METHOD_HANDLER = _UnaryMethodHandler()


class _GenericHandler(grpc.GenericRpcHandler):

def service(self, handler_call_details):
return _METHOD_HANDLER
if handler_call_details.method == _STREAM_URI:
return _STREAMING_METHOD_HANDLER
else:
return _UNARY_METHOD_HANDLER


_GENERIC_HANDLER = _GenericHandler()
Expand Down Expand Up @@ -94,6 +112,24 @@ def __exit__(self, type, value, traceback):
self.close()


class EndlessIterator(object):

def __init__(self, msg):
self._msg = msg

def __iter__(self):
return self

def _next(self):
return self._msg

def __next__(self):
return self._next()

def next(self):
return self._next()


class ChannelCloseTest(unittest.TestCase):

def setUp(self):
Expand All @@ -108,7 +144,7 @@ def tearDown(self):

def test_close_immediately_after_call_invocation(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod')
multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe(())
response_iterator = multi_callable(request_iterator)
channel.close()
Expand All @@ -118,7 +154,7 @@ def test_close_immediately_after_call_invocation(self):

def test_close_while_call_active(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod')
multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
Expand All @@ -130,7 +166,7 @@ def test_close_while_call_active(self):
def test_context_manager_close_while_call_active(self):
with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream('Meffod')
multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
Expand All @@ -141,7 +177,7 @@ def test_context_manager_close_while_call_active(self):
def test_context_manager_close_while_many_calls_active(self):
with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream('Meffod')
multi_callable = channel.stream_stream(_STREAM_URI)
request_iterators = tuple(
_Pipe((b'abc',))
for _ in range(test_constants.THREAD_CONCURRENCY))
Expand All @@ -158,7 +194,7 @@ def test_context_manager_close_while_many_calls_active(self):

def test_many_concurrent_closes(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod')
multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
Expand All @@ -181,6 +217,21 @@ def sleep_some_time_then_close():

self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)

def test_exception_in_callback(self):
with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel:
stream_multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = (str(i).encode('ascii') for i in range(9999))
endless_iterator = EndlessIterator(b'abc')
stream_response_iterator = stream_multi_callable(endless_iterator)
future = channel.unary_unary(_UNARY_URI).future(b'abc')

def on_done_callback(future):
raise Exception("This should not cause a deadlock.")

future.add_done_callback(on_done_callback)
future.result()


if __name__ == '__main__':
logging.basicConfig()
Expand Down

0 comments on commit 09a270d

Please sign in to comment.