Skip to content

Commit

Permalink
Add batch size, storage constraint/reservation, pooling factors, and …
Browse files Browse the repository at this point in the history
…output type to stats table and expose stats table (pytorch#462)

Summary:
Pull Request resolved: pytorch#462

Added
- batch size
- longest critical path
- peak memory pressure
- usable memory
- KJT storage
- dense storage
- pooling factors
- number of features
- output type (pooled vs seq)

Set debug mode to default to true

Made stats table info to an attribute so it can be exposed and persisted when saving a sharding plan.

Reviewed By: dstaay-fb

Differential Revision: D37318699

fbshipit-source-id: fa2b28d313d5319eee13ef9db6953bff72bc2b1f
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Jun 22, 2022
1 parent e8e21dc commit 9faf356
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 44 deletions.
4 changes: 3 additions & 1 deletion torchrec/distributed/planner/parallelized_planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
performance_model: Optional[PerfModel] = None,
stats: Optional[Stats] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
debug: bool = False,
debug: bool = True,
) -> None:
self._topology = topology
self._constraints = constraints
Expand Down Expand Up @@ -278,6 +278,8 @@ def get_best_plan(
self._stats.log(
sharding_plan=sharding_plan,
topology=self._topology,
storage_constraint=storage_constraint,
storage_reservation=self._storage_reservation,
num_proposals=self._num_proposals,
num_plans=self._num_plans,
run_time=end_time - start_time,
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
performance_model: Optional[PerfModel] = None,
stats: Optional[Stats] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
debug: bool = False,
debug: bool = True,
) -> None:
self._topology = topology
self._constraints = constraints
Expand Down Expand Up @@ -261,6 +261,8 @@ def plan(
self._stats.log(
sharding_plan=sharding_plan,
topology=self._topology,
storage_constraint=storage_constraint,
storage_reservation=self._storage_reservation,
num_proposals=self._num_proposals,
num_plans=self._num_plans,
run_time=end_time - start_time,
Expand Down
194 changes: 160 additions & 34 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
from typing import Any, cast, Dict, List, Tuple, Union

from torchrec.distributed.planner.constants import BIGINT_DTYPE
from torchrec.distributed.planner.types import ShardingOption, Stats, Storage, Topology
from torchrec.distributed.planner.storage_reservations import (
FixedPercentageReservation,
HeuristicalStorageReservation,
)
from torchrec.distributed.planner.types import (
ShardingOption,
Stats,
Storage,
StorageReservation,
Topology,
)
from torchrec.distributed.planner.utils import bytes_to_gb, bytes_to_mb
from torchrec.distributed.types import ParameterSharding, ShardingPlan, ShardingType

Expand All @@ -26,25 +36,34 @@ class EmbeddingStats(Stats):
Stats for a sharding planner execution.
"""

def __init__(self) -> None:
self._width: int = MIN_WIDTH
self._stats_table: List[str] = []

def log(
self,
sharding_plan: ShardingPlan,
topology: Topology,
storage_constraint: Topology,
storage_reservation: StorageReservation,
num_proposals: int,
num_plans: int,
run_time: float,
best_plan: List[ShardingOption],
debug: bool = False,
debug: bool = True,
) -> None:
"""
Logs stats for a given sharding plan to stdout.
Logs stats for a given sharding plan.
Provides a tabular view of stats for the given sharding plan with per device
storage usage (HBM and DDR), perf, input, output, and number/type of shards.
Args:
sharding_plan (ShardingPlan): sharding plan chosen by the planner.
topology (Topology): device topology.
storage_constraint (Topology): available storage after storage reservation.
storage_reservation (StorageReservation): reserves storage for unsharded
parts of the model
num_proposals (int): number of proposals evaluated.
num_plans (int): number of proposals successfully partitioned.
run_time (float): time taken to find plan (in seconds).
Expand Down Expand Up @@ -120,7 +139,7 @@ def log(
],
]

for rank, device in enumerate(topology.devices):
for rank, device in enumerate(storage_constraint.devices):
used_hbm_gb = bytes_to_gb(used_hbm[rank])
used_hbm_ratio = (
used_hbm[rank] / device.storage.hbm
Expand Down Expand Up @@ -154,79 +173,99 @@ def log(
]
)
formatted_table = _format_table(table)
width = max(MIN_WIDTH, len(formatted_table[0]) + 8)
self._width = max(self._width, len(formatted_table[0]) + 8)

if debug:
param_table: List[List[Union[str, int]]] = [
["FQN", "Sharding", "Compute Kernel", "Perf (ms)", "Ranks"],
[
"FQN",
"Sharding",
"Compute Kernel",
"Perf (ms)",
"Pooling Factor",
"Output",
"Features",
"Ranks",
],
[
"-----",
"----------",
"----------------",
"-----------",
"----------------",
"--------",
"----------",
"-------",
],
]
for so in best_plan:
# pyre-ignore[6]
ranks = sorted([shard.rank for shard in so.shards])
if len(ranks) > 1 and ranks == list(range(min(ranks), max(ranks) + 1)):
ranks = [f"{min(ranks)}-{max(ranks)}"]
ranks = sorted([cast(int, shard.rank) for shard in so.shards])
ranks = _collapse_consecutive_ranks(ranks)
shard_perfs = str(
round(sum([cast(float, shard.perf) for shard in so.shards]), 3)
)
pooling_factor = str(round(sum(so.input_lengths), 3))
output = "pooled" if so.is_pooled else "sequence"
num_features = len(so.input_lengths)
param_table.append(
[
so.fqn,
_get_sharding_type_abbr(so.sharding_type),
so.compute_kernel,
shard_perfs,
",".join([str(rank) for rank in ranks]),
pooling_factor,
output,
num_features,
",".join(ranks),
]
)
formatted_param_table = _format_table(param_table)
width = max(width, len(formatted_param_table[0]) + 6)
self._width = max(self._width, len(formatted_param_table[0]) + 6)

logger.info("#" * width)
self._stats_table.clear()
self._stats_table.append("#" * self._width)
header_text = "--- Planner Statistics ---"
logger.info(f"#{header_text: ^{width-2}}#")
self._stats_table.append(f"#{header_text: ^{self._width-2}}#")

iter_text = (
f"--- Evalulated {num_proposals} proposal(s), "
f"found {num_plans} possible plan(s), "
f"ran for {run_time:.2f}s ---"
)
logger.info(f"#{iter_text: ^{width-2}}#")
self._stats_table.append(f"#{iter_text: ^{self._width-2}}#")

divider = "-" * (width - 4)
logger.info(f"#{divider: ^{width-2}}#")
divider = "-" * (self._width - 4)
self._stats_table.append(f"#{divider: ^{self._width-2}}#")

for row in formatted_table:
logger.info(f"# {row: <{width-3}}#")
self._stats_table.append(f"# {row: <{self._width-3}}#")

logger.info(f"#{'' : ^{width-2}}#")
legend = "Input: MB/iteration, Output: MB/iteration, Shards: number of tables"
logger.info(f"# {legend: <{width-3}}#")
hbm_info = "HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients"
logger.info(f"# {hbm_info: <{width-3}}#")
logger.info(f"#{'' : ^{width-2}}#")
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {legend: <{self._width-3}}#")
self._stats_table.append(f"# {hbm_info: <{self._width-3}}#")

compute_kernels_count = [
f"{compute_kernel}: {count}"
for compute_kernel, count in sorted(compute_kernels_to_count.items())
]
logger.info(f"# {'Compute Kernels:' : <{width-3}}#")
for compute_kernel_count in compute_kernels_count:
logger.info(f"# {compute_kernel_count : <{width-5}}#")
if debug:
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {'Parameter Info:' : <{self._width-3}}#")
for row in formatted_param_table:
self._stats_table.append(f"# {row: <{self._width-3}}#")

batch_size = f"Batch Size: {topology.batch_size}"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {batch_size : <{self._width-3}}#")

self._log_compute_kernel_stats(compute_kernels_to_count)

if debug:
logger.info(f"#{'' : ^{width-2}}#")
logger.info(f"# {'Parameter Info:' : <{width-3}}#")
self._log_max_perf_and_max_hbm(perf, used_hbm)
self._log_storage_reservation_stats(storage_reservation, topology)

for row in formatted_param_table:
logger.info(f"# {row: <{width-3}}#")
self._stats_table.append("#" * self._width)

logger.info("#" * width)
for row in self._stats_table:
logger.info(row)

def _get_shard_stats(
self,
Expand Down Expand Up @@ -296,6 +335,86 @@ def _get_shard_stats(
)
return ranks, input_sizes, output_sizes

def _log_max_perf_and_max_hbm(self, perf: List[float], used_hbm: List[int]) -> None:
max_perf = max(perf)
max_perf_indices = [i for i in range(len(perf)) if perf[i] == max_perf]
rank_text = "ranks" if len(max_perf_indices) > 1 else "rank"
max_perf_indices = _collapse_consecutive_ranks(max_perf_indices)
max_perf_ranks = f"{rank_text} {','.join(max_perf_indices)}"
longest_critical_path = (
f"Longest Critical Path: {round(max_perf, 3)} ms on {max_perf_ranks}"
)
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {longest_critical_path : <{self._width-3}}#")

max_hbm = max(used_hbm)
max_hbm_indices = [i for i in range(len(used_hbm)) if used_hbm[i] == max_hbm]
rank_text = "ranks" if len(max_hbm_indices) > 1 else "rank"
max_hbm_indices = _collapse_consecutive_ranks(max_hbm_indices)
max_hbm_ranks = f"{rank_text} {','.join(max_hbm_indices)}"
peak_memory_pressure = f"Peak Memory Pressure: {round(bytes_to_gb(max_hbm), 3)} GB on {max_hbm_ranks}"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {peak_memory_pressure : <{self._width-3}}#")

def _log_storage_reservation_stats(
self,
storage_reservation: StorageReservation,
topology: Topology,
) -> None:
if isinstance(
storage_reservation,
(FixedPercentageReservation, HeuristicalStorageReservation),
):
percentage = storage_reservation._percentage
device_storage = topology.devices[0].storage
usable_hbm = round(
bytes_to_gb(int((1 - percentage) * device_storage.hbm)), 3
)
usable_ddr = round(
bytes_to_gb(int((1 - percentage) * device_storage.ddr)), 3
)
usable_memory = f"HBM: {usable_hbm} GB, DDR: {usable_ddr} GB"
usable_percentage = f"Percent of Total: {(1 - percentage):.0%}"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {'Usable Memory:' : <{self._width-3}}#")
self._stats_table.append(f"# {usable_memory : <{self._width-6}}#")
self._stats_table.append(f"# {usable_percentage : <{self._width-6}}#")

if isinstance(storage_reservation, HeuristicalStorageReservation):
assert storage_reservation._dense_storage is not None
dense_storage = storage_reservation._dense_storage
dense_hbm = round(bytes_to_gb(dense_storage.hbm), 3)
dense_ddr = round(bytes_to_gb(dense_storage.ddr), 3)
dense_storage = f"HBM: {dense_hbm} GB, DDR: {dense_ddr} GB"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {'Dense Storage (per rank): ' : <{self._width-3}}#"
)
self._stats_table.append(f"# {dense_storage : <{self._width-6}}#")

assert storage_reservation._kjt_storage is not None
kjt_storage = storage_reservation._kjt_storage
kjt_hbm = round(bytes_to_gb(kjt_storage.hbm), 3)
kjt_ddr = round(bytes_to_gb(kjt_storage.ddr), 3)
kjt_storage = f"HBM: {kjt_hbm} GB, DDR: {kjt_ddr} GB"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(
f"# {'KJT Storage (per rank): ' : <{self._width-3}}#"
)
self._stats_table.append(f"# {kjt_storage : <{self._width-6}}#")

def _log_compute_kernel_stats(
self, compute_kernels_to_count: Dict[str, int]
) -> None:
compute_kernels_count = [
f"{compute_kernel}: {count}"
for compute_kernel, count in sorted(compute_kernels_to_count.items())
]
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {'Compute Kernels:' : <{self._width-3}}#")
for compute_kernel_count in compute_kernels_count:
self._stats_table.append(f"# {compute_kernel_count : <{self._width-6}}#")


def _get_sharding_type_abbr(sharding_type: str) -> str:
if sharding_type == ShardingType.DATA_PARALLEL.value:
Expand Down Expand Up @@ -324,3 +443,10 @@ def _format_table(table: List[List[Union[str, int]]]) -> List[str]:
["{:>" + str(longest_col) + "}" for longest_col in longest_cols]
)
return [row_format.format(*row) for row in table]


def _collapse_consecutive_ranks(ranks: List[int]) -> List[str]:
if len(ranks) > 1 and ranks == list(range(min(ranks), max(ranks) + 1)):
return [f"{min(ranks)}-{max(ranks)}"]
else:
return [str(rank) for rank in ranks]
23 changes: 15 additions & 8 deletions torchrec/distributed/planner/storage_reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class HeuristicalStorageReservation(StorageReservation):
def __init__(self, percentage: float) -> None:
assert percentage >= 0 and percentage <= 1
self._percentage: float = percentage
self._dense_storage: Optional[Storage] = None
self._kjt_storage: Optional[Storage] = None

def reserve(
self,
Expand Down Expand Up @@ -90,18 +92,20 @@ def reserve(

_reserve_storage_percentage(reserved_topology, self._percentage)

_reserve_unshardable_tensors_storage(
self._dense_storage = _reserve_dense_storage(
reserved_topology, module, shardable_parameters
)

_reserve_kjt_storage(reserved_topology, all_input_lengths, BIGINT_DTYPE)
self._kjt_storage = _reserve_kjt_storage(
reserved_topology, all_input_lengths, BIGINT_DTYPE
)

return reserved_topology


def _reserve_unshardable_tensors_storage(
def _reserve_dense_storage(
topology: Topology, module: nn.Module, shardable_parameters: Set[nn.Parameter]
) -> None:
) -> Storage:
unshardable_parameters = set(module.parameters()) - shardable_parameters

unshardable_parameters_size = sum(
Expand Down Expand Up @@ -130,18 +134,19 @@ def _reserve_unshardable_tensors_storage(
for device in topology.devices:
device.storage -= unshardable_tensors_storage

return unshardable_tensors_storage


def _reserve_kjt_storage(
topology: Topology,
all_input_lengths: List[float],
input_data_type_size: int,
) -> None:
) -> Storage:
kjt_size = (
math.ceil(
topology.batch_size
# pyre-ignore[58]
float(topology.batch_size)
* sum(all_input_lengths)
* input_data_type_size
* float(input_data_type_size)
)
* 20 # 2 pipelined batches each with 10 internal copies
)
Expand All @@ -154,6 +159,8 @@ def _reserve_kjt_storage(
for device in topology.devices:
device.storage -= kjt_storage

return kjt_storage


def _reserve_storage_percentage(topology: Topology, percent: float) -> None:
for device in topology.devices:
Expand Down
Loading

0 comments on commit 9faf356

Please sign in to comment.