Skip to content

Commit

Permalink
[Core] Fix check failure when sync and async tasks are mixed up (ray-…
Browse files Browse the repository at this point in the history
…project#41724)

When sync and async tasks are mixed up, there are sometimes where TaskID is not correctly set, which causes the check failure. This is the case where check failure correctly found a bug.

I fixed the issue by passing TaskID correctly. Also improved get_current_task_id API to always return the correct task ID when used within raylet.pyx
  • Loading branch information
rkooo567 authored Dec 9, 2023
1 parent 1dffb4d commit 2f29500
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 26 deletions.
47 changes: 23 additions & 24 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3295,9 +3295,26 @@ cdef class CoreWorker:
with nogil:
CCoreWorkerProcess.GetCoreWorker().Exit(c_exit_type, detail, null_ptr)

def get_current_task_id(self):
return TaskID(
CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskId().Binary())
def get_current_task_id(self) -> TaskID:
"""Return the current task ID.

If it is a normal task, it returns the TaskID from the main thread.
If it is a threaded actor, it returns the TaskID for the current thread.
If it is async actor, it returns the TaskID stored in contextVar for
the current asyncio task.
"""
# We can only obtain the correct task ID within asyncio task
# via async_task_id contextvar. We try this first.
# It is needed because the core Worker's GetCurrentTaskId API
# doesn't have asyncio context, thus it cannot return the
# correct TaskID.
task_id = async_task_id.get()
if task_id is None:
# if it is not within asyncio context, fallback to TaskID
# obtainable from core worker.
task_id = TaskID(
CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskId().Binary())
return task_id

def get_current_task_attempt_number(self):
return CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskAttemptNumber()
Expand Down Expand Up @@ -3749,9 +3766,6 @@ cdef class CoreWorker:
c_vector[CObjectID] incremented_put_arg_ids
c_string serialized_retry_exception_allowlist
CTaskID current_c_task_id
TaskID task_id_in_async_context = async_task_id.get()
# This task id is incorrect if async task is used.
# In this case, we should use task_id_in_async_context
TaskID current_task = self.get_current_task_id()

self.python_scheduling_strategy_to_c(
Expand All @@ -3775,13 +3789,7 @@ cdef class CoreWorker:
generator_backpressure_num_objects,
serialized_runtime_env_info)

# We are in the async context. We have to obtain
# the task id from this context var. get_current_task_id()
# doesn't contain the correct id for asyncio tasks.
if task_id_in_async_context is not None:
current_c_task_id = task_id_in_async_context.native()
else:
current_c_task_id = current_task.native()
current_c_task_id = current_task.native()

with nogil:
return_refs = CCoreWorkerProcess.GetCoreWorker().SubmitTask(
Expand Down Expand Up @@ -3971,9 +3979,6 @@ cdef class CoreWorker:
c_vector[CObjectReference] return_refs
c_vector[CObjectID] incremented_put_arg_ids
CTaskID current_c_task_id = CTaskID.Nil()
TaskID task_id_in_async_context = async_task_id.get()
# This task id is incorrect if async task is used.
# In this case, we should use task_id_in_async_context
TaskID current_task = self.get_current_task_id()
c_string serialized_retry_exception_allowlist

Expand All @@ -3990,13 +3995,7 @@ cdef class CoreWorker:
self, language, args, &args_vector, function_descriptor,
&incremented_put_arg_ids)

# We are in the async context. We have to obtain
# the task id from this context var. get_current_task_id()
# doesn't contain the correct id for asyncio tasks.
if task_id_in_async_context is not None:
current_c_task_id = task_id_in_async_context.native()
else:
current_c_task_id = current_task.native()
current_c_task_id = current_task.native()

with nogil:
status = CCoreWorkerProcess.GetCoreWorker().SubmitActorTask(
Expand Down Expand Up @@ -4814,7 +4813,7 @@ cdef class CoreWorker:
else:
return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId(
owner_address,
CTaskID.Nil(),
task_id,
make_optional[ObjectIDIndexType](
<int>1 + <int>return_size + <int>generator_index))

Expand Down
57 changes: 57 additions & 0 deletions python/ray/tests/test_streaming_generator_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import time
import gc
import random
import asyncio
from typing import Optional
from pydantic import BaseModel

import ray

Expand Down Expand Up @@ -134,6 +137,60 @@ def f():
ray.get(gen)


def test_sync_async_mix_regression_test(shutdown_only):
"""Verify when sync and async tasks are mixed up
it doesn't raise a segfault
https://github.com/ray-project/ray/issues/41346
"""

class PayloadPydantic(BaseModel):
class Error(BaseModel):
msg: str
code: int
type: str

text: Optional[str] = None
ts: Optional[float] = None
reason: Optional[str] = None
error: Optional[Error] = None

ray.init()

@ray.remote
class B:
def __init__(self, a):
self.a = a

async def stream(self):
async for ref in self.a.stream.remote(1):
print("stream")
await ref

async def start(self):
await asyncio.gather(*[self.stream() for _ in range(2)])

@ray.remote
class A:
def stream(self, i):
payload = PayloadPydantic(
text="Test output",
ts=time.time(),
reason="Success!",
)

for _ in range(10):
yield payload

async def aio_stream(self):
for _ in range(10):
yield 1

a = A.remote()
b = B.remote(a)
ray.get(b.start.remote())


if __name__ == "__main__":
import os

Expand Down
5 changes: 3 additions & 2 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2991,10 +2991,11 @@ Status CoreWorker::ReportGeneratorItemReturns(
const Status &status, const rpc::ReportGeneratorItemReturnsReply &reply) {
RAY_LOG(DEBUG) << "ReportGeneratorItemReturns replied. " << generator_id
<< "index: " << item_index
<< ". Total object consumed: " << waiter->TotalObjectConsumed()
<< ". Total object generated: " << waiter->TotalObjectGenerated()
<< ". total_consumed_reported: "
<< reply.total_num_object_consumed();
RAY_CHECK(waiter != nullptr);
RAY_LOG(DEBUG) << "Total object consumed: " << waiter->TotalObjectConsumed()
<< ". Total object generated: " << waiter->TotalObjectGenerated();
if (status.ok()) {
/// Since unary gRPC requests are not ordered, it is possible the stale
/// total value can be replied. Since total object consumed only can
Expand Down

0 comments on commit 2f29500

Please sign in to comment.