Skip to content

Commit

Permalink
Merge pull request grpc#8137 from kpayson64/python_server_args
Browse files Browse the repository at this point in the history
Add parameter for server options
  • Loading branch information
kpayson64 authored Sep 19, 2016
2 parents 9fec58f + 63d8af2 commit a6a6fa4
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 38 deletions.
12 changes: 8 additions & 4 deletions src/python/grpcio/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,7 @@ def insecure_channel(target, options=None):
A Channel to the target through which RPCs may be conducted.
"""
from grpc import _channel
return _channel.Channel(target, options, None)
return _channel.Channel(target, () if options is None else options, None)


def secure_channel(target, credentials, options=None):
Expand All @@ -1205,10 +1205,11 @@ def secure_channel(target, credentials, options=None):
A Channel to the target through which RPCs may be conducted.
"""
from grpc import _channel
return _channel.Channel(target, options, credentials._credentials)
return _channel.Channel(target, () if options is None else options,
credentials._credentials)


def server(thread_pool, handlers=None):
def server(thread_pool, handlers=None, options=None):
"""Creates a Server with which RPCs can be serviced.
Args:
Expand All @@ -1219,12 +1220,15 @@ def server(thread_pool, handlers=None):
only handlers the server will use to service RPCs; other handlers may
later be added by calling add_generic_rpc_handlers any time before the
returned Server is started.
options: A sequence of string-value pairs according to which to configure
the created server.
Returns:
A Server with which RPCs can be serviced.
"""
from grpc import _server
return _server.Server(thread_pool, () if handlers is None else handlers)
return _server.Server(thread_pool, () if handlers is None else handlers,
() if options is None else options)


################################### __all__ #################################
Expand Down
17 changes: 4 additions & 13 deletions src/python/grpcio/grpc/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,18 +842,8 @@ def _unsubscribe(state, callback):


def _options(options):
if options is None:
pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)
else:
pairs = list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
encoded_pairs = [
(_common.encode(arg_name), arg_value) if isinstance(arg_value, int)
else (_common.encode(arg_name), _common.encode(arg_value))
for arg_name, arg_value in pairs]
return cygrpc.ChannelArgs([
cygrpc.ChannelArg(arg_name, arg_value)
for arg_name, arg_value in encoded_pairs])
return list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]


class Channel(grpc.Channel):
Expand All @@ -867,7 +857,8 @@ def __init__(self, target, options, credentials):
credentials: A cygrpc.ChannelCredentials or None.
"""
self._channel = cygrpc.Channel(
_common.encode(target), _options(options), credentials)
_common.encode(target), _common.channel_args(_options(options)),
credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)

Expand Down
10 changes: 10 additions & 0 deletions src/python/grpcio/grpc/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def decode(b):
return b.decode('latin1')


def channel_args(options):
channel_args = []
for key, value in options:
if isinstance(value, six.string_types):
channel_args.append(cygrpc.ChannelArg(encode(key), encode(value)))
else:
channel_args.append(cygrpc.ChannelArg(encode(key), value))
return cygrpc.ChannelArgs(channel_args)


def cygrpc_metadata(application_metadata):
return _EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
cygrpc.Metadatum(encode(key), encode(value))
Expand Down
5 changes: 3 additions & 2 deletions src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ cimport cpython

cdef class Channel:

def __cinit__(self, bytes target, ChannelArgs arguments=None,
def __cinit__(self, bytes target, ChannelArgs arguments,
ChannelCredentials channel_credentials=None):
grpc_init()
cdef grpc_channel_args *c_arguments = NULL
cdef char *c_target = NULL
self.c_channel = NULL
self.references = []
if arguments is not None:
if len(arguments) > 0:
c_arguments = &arguments.c_args
self.references.append(arguments)
c_target = target
if channel_credentials is None:
with nogil:
Expand Down
4 changes: 2 additions & 2 deletions src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ import time

cdef class Server:

def __cinit__(self, ChannelArgs arguments=None):
def __cinit__(self, ChannelArgs arguments):
grpc_init()
cdef grpc_channel_args *c_arguments = NULL
self.references = []
self.registered_completion_queues = []
if arguments is not None:
if len(arguments) > 0:
c_arguments = &arguments.c_args
self.references.append(arguments)
with nogil:
Expand Down
5 changes: 2 additions & 3 deletions src/python/grpcio/grpc/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,12 +728,11 @@ def cleanup_server(timeout):
cleanup_server, target=_serve, args=(state,))
thread.start()


class Server(grpc.Server):

def __init__(self, thread_pool, generic_handlers):
def __init__(self, thread_pool, generic_handlers, options):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(
completion_queue, server, generic_handlers, thread_pool)
Expand Down
1 change: 1 addition & 0 deletions src/python/grpcio_tests/tests/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"_beta_features_test.BetaFeaturesTest",
"_beta_features_test.ContextManagementAndLifecycleTest",
"_cancel_many_calls_test.CancelManyCallsTest",
"_channel_args_test.ChannelArgsTest",
"_channel_connectivity_test.ChannelConnectivityTest",
"_channel_ready_future_test.ChannelReadyFutureTest",
"_channel_test.ChannelTest",
Expand Down
53 changes: 53 additions & 0 deletions src/python/grpcio_tests/tests/unit/_channel_args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Tests of Channel Args on client/server side."""

import unittest

import grpc

TEST_CHANNEL_ARGS = (
('arg1', b'bytes_val'),
('arg2', 'str_val'),
('arg3', 1),
(b'arg4', 'str_val'),
)


class ChannelArgsTest(unittest.TestCase):

def test_client(self):
grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS)

def test_server(self):
grpc.server(None, options=TEST_CHANNEL_ARGS)

if __name__ == '__main__':
unittest.main(verbosity=2)
10 changes: 5 additions & 5 deletions src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ChannelConnectivityTest(unittest.TestCase):
def test_lonely_channel_connectivity(self):
callback = _Callback()

channel = _channel.Channel('localhost:12345', None, None)
channel = _channel.Channel('localhost:12345', (), None)
channel.subscribe(callback.update, try_to_connect=False)
first_connectivities = callback.block_until_connectivities_satisfy(bool)
channel.subscribe(callback.update, try_to_connect=True)
Expand All @@ -105,13 +105,13 @@ def test_lonely_channel_connectivity(self):

def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ())
server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0')
server.start()
first_callback = _Callback()
second_callback = _Callback()

channel = _channel.Channel('localhost:{}'.format(port), None, None)
channel = _channel.Channel('localhost:{}'.format(port), (), None)
channel.subscribe(first_callback.update, try_to_connect=False)
first_connectivities = first_callback.block_until_connectivities_satisfy(
bool)
Expand Down Expand Up @@ -146,12 +146,12 @@ def test_immediately_connectable_channel_connectivity(self):

def test_reachable_then_unreachable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ())
server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0')
server.start()
callback = _Callback()

channel = _channel.Channel('localhost:{}'.format(port), None, None)
channel = _channel.Channel('localhost:{}'.format(port), (), None)
channel.subscribe(callback.update, try_to_connect=True)
callback.block_until_connectivities_satisfy(_ready_in_connectivities)
# Now take down the server and confirm that channel readiness is repudiated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_lonely_channel_connectivity(self):

def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = _server.Server(thread_pool, ())
server = _server.Server(thread_pool, (), ())
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ def testCancelManyCalls(self):
server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)

server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode())
channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
cygrpc.ChannelArgs([]))

state = _State()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):

def testReadSomeButNotAllResponses(self):
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode())
channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
cygrpc.ChannelArgs([]))

server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
Expand Down
9 changes: 5 additions & 4 deletions src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def testCallCredentialsFromPluginUpDown(self):
del call_credentials

def testServerStartNoExplicitShutdown(self):
server = cygrpc.Server()
server = cygrpc.Server(cygrpc.ChannelArgs([]))
completion_queue = cygrpc.CompletionQueue()
server.register_completion_queue(completion_queue)
port = server.add_http2_port(b'[::]:0')
Expand All @@ -131,7 +131,7 @@ def testServerStartNoExplicitShutdown(self):

def testServerStartShutdown(self):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server = cygrpc.Server(cygrpc.ChannelArgs([]))
server.add_http2_port(b'[::]:0')
server.register_completion_queue(completion_queue)
server.start()
Expand All @@ -148,7 +148,7 @@ class ServerClientMixin(object):

def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server()
self.server = cygrpc.Server(cygrpc.ChannelArgs([]))
self.server.register_completion_queue(self.server_completion_queue)
if server_credentials:
self.port = self.server.add_http2_port(b'[::]:0', server_credentials)
Expand All @@ -164,7 +164,8 @@ def setUpMixin(self, server_credentials, client_credentials, host_override):
'localhost:{}'.format(self.port).encode(), client_channel_arguments,
client_credentials)
else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(self.port).encode())
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port).encode(), cygrpc.ChannelArgs([]))
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
Expand Down

0 comments on commit a6a6fa4

Please sign in to comment.