Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] allow gRPC deployment to use grpc context #41667

Merged
merged 15 commits into from
Dec 8, 2023
Prev Previous commit
Next Next commit
address comments
Signed-off-by: Gene Su <e870252314@gmail.com>
  • Loading branch information
GeneDer committed Dec 7, 2023
commit bab5dbce26763d113f21a2a1605dec96c99d51db
2 changes: 0 additions & 2 deletions python/ray/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_app_handle,
get_deployment,
get_deployment_handle,
get_grpc_context,
get_multiplexed_model_id,
get_replica_context,
ingress,
Expand Down Expand Up @@ -51,7 +50,6 @@
"Deployment",
"multiplexed",
"get_multiplexed_model_id",
"get_grpc_context",
"status",
"get_app_handle",
"get_deployment_handle",
Expand Down
77 changes: 3 additions & 74 deletions python/ray/serve/_private/grpc_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Sequence

import grpc
from grpc.aio._server import Server

GRPC_CONTEXT_ARG_NAME = "grpc_context"

GeneDer marked this conversation as resolved.
Show resolved Hide resolved

class gRPCServer(Server):
"""Custom gRPC server to override gRPC method methods.
Expand Down Expand Up @@ -79,76 +81,3 @@ class DummyServicer:
def __getattr__(self, attr):
# No-op pass through. Just need this to act as the callable.
pass


class RayServegRPCContext:
"""Context manager to set and get gRPC context.

This class implements most of the methods from ServicerContext
(see: https://grpc.github.io/grpc/python/grpc.html#grpc.ServicerContext) so it's
serializable and can pass with the request to be used on the deployment.
"""

def __init__(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
self._auth_context = grpc_context.auth_context()
self._code = grpc_context.code()
self._details = grpc_context.details()
self._invocation_metadata = [
(key, value) for key, value in grpc_context.invocation_metadata()
]
self._peer = grpc_context.peer()
self._peer_identities = grpc_context.peer_identities()
self._peer_identity_key = grpc_context.peer_identity_key()
self._trailing_metadata = [
(key, value) for key, value in grpc_context.trailing_metadata()
]
self._compression = None

def auth_context(self) -> Dict[str, Any]:
return self._auth_context

def code(self) -> grpc.StatusCode:
return self._code

def details(self) -> str:
return self._details

def invocation_metadata(self) -> List[Tuple[str, str]]:
return self._invocation_metadata

def peer(self) -> str:
return self._peer

def peer_identities(self) -> Optional[bytes]:
return self._peer_identities

def peer_identity_key(self) -> Optional[str]:
return self._peer_identity_key

def trailing_metadata(self) -> List[Tuple[str, str]]:
return self._trailing_metadata

def set_code(self, code: grpc.StatusCode):
self._code = code

def set_compression(self, compression: grpc.Compression):
self._compression = compression

def set_details(self, details: str):
self._details = details

def set_trailing_metadata(self, trailing_metadata: List[Tuple[str, str]]):
self._trailing_metadata += trailing_metadata

def set_on_grpc_context(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
if self._code:
grpc_context.set_code(self._code)

if self._compression:
grpc_context.set_compression(self._compression)

if self._details:
grpc_context.set_details(self._details)

if self._trailing_metadata:
grpc_context.set_trailing_metadata(self._trailing_metadata)
17 changes: 15 additions & 2 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
SERVE_NAMESPACE,
)
from ray.serve._private.deployment_info import CONTROL_PLANE_CONCURRENCY_GROUP
from ray.serve._private.grpc_util import GRPC_CONTEXT_ARG_NAME
from ray.serve._private.http_util import (
ASGIAppReplicaWrapper,
ASGIMessageQueue,
Expand Down Expand Up @@ -769,7 +770,13 @@ async def call_user_method_with_grpc_unary_stream(
async with self.wrap_user_method_call(request_metadata):
user_method = self.get_runner_method(request_metadata)
user_request = pickle.loads(request.grpc_user_request)
result_generator = user_method(user_request)
if GRPC_CONTEXT_ARG_NAME in inspect.signature(user_method).parameters:
result_generator = user_method(
user_request,
request_metadata.grpc_context,
)
else:
result_generator = user_method(user_request)
if inspect.iscoroutine(result_generator):
result_generator = await result_generator

Expand Down Expand Up @@ -808,7 +815,13 @@ async def call_user_method_grpc_unary(

method_to_call = sync_to_async(runner_method)

result = await method_to_call(user_request)
if GRPC_CONTEXT_ARG_NAME in inspect.signature(runner_method).parameters:
result = await method_to_call(
user_request,
request_metadata.grpc_context,
GeneDer marked this conversation as resolved.
Show resolved Hide resolved
)
else:
result = await method_to_call(user_request)
return request_metadata.grpc_context, result.SerializeToString()

async def call_user_method(
Expand Down
8 changes: 0 additions & 8 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
)
from ray.serve.deployment import Application, Deployment
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
from ray.serve.handle import DeploymentHandle
from ray.serve.multiplex import _ModelMultiplexWrapper
from ray.serve.schema import LoggingConfig, ServeInstanceDetails, ServeStatus
Expand Down Expand Up @@ -735,13 +734,6 @@ def my_deployment_function(request):
return _request_context.multiplexed_model_id


@PublicAPI(stability="beta")
def get_grpc_context() -> RayServegRPCContext:
"""Get the grpc context for the current request."""
_request_context = ray.serve.context._serve_request_context.get()
return _request_context.grpc_context


@PublicAPI(stability="alpha")
def status() -> ServeStatus:
"""Get the status of Serve on the cluster.
Expand Down
10 changes: 9 additions & 1 deletion python/ray/serve/grpc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,16 @@ def set_compression(self, compression: grpc.Compression):
def set_details(self, details: str):
self._details = details

def _request_id_metadata(self) -> List[Tuple[str, str]]:
# Request id metadata should be carried over to the trailing metadata and passed
# back to the request client. This function helped to pick it out if existed.
GeneDer marked this conversation as resolved.
Show resolved Hide resolved
for key, value in self._trailing_metadata:
if key == "request_id":
return [(key, value)]
return []

def set_trailing_metadata(self, trailing_metadata: List[Tuple[str, str]]):
self._trailing_metadata += trailing_metadata
self._trailing_metadata = self._request_id_metadata() + trailing_metadata

def set_on_grpc_context(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
if self._code:
Expand Down
8 changes: 4 additions & 4 deletions python/ray/serve/tests/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,13 @@ def test_using_grpc_context(ray_instance, ray_shutdown, streaming: bool):

@serve.deployment()
class HelloModel:
def __call__(self, user_message):
grpc_context = serve.get_grpc_context()
def __call__(self, user_message, grpc_context):
GeneDer marked this conversation as resolved.
Show resolved Hide resolved
grpc_context.set_code(error_code)
grpc_context.set_details(error_message)
grpc_context.set_trailing_metadata([trailing_metadata])
return serve_pb2.UserDefinedResponse(greeting="hello")

def Streaming(self, user_message):
grpc_context = serve.get_grpc_context()
def Streaming(self, user_message, grpc_context):
grpc_context.set_code(error_code)
grpc_context.set_details(error_message)
grpc_context.set_trailing_metadata([trailing_metadata])
Expand All @@ -632,6 +630,8 @@ def Streaming(self, user_message):
assert rpc_error.code() == error_code
assert error_message == rpc_error.details()
assert trailing_metadata in rpc_error.trailing_metadata()
# request_id should always be set in the trailing metadata.
assert any([key == "request_id" for key, _ in rpc_error.trailing_metadata()])


if __name__ == "__main__":
Expand Down
Loading