Skip to content

Commit

Permalink
Merge pull request grpc#3269 from nathanielmanistaatgoogle/moar-proto…
Browse files Browse the repository at this point in the history
…col-objects

The gRPC protocol objects
  • Loading branch information
soltanmm committed Sep 8, 2015
2 parents a4836ad + 41abb05 commit 8905be0
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 24 deletions.
10 changes: 7 additions & 3 deletions src/python/grpcio/grpc/_adapter/_intermediary_low.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

_IGNORE_ME_TAG = object()
Code = _types.StatusCode
WriteFlags = _types.OpWriteFlags


class Status(collections.namedtuple('Status', ['code', 'details'])):
Expand Down Expand Up @@ -125,9 +126,9 @@ def invoke(self, completion_queue, metadata_tag, finish_tag):
], _TagAdapter(finish_tag, Event.Kind.FINISH))
return err0 if err0 != _types.CallError.OK else err1 if err1 != _types.CallError.OK else err2 if err2 != _types.CallError.OK else _types.CallError.OK

def write(self, message, tag):
def write(self, message, tag, flags):
return self._internal.start_batch([
_types.OpArgs.send_message(message, 0)
_types.OpArgs.send_message(message, flags)
], _TagAdapter(tag, Event.Kind.WRITE_ACCEPTED))

def complete(self, tag):
Expand Down Expand Up @@ -163,8 +164,11 @@ def status(self, status, tag):
def cancel(self):
return self._internal.cancel()

def peer(self):
return self._internal.peer()

def set_credentials(self, creds):
return self._internal.set_credentials(creds)
return self._internal.set_credentials(creds._internal)


class Channel(object):
Expand Down
4 changes: 2 additions & 2 deletions src/python/grpcio/grpc/_adapter/fore.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class _LowWrite(enum.Enum):
def _write(call, rpc_state, payload):
serialized_payload = rpc_state.serializer(payload)
if rpc_state.write.low is _LowWrite.OPEN:
call.write(serialized_payload, call)
call.write(serialized_payload, call, 0)
rpc_state.write.low = _LowWrite.ACTIVE
else:
rpc_state.write.pending.append(serialized_payload)
Expand Down Expand Up @@ -164,7 +164,7 @@ def _on_write_event(self, event):

if rpc_state.write.pending:
serialized_payload = rpc_state.write.pending.pop(0)
call.write(serialized_payload, call)
call.write(serialized_payload, call, 0)
elif rpc_state.write.high is _common.HighWrite.CLOSED:
_status(call, rpc_state)
else:
Expand Down
6 changes: 3 additions & 3 deletions src/python/grpcio/grpc/_adapter/rear.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, call, outstanding, active, common):

def _write(operation_id, call, outstanding, write_state, serialized_payload):
if write_state.low is _LowWrite.OPEN:
call.write(serialized_payload, operation_id)
call.write(serialized_payload, operation_id, 0)
outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
write_state.low = _LowWrite.ACTIVE
elif write_state.low is _LowWrite.ACTIVE:
Expand Down Expand Up @@ -144,7 +144,7 @@ def _on_write_event(self, operation_id, event, rpc_state):
if event.write_accepted:
if rpc_state.common.write.pending:
rpc_state.call.write(
rpc_state.common.write.pending.pop(0), operation_id)
rpc_state.common.write.pending.pop(0), operation_id, 0)
rpc_state.outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
elif rpc_state.common.write.high is _common.HighWrite.CLOSED:
rpc_state.call.complete(operation_id)
Expand Down Expand Up @@ -263,7 +263,7 @@ def _invoke(self, operation_id, name, high_state, payload, timeout):
low_state = _LowWrite.OPEN
else:
serialized_payload = request_serializer(payload)
call.write(serialized_payload, operation_id)
call.write(serialized_payload, operation_id, 0)
outstanding.add(_low.Event.Kind.WRITE_ACCEPTED)
low_state = _LowWrite.ACTIVE

Expand Down
57 changes: 50 additions & 7 deletions src/python/grpcio/grpc/_links/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from grpc._adapter import _intermediary_low
from grpc._links import _constants
from grpc.beta import interfaces as beta_interfaces
from grpc.framework.foundation import activated
from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import relay
Expand Down Expand Up @@ -73,11 +74,28 @@ class _LowWrite(enum.Enum):
CLOSED = 'CLOSED'


class _Context(beta_interfaces.GRPCInvocationContext):

def __init__(self):
self._lock = threading.Lock()
self._disable_next_compression = False

def disable_next_request_compression(self):
with self._lock:
self._disable_next_compression = True

def next_compression_disabled(self):
with self._lock:
disabled = self._disable_next_compression
self._disable_next_compression = False
return disabled


class _RPCState(object):

def __init__(
self, call, request_serializer, response_deserializer, sequence_number,
read, allowance, high_write, low_write, due):
read, allowance, high_write, low_write, due, context):
self.call = call
self.request_serializer = request_serializer
self.response_deserializer = response_deserializer
Expand All @@ -87,6 +105,7 @@ def __init__(
self.high_write = high_write
self.low_write = low_write
self.due = due
self.context = context


def _no_longer_due(kind, rpc_state, key, rpc_states):
Expand Down Expand Up @@ -209,7 +228,7 @@ def _spin(self, completion_queue):

def _invoke(
self, operation_id, group, method, initial_metadata, payload, termination,
timeout, allowance):
timeout, allowance, options):
"""Invoke an RPC.
Args:
Expand All @@ -224,6 +243,7 @@ def _invoke(
timeout: A duration of time in seconds to allow for the RPC.
allowance: The number of payloads (beyond the free first one) that the
local ticket exchange mate has granted permission to be read.
options: A beta_interfaces.GRPCCallOptions value or None.
"""
if termination is links.Ticket.Termination.COMPLETION:
high_write = _HighWrite.CLOSED
Expand All @@ -241,6 +261,8 @@ def _invoke(
call = _intermediary_low.Call(
self._channel, self._completion_queue, '/%s/%s' % (group, method),
self._host, time.time() + timeout)
if options is not None and options.credentials is not None:
call.set_credentials(options.credentials._intermediary_low_credentials)
if transformed_initial_metadata is not None:
for metadata_key, metadata_value in transformed_initial_metadata:
call.add_metadata(metadata_key, metadata_value)
Expand All @@ -254,17 +276,33 @@ def _invoke(
low_write = _LowWrite.OPEN
due = set((_METADATA, _FINISH,))
else:
call.write(request_serializer(payload), operation_id)
if options is not None and options.disable_compression:
flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
else:
flags = 0
call.write(request_serializer(payload), operation_id, flags)
low_write = _LowWrite.ACTIVE
due = set((_WRITE, _METADATA, _FINISH,))
context = _Context()
self._rpc_states[operation_id] = _RPCState(
call, request_serializer, response_deserializer, 0,
call, request_serializer, response_deserializer, 1,
_Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance),
high_write, low_write, due)
high_write, low_write, due, context)
protocol = links.Protocol(links.Protocol.Kind.INVOCATION_CONTEXT, context)
ticket = links.Ticket(
operation_id, 0, None, None, None, None, None, None, None, None, None,
None, None, protocol)
self._relay.add_value(ticket)

def _advance(self, operation_id, rpc_state, payload, termination, allowance):
if payload is not None:
rpc_state.call.write(rpc_state.request_serializer(payload), operation_id)
disable_compression = rpc_state.context.next_compression_disabled()
if disable_compression:
flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
else:
flags = 0
rpc_state.call.write(
rpc_state.request_serializer(payload), operation_id, flags)
rpc_state.low_write = _LowWrite.ACTIVE
rpc_state.due.add(_WRITE)

Expand Down Expand Up @@ -292,10 +330,15 @@ def add_ticket(self, ticket):
if self._completion_queue is None:
logging.error('Received invocation ticket %s after stop!', ticket)
else:
if (ticket.protocol is not None and
ticket.protocol.kind is links.Protocol.Kind.CALL_OPTION):
grpc_call_options = ticket.protocol.value
else:
grpc_call_options = None
self._invoke(
ticket.operation_id, ticket.group, ticket.method,
ticket.initial_metadata, ticket.payload, ticket.termination,
ticket.timeout, ticket.allowance)
ticket.timeout, ticket.allowance, grpc_call_options)
else:
rpc_state = self._rpc_states.get(ticket.operation_id)
if rpc_state is not None:
Expand Down
39 changes: 34 additions & 5 deletions src/python/grpcio/grpc/_links/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from grpc._adapter import _intermediary_low
from grpc._links import _constants
from grpc.beta import interfaces as beta_interfaces
from grpc.framework.foundation import logging_pool
from grpc.framework.foundation import relay
from grpc.framework.interfaces.links import links
Expand Down Expand Up @@ -89,12 +90,34 @@ class _LowWrite(enum.Enum):
CLOSED = 'CLOSED'


class _Context(beta_interfaces.GRPCServicerContext):

def __init__(self, call):
self._lock = threading.Lock()
self._call = call
self._disable_next_compression = False

def peer(self):
with self._lock:
return self._call.peer()

def disable_next_response_compression(self):
with self._lock:
self._disable_next_compression = True

def next_compression_disabled(self):
with self._lock:
disabled = self._disable_next_compression
self._disable_next_compression = False
return disabled


class _RPCState(object):

def __init__(
self, request_deserializer, response_serializer, sequence_number, read,
early_read, allowance, high_write, low_write, premetadataed,
terminal_metadata, code, message, due):
terminal_metadata, code, message, due, context):
self.request_deserializer = request_deserializer
self.response_serializer = response_serializer
self.sequence_number = sequence_number
Expand All @@ -110,6 +133,7 @@ def __init__(
self.code = code
self.message = message
self.due = due
self.context = context


def _no_longer_due(kind, rpc_state, key, rpc_states):
Expand Down Expand Up @@ -163,12 +187,12 @@ def _on_service_acceptance_event(self, event, server):
(group, method), _IDENTITY)

call.read(call)
context = _Context(call)
self._rpc_states[call] = _RPCState(
request_deserializer, response_serializer, 1, _Read.READING, None, 1,
_HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None,
set((_READ, _FINISH,)))
protocol = links.Protocol(
links.Protocol.Kind.SERVICER_CONTEXT, 'TODO: Service Context Object!')
set((_READ, _FINISH,)), context)
protocol = links.Protocol(links.Protocol.Kind.SERVICER_CONTEXT, context)
ticket = links.Ticket(
call, 0, group, method, links.Ticket.Subscription.FULL,
service_acceptance.deadline - time.time(), None, event.metadata, None,
Expand Down Expand Up @@ -313,7 +337,12 @@ def add_ticket(self, ticket):
self._relay.add_value(early_read_ticket)

if ticket.payload is not None:
call.write(rpc_state.response_serializer(ticket.payload), call)
disable_compression = rpc_state.context.next_compression_disabled()
if disable_compression:
flags = _intermediary_low.WriteFlags.WRITE_NO_COMPRESS
else:
flags = 0
call.write(rpc_state.response_serializer(ticket.payload), call, flags)
rpc_state.due.add(_WRITE)
rpc_state.low_write = _LowWrite.ACTIVE

Expand Down
58 changes: 58 additions & 0 deletions src/python/grpcio/grpc/beta/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

"""Constants and interfaces of the Beta API of gRPC Python."""

import abc
import enum


Expand All @@ -52,3 +53,60 @@ class StatusCode(enum.Enum):
UNAVAILABLE = 14
DATA_LOSS = 15
UNAUTHENTICATED = 16


class GRPCCallOptions(object):
"""A value encapsulating gRPC-specific options passed on RPC invocation.
This class and its instances have no supported interface - it exists to
define the type of its instances and its instances exist to be passed to
other functions.
"""

def __init__(self, disable_compression, subcall_of, credentials):
self.disable_compression = disable_compression
self.subcall_of = subcall_of
self.credentials = credentials


def grpc_call_options(disable_compression=False, credentials=None):
"""Creates a GRPCCallOptions value to be passed at RPC invocation.
All parameters are optional and should always be passed by keyword.
Args:
disable_compression: A boolean indicating whether or not compression should
be disabled for the request object of the RPC. Only valid for
request-unary RPCs.
credentials: A ClientCredentials object to use for the invoked RPC.
"""
return GRPCCallOptions(disable_compression, None, credentials)


class GRPCServicerContext(object):
"""Exposes gRPC-specific options and behaviors to code servicing RPCs."""
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def peer(self):
"""Identifies the peer that invoked the RPC being serviced.
Returns:
A string identifying the peer that invoked the RPC being serviced.
"""
raise NotImplementedError()

@abc.abstractmethod
def disable_next_response_compression(self):
"""Disables compression of the next response passed by the application."""
raise NotImplementedError()


class GRPCInvocationContext(object):
"""Exposes gRPC-specific options and behaviors to code invoking RPCs."""
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def disable_next_request_compression(self):
"""Disables compression of the next request passed by the application."""
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _perform_echo_test(self, test_data):
metadata[server_leading_binary_metadata_key])

for datum in test_data:
client_call.write(datum, write_tag)
client_call.write(datum, write_tag, _low.WriteFlags.WRITE_NO_COMPRESS)
write_accepted = self.client_events.get()
self.assertIsNotNone(write_accepted)
self.assertIs(write_accepted.kind, _low.Event.Kind.WRITE_ACCEPTED)
Expand All @@ -206,7 +206,7 @@ def _perform_echo_test(self, test_data):
self.assertIsNotNone(read_accepted.bytes)
server_data.append(read_accepted.bytes)

server_call.write(read_accepted.bytes, write_tag)
server_call.write(read_accepted.bytes, write_tag, 0)
write_accepted = self.server_events.get()
self.assertIsNotNone(write_accepted)
self.assertEqual(_low.Event.Kind.WRITE_ACCEPTED, write_accepted.kind)
Expand Down Expand Up @@ -370,14 +370,14 @@ def testCancellation(self):
self.assertIsNotNone(metadata_accepted)

for datum in test_data:
client_call.write(datum, write_tag)
client_call.write(datum, write_tag, 0)
write_accepted = self.client_events.get()

server_call.read(read_tag)
read_accepted = self.server_events.get()
server_data.append(read_accepted.bytes)

server_call.write(read_accepted.bytes, write_tag)
server_call.write(read_accepted.bytes, write_tag, 0)
write_accepted = self.server_events.get()
self.assertIsNotNone(write_accepted)

Expand Down
Loading

0 comments on commit 8905be0

Please sign in to comment.