Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal cleanup of P2P code #8907

Merged
merged 7 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def add_partition(
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
Expand Down Expand Up @@ -402,6 +403,9 @@ def read(self, path: Path) -> tuple[Any, int]:
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""

def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down Expand Up @@ -475,9 +479,6 @@ def create_new_run(
participating_workers=set(worker_for.values()),
)

def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""

@abc.abstractmethod
def create_run_on_worker(
self,
Expand Down Expand Up @@ -522,7 +523,7 @@ def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P shuffling {id} failed during transfer phase") from e
raise RuntimeError(f"P2P {id} failed during transfer phase") from e


@contextlib.contextmanager
Expand All @@ -538,7 +539,7 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]:
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e
raise RuntimeError(f"P2P {id} failed during unpack phase") from e


def _handle_datetime(buf: Any) -> Any:
Expand All @@ -561,3 +562,16 @@ def _mean_shard_size(shards: Iterable) -> int:
if count == 10:
break
return size // count if count else 0


def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
except Reschedule as e:
raise e
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"P2P {id} failed during barrier phase") from e
13 changes: 9 additions & 4 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from dask.tokenize import tokenize

from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin
from distributed.shuffle._shuffle import shuffle_barrier, shuffle_transfer
from distributed.shuffle._core import (
ShuffleId,
barrier_key,
get_worker_plugin,
p2p_barrier,
)
from distributed.shuffle._shuffle import shuffle_transfer

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -411,8 +416,8 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:

_barrier_key_left = barrier_key(ShuffleId(token_left))
_barrier_key_right = barrier_key(ShuffleId(token_right))
dsk[_barrier_key_left] = (shuffle_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (shuffle_barrier, token_right, transfer_keys_right)
dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (p2p_barrier, token_right, transfer_keys_right)

name = self.name
for part_out in self.parts_out:
Expand Down
5 changes: 3 additions & 2 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._shuffle import barrier_key
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand Down Expand Up @@ -823,7 +824,7 @@ def partial_rechunk(
transfer_keys.append(t.ref())

dsk[_barrier_key] = barrier = Task(
_barrier_key, shuffle_barrier, partial_token, transfer_keys
_barrier_key, p2p_barrier, partial_token, transfer_keys
)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
Expand Down
35 changes: 9 additions & 26 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from dask.utils import is_dataframe_like

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter
from distributed.shuffle._arrow import (
buffers_to_table,
Expand All @@ -49,12 +48,9 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._exceptions import (
DataUnavailable,
P2PConsistencyError,
P2POutOfDiskError,
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
Expand Down Expand Up @@ -106,19 +102,6 @@ def shuffle_unpack(
)


def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
except Reschedule as e:
raise e
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e


def rearrange_by_column_p2p(
df: DataFrame,
column: str,
Expand Down Expand Up @@ -306,7 +289,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
dsk[t.key] = t
transfer_keys.append(t.ref())

barrier = Task(_barrier_key, shuffle_barrier, token, transfer_keys)
barrier = Task(_barrier_key, p2p_barrier, token, transfer_keys)
dsk[barrier.key] = barrier

name = self.name
Expand Down Expand Up @@ -570,6 +553,12 @@ def read(self, path: Path) -> tuple[pa.Table, int]:
def deserialize(self, buffer: Any) -> Any:
return deserialize_table(buffer)

def validate_data(self, data: pd.DataFrame) -> None:
if not is_dataframe_like(data):
raise TypeError(f"Expected {data=} to be a DataFrame, got {type(data)}.")
if set(data.columns) != set(self.meta.columns):
raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.")


@dataclass(frozen=True)
class DataFrameShuffleSpec(ShuffleSpec[int]):
Expand All @@ -586,12 +575,6 @@ def output_partitions(self) -> Generator[int]:
def pick_worker(self, partition: int, workers: Sequence[str]) -> str:
return _get_worker_for_range_sharding(self.npartitions, partition, workers)

def validate_data(self, data: pd.DataFrame) -> None:
if not is_dataframe_like(data):
raise TypeError(f"Expected {data=} to be a DataFrame, got {type(data)}.")
if set(data.columns) != set(self.meta.columns):
raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.")

def create_run_on_worker(
self,
run_id: int,
Expand Down
8 changes: 0 additions & 8 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def add_partition(
spec: ShuffleSpec,
**kwargs: Any,
) -> int:
spec.validate_data(data)
shuffle_run = self.get_or_create_shuffle(spec)
return shuffle_run.add_partition(
data=data,
Expand Down Expand Up @@ -387,13 +386,6 @@ async def _get_shuffle_run(
shuffle_id=shuffle_id, run_id=run_id
)

async def _get_or_create_shuffle(
self,
spec: ShuffleSpec,
key: Key,
) -> ShuffleRun:
return await self.shuffle_runs.get_or_create(spec=spec, key=key)

async def teardown(self, worker: Worker) -> None:
assert not self.closed

Expand Down
6 changes: 3 additions & 3 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,7 +2546,7 @@ def make_partition(i):

with raises_with_cause(
RuntimeError,
r"(shuffling \w*|shuffle_barrier) failed",
r"P2P \w* failed",
pa.ArrowTypeError,
"incompatible types",
):
Expand Down Expand Up @@ -2744,7 +2744,7 @@ async def test_flaky_connect_fails_without_retry(c, s, a, b):
with mock.patch.object(a, "rpc", rpc):
with raises_with_cause(
expected_exception=RuntimeError,
match="P2P shuffling.*transfer",
match="P2P.*transfer",
expected_cause=OSError,
match_cause=None,
):
Expand Down Expand Up @@ -2899,7 +2899,7 @@ def data_gen():

with raises_with_cause(
RuntimeError,
r"shuffling \w* failed",
r"P2P \w* failed",
ValueError,
"meta",
):
Expand Down
Loading