Skip to content

Commit

Permalink
StreamManager: Add mechanism to close the request iterator (#6263)
Browse files Browse the repository at this point in the history
* Add a signal to stop the request iterator

* Make request_queue local to asyncio coroutines

* Added missing raises docstring

* Addressed maffoo's comments

* Addressed maffoo's nits

* Fix failing stream_manager_test after merging

* Fix format
  • Loading branch information
verult authored Sep 11, 2023
1 parent deedb45 commit 6c14cfa
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 34 deletions.
71 changes: 48 additions & 23 deletions cirq-google/cirq_google/engine/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ class StreamManager:

def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
# TODO(#5996) Make this local to the asyncio thread.
self._request_queue: Optional[asyncio.Queue] = None
# Used to determine whether the stream coroutine is actively running, and provides a way to
# cancel it.
self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None
Expand All @@ -121,6 +119,16 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
# interface.
self._response_demux = ResponseDemux()
self._next_available_message_id = 0
# Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it
# is used by asyncio coroutines.
self._request_queue = self._executor.submit(self._make_request_queue).result()

async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]:
"""Returns a queue used to back the request iterator passed to the stream.
If `None` is put into the queue, the request iterator will stop.
"""
return asyncio.Queue()

def submit(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
Expand Down Expand Up @@ -153,8 +161,12 @@ def submit(
raise ValueError('Program name must be set.')

if self._manage_stream_loop_future is None or self._manage_stream_loop_future.done():
self._manage_stream_loop_future = self._executor.submit(self._manage_stream)
return self._executor.submit(self._manage_execution, project_name, program, job)
self._manage_stream_loop_future = self._executor.submit(
self._manage_stream, self._request_queue
)
return self._executor.submit(
self._manage_execution, self._request_queue, project_name, program, job
)

def stop(self) -> None:
"""Closes the open stream and resets all management resources."""
Expand All @@ -168,17 +180,19 @@ def stop(self) -> None:

def _reset(self):
"""Resets the manager state."""
self._request_queue = None
self._manage_stream_loop_future = None
self._response_demux = ResponseDemux()
self._request_queue = self._executor.submit(self._make_request_queue).result()

@property
def _executor(self) -> AsyncioExecutor:
# We must re-use a single Executor due to multi-threading issues in gRPC
# clients: https://github.com/grpc/grpc/issues/25364.
return AsyncioExecutor.instance()

async def _manage_stream(self) -> None:
async def _manage_stream(
self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]
) -> None:
"""The stream coroutine, an asyncio coroutine to manage QuantumRunStream.
This coroutine reads responses from the stream and forwards them to the ResponseDemux, where
Expand All @@ -187,25 +201,32 @@ async def _manage_stream(self) -> None:
When the stream breaks, the stream is reopened, and all execution coroutines are notified.
There is at most a single instance of this coroutine running.
Args:
request_queue: The queue holding requests from the execution coroutine.
"""
self._request_queue = asyncio.Queue()
while True:
try:
# The default gRPC client timeout is used.
response_iterable = await self._grpc_client.quantum_run_stream(
_request_iterator(self._request_queue)
_request_iterator(request_queue)
)
async for response in response_iterable:
self._response_demux.publish(response)
except asyncio.CancelledError:
await request_queue.put(None)
break
except BaseException as e:
# TODO(#5996) Close the request iterator to close the existing stream.
# Note: the message ID counter is not reset upon a new stream.
await request_queue.put(None)
self._response_demux.publish_exception(e) # Raise to all request tasks

async def _manage_execution(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
self,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
project_name: str,
program: quantum.QuantumProgram,
job: quantum.QuantumJob,
) -> Union[quantum.QuantumResult, quantum.QuantumJob]:
"""The execution coroutine, an asyncio coroutine to manage the lifecycle of a job execution.
Expand All @@ -216,28 +237,33 @@ async def _manage_execution(
error by sending another request. The exact request type depends on the error.
There is one execution coroutine per running job submission.
Args:
request_queue: The queue used to send requests to the stream coroutine.
project_name: The full project ID resource path associated with the job.
program: The Quantum Engine program representing the circuit to be executed.
job: The Quantum Engine job to be executed.
Raises:
concurrent.futures.CancelledError: if either the request is cancelled or the stream
coroutine is cancelled.
google.api_core.exceptions.GoogleAPICallError: if the stream breaks with a non-retryable
error.
ValueError: if the response is of a type which is not recognized by this client.
"""
# Construct requests ahead of time to be reused for retries.
create_program_and_job_request = quantum.QuantumRunStreamRequest(
parent=project_name,
create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest(
parent=project_name, quantum_program=program, quantum_job=job
),
)

while self._request_queue is None:
# Wait for the stream coroutine to start.
# Ignoring coverage since this is rarely triggered.
# TODO(#5996) Consider awaiting for the queue to become available, once it is changed
# to be local to the asyncio thread.
await asyncio.sleep(1) # pragma: no cover

current_request = create_program_and_job_request
while True:
try:
current_request.message_id = self._generate_message_id()
response_future = self._response_demux.subscribe(current_request.message_id)
await self._request_queue.put(current_request)
await request_queue.put(current_request)
response = await response_future

# Broken stream
Expand Down Expand Up @@ -325,16 +351,15 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool:
return any(isinstance(e, exception_type) for exception_type in RETRYABLE_GOOGLE_API_EXCEPTIONS)


# TODO(#5996) Add stop signal to the request iterator.
async def _request_iterator(
request_queue: asyncio.Queue,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
) -> AsyncIterator[quantum.QuantumRunStreamRequest]:
"""The request iterator for Quantum Engine client RPC quantum_run_stream().
Every call to this method generates a new iterator.
"""
while True:
yield await request_queue.get()
while request := await request_queue.get():
yield request


def _to_create_job_request(
Expand Down
127 changes: 116 additions & 11 deletions cirq-google/cirq_google/engine/stream_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,26 @@ def setup(client_constructor):
class FakeQuantumRunStream:
"""A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob."""

_REQUEST_STOPPED = 'REQUEST_STOPPED'

def __init__(self):
self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = []
self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = []
self._executor = AsyncioExecutor.instance()
self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]()
self._request_iterator_stopped = duet.AwaitableFuture()
# asyncio.Queue needs to be initialized inside the asyncio thread because all callers need
# to use the same event loop.
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
self._responses_and_exceptions_future: duet.AwaitableFuture[
asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]]
] = duet.AwaitableFuture()

async def quantum_run_stream(
self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs
) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]:
"""Fakes the QuantumRunStream RPC.
Once a request is received, it is appended to `stream_requests`, and the test calling
Once a request is received, it is appended to `all_stream_requests`, and the test calling
`wait_for_requests()` is notified.
The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a
Expand All @@ -91,25 +96,29 @@ async def quantum_run_stream(
This is called from the asyncio thread.
"""
responses_and_exceptions: asyncio.Queue = asyncio.Queue()
responses_and_exceptions: asyncio.Queue[
Union[quantum.QuantumRunStreamResponse, BaseException]
] = asyncio.Queue()
self._responses_and_exceptions_future.try_set_result(responses_and_exceptions)

async def read_requests():
async for request in requests:
self.all_stream_requests.append(request)
self._request_buffer.add(request)
await responses_and_exceptions.put(FakeQuantumRunStream._REQUEST_STOPPED)
self._request_iterator_stopped.try_set_result(None)

async def response_iterator():
asyncio.create_task(read_requests())
while True:
response_or_exception = await responses_and_exceptions.get()
if isinstance(response_or_exception, quantum.QuantumRunStreamResponse):
yield response_or_exception
else: # isinstance(response_or_exception, BaseException)
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
raise response_or_exception
while (
message := await responses_and_exceptions.get()
) != FakeQuantumRunStream._REQUEST_STOPPED:
if isinstance(message, quantum.QuantumRunStreamResponse):
yield message
else: # isinstance(message, BaseException)
self._responses_and_exceptions_future = duet.AwaitableFuture()
raise message

await asyncio.sleep(0)
return response_iterator()

async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None:
Expand Down Expand Up @@ -158,6 +167,14 @@ async def send():

await self._executor.submit(send)

async def wait_for_request_iterator_stop(self):
"""Wait for the request iterator to stop.
This must be called from a duet thread.
"""
await self._request_iterator_stopped
self._request_iterator_stopped = duet.AwaitableFuture()


class TestResponseDemux:
@pytest.fixture
Expand Down Expand Up @@ -704,3 +721,91 @@ def test_get_retry_request_or_raise_expects_stream_error(
create_quantum_program_and_job_request,
create_quantum_job_request,
)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_broken_stream_stops_request_iterator(self, client_constructor):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result,
)
)
await actual_result_future
await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable'))
await fake_client.wait_for_request_iterator_stop()
manager.stop()

duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_stop_stops_request_iterator(self, client_constructor):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result,
)
)
await actual_result_future
manager.stop()
await fake_client.wait_for_request_iterator_stop()

duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_submit_after_stream_breakage(self, client_constructor):
expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result0_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result0,
)
)
actual_result0 = await actual_result0_future
await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable'))
actual_result1_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[1].message_id,
result=expected_result1,
)
)
actual_result1 = await actual_result1_future
manager.stop()

assert len(fake_client.all_stream_requests) == 2
assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0]
assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1]
assert actual_result0 == expected_result0
assert actual_result1 == expected_result1

duet.run(test)

0 comments on commit 6c14cfa

Please sign in to comment.