Skip to content

Commit

Permalink
[Serve] allow gRPC deployment to use grpc context (ray-project#41667)
Browse files Browse the repository at this point in the history
This PR passes a grpc_context to deployments if the deployment uses it to get gRPC request related info can use to set code, details, trailing metadata, and compression. The original grpc._cython.cygrpc._ServicerContext type is not serializable, so we created a RayServegRPCContext to be able to pass to the deployment. Will follow up with doc change.

---------

Signed-off-by: Gene Su <e870252314@gmail.com>
  • Loading branch information
GeneDer authored Dec 8, 2023
1 parent aa86ef6 commit 2487553
Show file tree
Hide file tree
Showing 13 changed files with 448 additions and 29 deletions.
3 changes: 3 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,6 @@
# precision up to 0.0001.
# This limitation should be lifted in the long term.
MAX_REPLICAS_PER_NODE_MAX_VALUE = 100

# Argument name for passing in the gRPC context into a replica.
GRPC_CONTEXT_ARG_NAME = "grpc_context"
25 changes: 20 additions & 5 deletions python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,20 @@ async def health_response(self, proxy_request: ProxyRequest) -> ResponseGenerato
)

def service_handler_factory(self, service_method: str, stream: bool) -> Callable:
def set_grpc_code_and_details(
context: grpc._cython.cygrpc._ServicerContext, status: ResponseStatus
):
# Only the latest code and details will take effect. If the user already
# set them to a truthy value in the context, skip setting them with Serve's
# default values. By default, if nothing is set, the code is 0 and the
# details is "", which both are falsy. So if the user did not set them or
# if they're explicitly set to falsy values, such as None, Serve will
# continue to set them with our default values.
if not context.code():
context.set_code(status.code)
if not context.details():
context.set_details(status.message)

async def unary_unary(
request_proto: Any, context: grpc._cython.cygrpc._ServicerContext
) -> bytes:
Expand All @@ -640,8 +654,8 @@ async def unary_unary(
else:
response = message

context.set_code(status.code)
context.set_details(status.message)
set_grpc_code_and_details(context, status)

return response

async def unary_stream(
Expand All @@ -668,8 +682,7 @@ async def unary_stream(
else:
yield message

context.set_code(status.code)
context.set_details(status.message)
set_grpc_code_and_details(context, status)

return unary_stream if stream else unary_unary

Expand Down Expand Up @@ -702,6 +715,7 @@ def setup_request_context_and_handle(
"request_id": request_id,
"app_name": app_name,
"multiplexed_model_id": multiplexed_model_id,
"grpc_context": proxy_request.ray_serve_grpc_context,
}
ray.serve.context._serve_request_context.set(
ray.serve.context._RequestContext(**request_context_info)
Expand All @@ -723,7 +737,8 @@ async def send_request_to_replica(
)

try:
async for result in response_generator:
async for context, result in response_generator:
context.set_on_grpc_context(proxy_request.context)
yield result

yield ResponseStatus(code=grpc.StatusCode.OK)
Expand Down
9 changes: 8 additions & 1 deletion python/ray/serve/_private/proxy_request_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.serve._private.common import StreamingHTTPRequest, gRPCRequest
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.utils import DEFAULT
from ray.serve.grpc_util import RayServegRPCContext

logger = logging.getLogger(SERVE_LOGGER_NAME)

Expand Down Expand Up @@ -119,6 +120,9 @@ def __init__(
self.request_id = None
self.method_name = "__call__"
self.multiplexed_model_id = DEFAULT.VALUE
# ray_serve_grpc_context is a class implemented by us to be able to serialize
# the object and pass it into the deployment.
self.ray_serve_grpc_context = RayServegRPCContext(context)
self.setup_variables()

def setup_variables(self):
Expand Down Expand Up @@ -159,7 +163,10 @@ def user_request(self) -> bytes:
return self.request

def send_request_id(self, request_id: str):
self.context.set_trailing_metadata([("request_id", request_id)])
# Setting the trailing metadata on the ray_serve_grpc_context object, so it's
# not overriding the ones set from the user and will be sent back to the
# client altogether.
self.ray_serve_grpc_context.set_trailing_metadata([("request_id", request_id)])

def request_object(self, proxy_handle: ActorHandle) -> gRPCRequest:
return gRPCRequest(
Expand Down
29 changes: 22 additions & 7 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import (
DEFAULT_LATENCY_BUCKET_MS,
GRPC_CONTEXT_ARG_NAME,
HEALTH_CHECK_METHOD,
RAY_SERVE_GAUGE_METRIC_SET_PERIOD_S,
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S,
Expand Down Expand Up @@ -64,6 +65,7 @@
from ray.serve._private.version import DeploymentVersion
from ray.serve.deployment import Deployment
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
from ray.serve.schema import LoggingConfig

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -737,6 +739,7 @@ async def wrap_user_method_call(self, request_metadata: RequestMetadata):
request_metadata.request_id,
self.deployment_id.app,
request_metadata.multiplexed_model_id,
request_metadata.grpc_context,
)
)

Expand Down Expand Up @@ -782,7 +785,7 @@ async def wrap_user_method_call(self, request_metadata: RequestMetadata):

async def call_user_method_with_grpc_unary_stream(
self, request_metadata: RequestMetadata, request: gRPCRequest
) -> AsyncGenerator[bytes, None]:
) -> AsyncGenerator[Tuple[RayServegRPCContext, bytes], None]:
"""Call a user method that is expected to be a generator.
Deserializes gRPC request into protobuf object and pass into replica's runner
Expand All @@ -791,16 +794,22 @@ async def call_user_method_with_grpc_unary_stream(
async with self.wrap_user_method_call(request_metadata):
user_method = self.get_runner_method(request_metadata)
user_request = pickle.loads(request.grpc_user_request)
result_generator = user_method(user_request)
if GRPC_CONTEXT_ARG_NAME in inspect.signature(user_method).parameters:
result_generator = user_method(
user_request,
grpc_context=request_metadata.grpc_context,
)
else:
result_generator = user_method(user_request)
if inspect.iscoroutine(result_generator):
result_generator = await result_generator

if inspect.isgenerator(result_generator):
for result in result_generator:
yield result.SerializeToString()
yield request_metadata.grpc_context, result.SerializeToString()
elif inspect.isasyncgen(result_generator):
async for result in result_generator:
yield result.SerializeToString()
yield request_metadata.grpc_context, result.SerializeToString()
else:
raise TypeError(
"When using `stream=True`, the called method must be a generator "
Expand All @@ -809,7 +818,7 @@ async def call_user_method_with_grpc_unary_stream(

async def call_user_method_grpc_unary(
self, request_metadata: RequestMetadata, request: gRPCRequest
) -> bytes:
) -> Tuple[RayServegRPCContext, bytes]:
"""Call a user method that is *not* expected to be a generator.
Deserializes gRPC request into protobuf object and pass into replica's runner
Expand All @@ -830,8 +839,14 @@ async def call_user_method_grpc_unary(

method_to_call = sync_to_async(runner_method)

result = await method_to_call(user_request)
return result.SerializeToString()
if GRPC_CONTEXT_ARG_NAME in inspect.signature(runner_method).parameters:
result = await method_to_call(
user_request,
grpc_context=request_metadata.grpc_context,
)
else:
result = await method_to_call(user_request)
return request_metadata.grpc_context, result.SerializeToString()

async def call_user_method(
self,
Expand Down
4 changes: 4 additions & 0 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ray.serve._private.utils import JavaActorHandleProxy, MetricsPusher
from ray.serve.generated.serve_pb2 import DeploymentRoute
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto
from ray.serve.grpc_util import RayServegRPCContext
from ray.util import metrics

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand All @@ -64,6 +65,9 @@ class RequestMetadata:
# The protocol to serve this request
_request_protocol: RequestProtocol = RequestProtocol.UNDEFINED

# Serve's gRPC context associated with this request for getting and setting metadata
grpc_context: Optional[RayServegRPCContext] = None

@property
def is_http_request(self) -> bool:
return self._request_protocol == RequestProtocol.HTTP
Expand Down
36 changes: 31 additions & 5 deletions python/ray/serve/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,42 @@ async def send_signal_on_cancellation(signal_actor: ActorHandle):

class FakeGrpcContext:
def __init__(self):
self.code = None
self.details = None
self._trailing_metadata = None
self._auth_context = {"key": "value"}
self._invocation_metadata = [("key", "value")]
self._peer = "peer"
self._peer_identities = b"peer_identities"
self._peer_identity_key = "peer_identity_key"
self._code = None
self._details = None
self._trailing_metadata = []
self._invocation_metadata = []

def auth_context(self):
return self._auth_context

def code(self):
return self._code

def details(self):
return self._details

def peer(self):
return self._peer

def peer_identities(self):
return self._peer_identities

def peer_identity_key(self):
return self._peer_identity_key

def trailing_metadata(self):
return self._trailing_metadata

def set_code(self, code):
self.code = code
self._code = code

def set_details(self, details):
self.details = details
self._details = details

def set_trailing_metadata(self, trailing_metadata):
self._trailing_metadata = trailing_metadata
Expand Down
2 changes: 2 additions & 0 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.serve._private.common import ReplicaTag
from ray.serve._private.constants import SERVE_CONTROLLER_NAME, SERVE_NAMESPACE
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
from ray.util.annotations import DeveloperAPI

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -166,6 +167,7 @@ class _RequestContext:
request_id: str = ""
app_name: str = ""
multiplexed_model_id: str = ""
grpc_context: Optional[RayServegRPCContext] = None


_serve_request_context = contextvars.ContextVar(
Expand Down
87 changes: 87 additions & 0 deletions python/ray/serve/grpc_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any, Dict, List, Optional, Tuple

import grpc

from ray.util.annotations import PublicAPI


@PublicAPI(stability="beta")
class RayServegRPCContext:
"""Context manager to set and get gRPC context.
This class implements most of the methods from ServicerContext
(see: https://grpc.github.io/grpc/python/grpc.html#grpc.ServicerContext). It's
serializable and can be passed with the request to be used on the deployment.
"""

def __init__(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
self._auth_context = grpc_context.auth_context()
self._code = grpc_context.code()
self._details = grpc_context.details()
self._invocation_metadata = [
(key, value) for key, value in grpc_context.invocation_metadata()
]
self._peer = grpc_context.peer()
self._peer_identities = grpc_context.peer_identities()
self._peer_identity_key = grpc_context.peer_identity_key()
self._trailing_metadata = [
(key, value) for key, value in grpc_context.trailing_metadata()
]
self._compression = None

def auth_context(self) -> Dict[str, Any]:
return self._auth_context

def code(self) -> grpc.StatusCode:
return self._code

def details(self) -> str:
return self._details

def invocation_metadata(self) -> List[Tuple[str, str]]:
return self._invocation_metadata

def peer(self) -> str:
return self._peer

def peer_identities(self) -> Optional[bytes]:
return self._peer_identities

def peer_identity_key(self) -> Optional[str]:
return self._peer_identity_key

def trailing_metadata(self) -> List[Tuple[str, str]]:
return self._trailing_metadata

def set_code(self, code: grpc.StatusCode):
self._code = code

def set_compression(self, compression: grpc.Compression):
self._compression = compression

def set_details(self, details: str):
self._details = details

def _request_id_metadata(self) -> List[Tuple[str, str]]:
# Request id metadata should be carried over to the trailing metadata and passed
# back to the request client. This function helps pick it out if it exists.
for key, value in self._trailing_metadata:
if key == "request_id":
return [(key, value)]
return []

def set_trailing_metadata(self, trailing_metadata: List[Tuple[str, str]]):
self._trailing_metadata = self._request_id_metadata() + trailing_metadata

def set_on_grpc_context(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
if self._code:
grpc_context.set_code(self._code)

if self._compression:
grpc_context.set_compression(self._compression)

if self._details:
grpc_context.set_details(self._details)

if self._trailing_metadata:
grpc_context.set_trailing_metadata(self._trailing_metadata)
1 change: 1 addition & 0 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _remote(
multiplexed_model_id=self.handle_options.multiplexed_model_id,
is_streaming=self.handle_options.stream,
_request_protocol=self.handle_options._request_protocol,
grpc_context=_request_context.grpc_context,
)
self.request_counter.inc(
tags={
Expand Down
Loading

0 comments on commit 2487553

Please sign in to comment.