Skip to content

Commit

Permalink
Merge pull request grpc#6254 from grpc/python_per_rpc_interop
Browse files Browse the repository at this point in the history
Added google call creds/per_rpc interop tests
  • Loading branch information
jtattermusch committed Jun 3, 2016
2 parents 6c81c25 + 60a83c7 commit 84f7193
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ cdef void plugin_get_metadata(
void *state, grpc_auth_metadata_context context,
grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil

cdef void plugin_destroy_c_plugin_state(void *state)
cdef void plugin_destroy_c_plugin_state(void *state) with gil
2 changes: 1 addition & 1 deletion src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ cdef void plugin_get_metadata(
cy_context.context = context
self.plugin_callback(cy_context, python_callback)

cdef void plugin_destroy_c_plugin_state(void *state):
cdef void plugin_destroy_c_plugin_state(void *state) with gil:
cpython.Py_DECREF(<CredentialsMetadataPlugin>state)

def channel_credentials_google_default():
Expand Down
73 changes: 73 additions & 0 deletions src/python/grpcio/grpc/beta/_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""GRPCAuthMetadataPlugins for standard authentication."""

from concurrent import futures

from grpc.beta import interfaces


def _sign_request(callback, token, error):
metadata = (('authorization', 'Bearer {}'.format(token)),)
callback(metadata, error)


class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""

def __init__(self, credentials):
self._credentials = credentials
self._pool = futures.ThreadPoolExecutor(max_workers=1)

def __call__(self, context, callback):
# MetadataPlugins cannot block (see grpc.beta.interfaces.py)
future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(lambda x: self._get_token_callback(callback, x))

def _get_token_callback(self, callback, future):
try:
access_token = future.result().access_token
except Exception as e:
_sign_request(callback, None, e)
else:
_sign_request(callback, access_token, None)

def __del__(self):
self._pool.shutdown(wait=False)


class AccessTokenCallCredentials(interfaces.GRPCAuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials."""

def __init__(self, access_token):
self._access_token = access_token

def __call__(self, context, callback):
_sign_request(callback, self._access_token, None)
33 changes: 32 additions & 1 deletion src/python/grpcio/grpc/beta/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from grpc._adapter import _intermediary_low
from grpc._adapter import _low
from grpc._adapter import _types
from grpc.beta import _auth
from grpc.beta import _connectivity_channel
from grpc.beta import _server
from grpc.beta import _stub
Expand Down Expand Up @@ -105,10 +106,40 @@ def metadata_call_credentials(metadata_plugin, name=None):
A CallCredentials object for use in a GRPCCallOptions object.
"""
if name is None:
name = metadata_plugin.__name__
try:
name = metadata_plugin.__name__
except AttributeError:
name = metadata_plugin.__class__.__name__
return CallCredentials(
_low.call_credentials_metadata_plugin(metadata_plugin, name))


def google_call_credentials(credentials):
"""Construct CallCredentials from GoogleCredentials.
Args:
credentials: A GoogleCredentials object from the oauth2client library.
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))


def access_token_call_credentials(access_token):
"""Construct CallCredentials from an access token.
Args:
access_token: A string to place directly in the http request
authorization header, ie "Authorization: Bearer <access_token>".
Returns:
A CallCredentials object for use in a GRPCCallOptions object.
"""
return metadata_call_credentials(
_auth.AccessTokenCallCredentials(access_token))


def composite_call_credentials(call_credentials, additional_call_credentials):
"""Compose two CallCredentials to make a new one.
Expand Down
39 changes: 17 additions & 22 deletions src/python/grpcio/tests/interop/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,34 @@ def _args():
help='email address of the default service account', type=str)
return parser.parse_args()

def _oauth_access_token(args):
credentials = oauth2client_client.GoogleCredentials.get_application_default()
scoped_credentials = credentials.create_scoped([args.oauth_scope])
return scoped_credentials.get_access_token().access_token

def _stub(args):
if args.oauth_scope:
if args.test_case == 'oauth2_auth_token':
# TODO(jtattermusch): This testcase sets the auth metadata key-value
# manually, which also means that the user would need to do the same
# thing every time he/she would like to use and out of band oauth token.
# The transformer function that produces the metadata key-value from
# the access token should be provided by gRPC auth library.
access_token = _oauth_access_token(args)
metadata_transformer = lambda x: [
('authorization', 'Bearer %s' % access_token)]
else:
metadata_transformer = lambda x: [
('authorization', 'Bearer %s' % _oauth_access_token(args))]
if args.test_case == 'oauth2_auth_token':
creds = oauth2client_client.GoogleCredentials.get_application_default()
scoped_creds = creds.create_scoped([args.oauth_scope])
access_token = scoped_creds.get_access_token().access_token
call_creds = implementations.access_token_call_credentials(access_token)
elif args.test_case == 'compute_engine_creds':
creds = oauth2client_client.GoogleCredentials.get_application_default()
scoped_creds = creds.create_scoped([args.oauth_scope])
call_creds = implementations.google_call_credentials(scoped_creds)
else:
metadata_transformer = lambda x: []
call_creds = None
if args.use_tls:
if args.use_test_ca:
root_certificates = resources.test_root_certificates()
else:
root_certificates = None # will load default roots.

channel_creds = implementations.ssl_channel_credentials(root_certificates)
if call_creds is not None:
channel_creds = implementations.composite_channel_credentials(
channel_creds, call_creds)

channel = test_utilities.not_really_secure_channel(
args.server_host, args.server_port,
implementations.ssl_channel_credentials(root_certificates),
args.server_host, args.server_port, channel_creds,
args.server_host_override)
stub = test_pb2.beta_create_TestService_stub(
channel, metadata_transformer=metadata_transformer)
stub = test_pb2.beta_create_TestService_stub(channel)
else:
channel = implementations.insecure_channel(
args.server_host, args.server_port)
Expand Down
30 changes: 27 additions & 3 deletions src/python/grpcio/tests/interop/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

from oauth2client import client as oauth2client_client

from grpc.beta import implementations
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.interfaces.face import face

Expand Down Expand Up @@ -88,13 +90,15 @@ def HalfDuplexCall(self, request_iterator, context):
return self.FullDuplexCall(request_iterator, context)


def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope):
def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
protocol_options=None):
with stub:
request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE, response_size=314159,
payload=messages_pb2.Payload(body=b'\x00' * 271828),
fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
response_future = stub.UnaryCall.future(request, _TIMEOUT)
response_future = stub.UnaryCall.future(request, _TIMEOUT,
protocol_options=protocol_options)
response = response_future.result()
if response.payload.type is not messages_pb2.COMPRESSABLE:
raise ValueError(
Expand Down Expand Up @@ -303,7 +307,24 @@ def _oauth2_auth_token(stub, args):
if args.oauth_scope.find(response.oauth_scope) == -1:
raise ValueError(
'expected to find oauth scope "%s" in received "%s"' %
(response.oauth_scope, args.oauth_scope))
(response.oauth_scope, args.oauth_scope))


def _per_rpc_creds(stub, args):
json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
credentials = oauth2client_client.GoogleCredentials.get_application_default()
scoped_credentials = credentials.create_scoped([args.oauth_scope])
call_creds = implementations.google_call_credentials(scoped_credentials)
options = interfaces.grpc_call_options(disable_compression=False,
credentials=call_creds)
response = _large_unary_common_behavior(stub, True, False,
protocol_options=options)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))


@enum.unique
class TestCase(enum.Enum):
Expand All @@ -317,6 +338,7 @@ class TestCase(enum.Enum):
EMPTY_STREAM = 'empty_stream'
COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
PER_RPC_CREDS = 'per_rpc_creds'
TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'

def test_interoperability(self, stub, args):
Expand All @@ -342,5 +364,7 @@ def test_interoperability(self, stub, args):
_compute_engine_creds(stub, args)
elif self is TestCase.OAUTH2_AUTH_TOKEN:
_oauth2_auth_token(stub, args)
elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args)
else:
raise NotImplementedError('Test case "%s" not implemented!' % self.name)
3 changes: 3 additions & 0 deletions src/python/grpcio/tests/tests.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[
"_auth_test.AccessTokenCallCredentialsTest",
"_auth_test.GoogleCallCredentialsTest",
"_base_interface_test.AsyncEasyTest",
"_base_interface_test.AsyncPeasyTest",
"_base_interface_test.SyncEasyTest",
Expand Down Expand Up @@ -33,6 +35,7 @@
"_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
"_health_servicer_test.HealthServicerTest",
"_implementations_test.CallCredentialsTest",
"_implementations_test.ChannelCredentialsTest",
"_insecure_interop_test.InsecureInteropTest",
"_intermediary_low_test.CancellationTest",
Expand Down
96 changes: 96 additions & 0 deletions src/python/grpcio/tests/unit/beta/_auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Tests of standard AuthMetadataPlugins."""

import collections
import threading
import unittest

from grpc.beta import _auth


class MockGoogleCreds(object):

def get_access_token(self):
token = collections.namedtuple('MockAccessTokenInfo',
('access_token', 'expires_in'))
token.access_token = 'token'
return token


class MockExceptionGoogleCreds(object):

def get_access_token(self):
raise Exception()


class GoogleCallCredentialsTest(unittest.TestCase):

def test_google_call_credentials_success(self):
callback_event = threading.Event()

def mock_callback(metadata, error):
self.assertEqual(metadata, (('authorization', 'Bearer token'),))
self.assertIsNone(error)
callback_event.set()

call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))

def test_google_call_credentials_error(self):
callback_event = threading.Event()

def mock_callback(metadata, error):
self.assertIsNotNone(error)
callback_event.set()

call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))


class AccessTokenCallCredentialsTest(unittest.TestCase):

def test_google_call_credentials_success(self):
callback_event = threading.Event()

def mock_callback(metadata, error):
self.assertEqual(metadata, (('authorization', 'Bearer token'),))
self.assertIsNone(error)
callback_event.set()

call_creds = _auth.AccessTokenCallCredentials('token')
call_creds(None, mock_callback)
self.assertTrue(callback_event.wait(1.0))


if __name__ == '__main__':
unittest.main(verbosity=2)
Loading

0 comments on commit 84f7193

Please sign in to comment.