Skip to content

Commit

Permalink
Merge pull request grpc#3147 from nathanielmanistaatgoogle/servicelin…
Browse files Browse the repository at this point in the history
…k-shut-down

Fix gRPC links lifecycle tracking
  • Loading branch information
soltanmm committed Aug 31, 2015
2 parents 8be7a04 + 154e762 commit 5684561
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 72 deletions.
115 changes: 72 additions & 43 deletions src/python/grpcio/grpc/_links/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
from grpc.framework.foundation import relay
from grpc.framework.interfaces.links import links

_STOP = _intermediary_low.Event.Kind.STOP
_WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED
_COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED
_READ = _intermediary_low.Event.Kind.READ_ACCEPTED
_METADATA = _intermediary_low.Event.Kind.METADATA_ACCEPTED
_FINISH = _intermediary_low.Event.Kind.FINISH


@enum.unique
class _Read(enum.Enum):
Expand All @@ -67,7 +74,7 @@ class _RPCState(object):

def __init__(
self, call, request_serializer, response_deserializer, sequence_number,
read, allowance, high_write, low_write):
read, allowance, high_write, low_write, due):
self.call = call
self.request_serializer = request_serializer
self.response_deserializer = response_deserializer
Expand All @@ -76,6 +83,13 @@ def __init__(
self.allowance = allowance
self.high_write = high_write
self.low_write = low_write
self.due = due


def _no_longer_due(kind, rpc_state, key, rpc_states):
rpc_state.due.remove(kind)
if not rpc_state.due:
del rpc_states[key]


class _Kernel(object):
Expand All @@ -91,12 +105,14 @@ def __init__(
self._relay = ticket_relay

self._completion_queue = None
self._rpc_states = None
self._rpc_states = {}
self._pool = None

def _on_write_event(self, operation_id, unused_event, rpc_state):
if rpc_state.high_write is _HighWrite.CLOSED:
rpc_state.call.complete(operation_id)
rpc_state.due.add(_COMPLETE)
rpc_state.due.remove(_WRITE)
rpc_state.low_write = _LowWrite.CLOSED
else:
ticket = links.Ticket(
Expand All @@ -105,16 +121,19 @@ def _on_write_event(self, operation_id, unused_event, rpc_state):
rpc_state.sequence_number += 1
self._relay.add_value(ticket)
rpc_state.low_write = _LowWrite.OPEN
_no_longer_due(_WRITE, rpc_state, operation_id, self._rpc_states)

def _on_read_event(self, operation_id, event, rpc_state):
if event.bytes is None:
if event.bytes is None or _FINISH not in rpc_state.due:
rpc_state.read = _Read.CLOSED
_no_longer_due(_READ, rpc_state, operation_id, self._rpc_states)
else:
if 0 < rpc_state.allowance:
rpc_state.allowance -= 1
rpc_state.call.read(operation_id)
else:
rpc_state.read = _Read.AWAITING_ALLOWANCE
_no_longer_due(_READ, rpc_state, operation_id, self._rpc_states)
ticket = links.Ticket(
operation_id, rpc_state.sequence_number, None, None, None, None, None,
None, rpc_state.response_deserializer(event.bytes), None, None, None,
Expand All @@ -123,18 +142,23 @@ def _on_read_event(self, operation_id, event, rpc_state):
self._relay.add_value(ticket)

def _on_metadata_event(self, operation_id, event, rpc_state):
rpc_state.allowance -= 1
rpc_state.call.read(operation_id)
rpc_state.read = _Read.READING
ticket = links.Ticket(
operation_id, rpc_state.sequence_number, None, None,
links.Ticket.Subscription.FULL, None, None, event.metadata, None, None,
None, None, None, None)
rpc_state.sequence_number += 1
self._relay.add_value(ticket)
if _FINISH in rpc_state.due:
rpc_state.allowance -= 1
rpc_state.call.read(operation_id)
rpc_state.read = _Read.READING
rpc_state.due.add(_READ)
rpc_state.due.remove(_METADATA)
ticket = links.Ticket(
operation_id, rpc_state.sequence_number, None, None,
links.Ticket.Subscription.FULL, None, None, event.metadata, None,
None, None, None, None, None)
rpc_state.sequence_number += 1
self._relay.add_value(ticket)
else:
_no_longer_due(_METADATA, rpc_state, operation_id, self._rpc_states)

def _on_finish_event(self, operation_id, event, rpc_state):
self._rpc_states.pop(operation_id, None)
_no_longer_due(_FINISH, rpc_state, operation_id, self._rpc_states)
if event.status.code is _intermediary_low.Code.OK:
termination = links.Ticket.Termination.COMPLETION
elif event.status.code is _intermediary_low.Code.CANCELLED:
Expand All @@ -155,26 +179,26 @@ def _on_finish_event(self, operation_id, event, rpc_state):
def _spin(self, completion_queue):
while True:
event = completion_queue.get(None)
if event.kind is _intermediary_low.Event.Kind.STOP:
return
operation_id = event.tag
with self._lock:
if self._completion_queue is None:
continue
rpc_state = self._rpc_states.get(operation_id)
if rpc_state is not None:
if event.kind is _intermediary_low.Event.Kind.WRITE_ACCEPTED:
self._on_write_event(operation_id, event, rpc_state)
elif event.kind is _intermediary_low.Event.Kind.METADATA_ACCEPTED:
self._on_metadata_event(operation_id, event, rpc_state)
elif event.kind is _intermediary_low.Event.Kind.READ_ACCEPTED:
self._on_read_event(operation_id, event, rpc_state)
elif event.kind is _intermediary_low.Event.Kind.FINISH:
self._on_finish_event(operation_id, event, rpc_state)
elif event.kind is _intermediary_low.Event.Kind.COMPLETE_ACCEPTED:
pass
else:
logging.error('Illegal RPC event! %s', (event,))
rpc_state = self._rpc_states.get(event.tag, None)
if event.kind is _STOP:
pass
elif event.kind is _WRITE:
self._on_write_event(event.tag, event, rpc_state)
elif event.kind is _METADATA:
self._on_metadata_event(event.tag, event, rpc_state)
elif event.kind is _READ:
self._on_read_event(event.tag, event, rpc_state)
elif event.kind is _FINISH:
self._on_finish_event(event.tag, event, rpc_state)
elif event.kind is _COMPLETE:
_no_longer_due(_COMPLETE, rpc_state, event.tag, self._rpc_states)
else:
logging.error('Illegal RPC event! %s', (event,))

if self._completion_queue is None and not self._rpc_states:
completion_queue.stop()
return

def _invoke(
self, operation_id, group, method, initial_metadata, payload, termination,
Expand Down Expand Up @@ -221,46 +245,53 @@ def _invoke(
if high_write is _HighWrite.CLOSED:
call.complete(operation_id)
low_write = _LowWrite.CLOSED
due = set((_METADATA, _COMPLETE, _FINISH,))
else:
low_write = _LowWrite.OPEN
due = set((_METADATA, _FINISH,))
else:
call.write(request_serializer(payload), operation_id)
low_write = _LowWrite.ACTIVE
due = set((_WRITE, _METADATA, _FINISH,))
self._rpc_states[operation_id] = _RPCState(
call, request_serializer, response_deserializer, 0,
_Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance),
high_write, low_write)
high_write, low_write, due)

def _advance(self, operation_id, rpc_state, payload, termination, allowance):
if payload is not None:
rpc_state.call.write(rpc_state.request_serializer(payload), operation_id)
rpc_state.low_write = _LowWrite.ACTIVE
rpc_state.due.add(_WRITE)

if allowance is not None:
if rpc_state.read is _Read.AWAITING_ALLOWANCE:
rpc_state.allowance += allowance - 1
rpc_state.call.read(operation_id)
rpc_state.read = _Read.READING
rpc_state.due.add(_READ)
else:
rpc_state.allowance += allowance

if termination is links.Ticket.Termination.COMPLETION:
rpc_state.high_write = _HighWrite.CLOSED
if rpc_state.low_write is _LowWrite.OPEN:
rpc_state.call.complete(operation_id)
rpc_state.due.add(_COMPLETE)
rpc_state.low_write = _LowWrite.CLOSED
elif termination is not None:
rpc_state.call.cancel()

def add_ticket(self, ticket):
with self._lock:
if self._completion_queue is None:
return
if ticket.sequence_number == 0:
self._invoke(
ticket.operation_id, ticket.group, ticket.method,
ticket.initial_metadata, ticket.payload, ticket.termination,
ticket.timeout, ticket.allowance)
if self._completion_queue is None:
logging.error('Received invocation ticket %s after stop!', ticket)
else:
self._invoke(
ticket.operation_id, ticket.group, ticket.method,
ticket.initial_metadata, ticket.payload, ticket.termination,
ticket.timeout, ticket.allowance)
else:
rpc_state = self._rpc_states.get(ticket.operation_id)
if rpc_state is not None:
Expand All @@ -276,7 +307,6 @@ def start(self):
"""
with self._lock:
self._completion_queue = _intermediary_low.CompletionQueue()
self._rpc_states = {}
self._pool = logging_pool.pool(1)
self._pool.submit(self._spin, self._completion_queue)

Expand All @@ -288,11 +318,10 @@ def stop(self):
has been called.
"""
with self._lock:
self._completion_queue.stop()
if not self._rpc_states:
self._completion_queue.stop()
self._completion_queue = None
pool = self._pool
self._pool = None
self._rpc_states = None
pool.shutdown(wait=True)


Expand Down
Loading

0 comments on commit 5684561

Please sign in to comment.