Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamManager: Add mechanism to close the request iterator #6263

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Addressed maffoo's comments
  • Loading branch information
verult committed Sep 8, 2023
commit 0bd13bfdcb0bfb0a9b22cbb557e9820c718c644c
25 changes: 15 additions & 10 deletions cirq-google/cirq_google/engine/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class StreamManager:

"""

_STOP_SIGNAL = 'stop_signal'
_STOP_SIGNAL = None

def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
Expand All @@ -121,11 +121,14 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
# interface.
self._response_demux = ResponseDemux()
self._next_available_message_id = 0
self._executor.submit(self._init_request_queue).result()
# Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it
# is used by asyncio coroutines.
self._request_queue: asyncio.Queue[
Optional[quantum.QuantumRunStreamRequest]
] = self._executor.submit(self._make_request_queue).result()
verult marked this conversation as resolved.
Show resolved Hide resolved

async def _init_request_queue(self) -> None:
await asyncio.sleep(0)
self._request_queue: asyncio.Queue = asyncio.Queue()
async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]:
return asyncio.Queue()

def submit(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
Expand Down Expand Up @@ -179,15 +182,17 @@ def _reset(self):
"""Resets the manager state."""
self._manage_stream_loop_future = None
self._response_demux = ResponseDemux()
self._executor.submit(self._init_request_queue).result()
self._request_queue = self._executor.submit(self._make_request_queue).result()

@property
def _executor(self) -> AsyncioExecutor:
# We must re-use a single Executor due to multi-threading issues in gRPC
# clients: https://github.com/grpc/grpc/issues/25364.
return AsyncioExecutor.instance()

async def _manage_stream(self, request_queue: asyncio.Queue) -> None:
async def _manage_stream(
self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]
) -> None:
"""The stream coroutine, an asyncio coroutine to manage QuantumRunStream.

This coroutine reads responses from the stream and forwards them to the ResponseDemux, where
Expand Down Expand Up @@ -218,7 +223,7 @@ async def _manage_stream(self, request_queue: asyncio.Queue) -> None:

async def _manage_execution(
self,
request_queue: asyncio.Queue,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
project_name: str,
program: quantum.QuantumProgram,
job: quantum.QuantumJob,
Expand Down Expand Up @@ -347,13 +352,13 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool:


async def _request_iterator(
request_queue: asyncio.Queue,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
) -> AsyncIterator[quantum.QuantumRunStreamRequest]:
"""The request iterator for Quantum Engine client RPC quantum_run_stream().

Every call to this method generates a new iterator.
"""
while (request := await request_queue.get()) != StreamManager._STOP_SIGNAL:
while request := await request_queue.get():
yield request


Expand Down
11 changes: 7 additions & 4 deletions cirq-google/cirq_google/engine/stream_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def __init__(self):
self._request_iterator_stopped = duet.AwaitableFuture()
# asyncio.Queue needs to be initialized inside the asyncio thread because all callers need
# to use the same event loop.
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
self._responses_and_exceptions_future: duet.AwaitableFuture[
asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]]
] = duet.AwaitableFuture()

async def quantum_run_stream(
self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs
Expand All @@ -94,7 +96,9 @@ async def quantum_run_stream(

This is called from the asyncio thread.
"""
responses_and_exceptions: asyncio.Queue = asyncio.Queue()
responses_and_exceptions: asyncio.Queue[
Union[quantum.QuantumRunStreamResponse, BaseException]
] = asyncio.Queue()
self._responses_and_exceptions_future.try_set_result(responses_and_exceptions)

async def read_requests():
Expand All @@ -112,10 +116,9 @@ async def response_iterator():
if isinstance(message, quantum.QuantumRunStreamResponse):
yield message
else: # isinstance(message, BaseException)
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
self._responses_and_exceptions_future = duet.AwaitableFuture()
raise message

await asyncio.sleep(0)
return response_iterator()

async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None:
Expand Down