Skip to content

Commit

Permalink
Update type check workflow to PyTorch 2.4 (Lightning-AI#20155)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 3, 2024
1 parent b5695c5 commit 1bc2aad
Show file tree
Hide file tree
Showing 19 changed files with 52 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10.6"
python-version: "3.11"

- name: Mypy cache
uses: actions/cache@v4
Expand Down
1 change: 1 addition & 0 deletions docs/source-fabric/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@
# These seem to be missing in reference generated API
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
("py:class", "torch.amp.grad_scaler.GradScaler"),
# Mocked optional packages
("py:class", "deepspeed.*"),
("py:.*", "torch_xla.*"),
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy"),
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
("py:class", "torch.amp.grad_scaler.GradScaler"),
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
("py:func", "torch.inference_mode"),
("py:meth", "torch.mean"),
Expand Down
4 changes: 2 additions & 2 deletions requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.8.0
torch==2.2.0
mypy==1.11.0
torch==2.4.0

types-Markdown
types-PyYAML
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T

@override
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group)
dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type]

@override
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group)
dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type]
return tensor

def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
scaler: Optional["torch.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) ->

# let torch automatically infer the writer to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
save(converted_state, checkpoint_id=path) # type: ignore[call-arg]
save(converted_state, checkpoint_id=path)
else: # deprecated
from torch.distributed.checkpoint import FileSystemWriter

Expand All @@ -883,7 +883,7 @@ def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> No

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
load(module_state, checkpoint_id=path)
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

Expand Down
8 changes: 4 additions & 4 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def setup_environment(self) -> None:
)

@override
def setup_module(self, module: TModel) -> TModel:
def setup_module(self, module: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel

if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
Expand All @@ -181,7 +181,7 @@ def setup_module(self, module: TModel) -> TModel:
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)

module = self._parallelize_fn(module, self.device_mesh)
module = self._parallelize_fn(module, self.device_mesh) # type: ignore[arg-type]
if not isinstance(module, Module):
raise TypeError(
f"The `parallelize_fn` must return a `nn.Module` instance, but got: {type(module).__name__}"
Expand Down Expand Up @@ -462,7 +462,7 @@ def _load_checkpoint(
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
broadcast_from_rank0=True,
full_state_dict=True,
strict=strict,
)
Expand Down Expand Up @@ -546,7 +546,7 @@ def _load_raw_module_state(
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
broadcast_from_rank0=True,
full_state_dict=True,
# must be set False to allow loading each param separately below
strict=False,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
result: Dict[str, Any] = {}
for k, v in params.items():
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
if is_dataclass(v):
if is_dataclass(v) and not isinstance(v, type):
v = asdict(v)
elif isinstance(v, Namespace):
v = vars(v)
Expand Down
27 changes: 12 additions & 15 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ def _generate_metrics_texts(self) -> Generator[str, None, None]:
value = f"{value:{self._metrics_format}}"
yield f"{name}: {value}"

else:
Task, Style = Any, Any # type: ignore[assignment, misc]


@dataclass
class RichProgressBarTheme:
Expand All @@ -209,14 +206,14 @@ class RichProgressBarTheme:
"""

description: Union[str, Style] = "white"
progress_bar: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
progress_bar_pulse: Union[str, Style] = "#6206E0"
batch_progress: Union[str, Style] = "white"
time: Union[str, Style] = "grey54"
processing_speed: Union[str, Style] = "grey70"
metrics: Union[str, Style] = "white"
description: Union[str, "Style"] = "white"
progress_bar: Union[str, "Style"] = "#6206E0"
progress_bar_finished: Union[str, "Style"] = "#6206E0"
progress_bar_pulse: Union[str, "Style"] = "#6206E0"
batch_progress: Union[str, "Style"] = "white"
time: Union[str, "Style"] = "grey54"
processing_speed: Union[str, "Style"] = "grey70"
metrics: Union[str, "Style"] = "white"
metrics_text_delimiter: str = " "
metrics_format: str = ".3f"

Expand Down Expand Up @@ -298,25 +295,25 @@ def is_disabled(self) -> bool:
return not self.is_enabled

@property
def train_progress_bar(self) -> Task:
def train_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.train_progress_bar_id is not None
return self.progress.tasks[self.train_progress_bar_id]

@property
def val_sanity_check_bar(self) -> Task:
def val_sanity_check_bar(self) -> "Task":
assert self.progress is not None
assert self.val_sanity_progress_bar_id is not None
return self.progress.tasks[self.val_sanity_progress_bar_id]

@property
def val_progress_bar(self) -> Task:
def val_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.val_progress_bar_id is not None
return self.progress.tasks[self.val_progress_bar_id]

@property
def test_progress_bar(self) -> Task:
def test_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.test_progress_bar_id is not None
return self.progress.tasks[self.test_progress_bar_id]
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def __init__(
# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
raise MisconfigurationException(
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'
f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " # type: ignore[union-attr]
)

self.pruning_fn = pruning_fn
Expand Down Expand Up @@ -308,7 +308,7 @@ def apply_lottery_ticket_hypothesis(self) -> None:

def _apply_local_pruning(self, amount: float) -> None:
for module, name in self._parameters_to_prune:
self.pruning_fn(module, name=name, amount=amount)
self.pruning_fn(module, name=name, amount=amount) # type: ignore[call-arg]

def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
self._global_kwargs["amount"] = amount
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

import torch
from torch import Tensor, nn
Expand All @@ -42,7 +42,7 @@ def __init__(
swa_lrs: Union[float, List[float]],
swa_epoch_start: Union[int, float] = 0.8,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
annealing_strategy: Literal["cos", "linear"] = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
):
Expand Down Expand Up @@ -303,14 +303,14 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
dtype=module.running_var.dtype,
)
self.momenta[module] = module.momentum
module.momentum = None # type: ignore[assignment]
module.momentum = None
assert module.num_batches_tracked is not None
module.num_batches_tracked *= 0

def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment]
bn_module.momentum = self.momenta[bn_module]

@staticmethod
def update_parameters(
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def load_from_checkpoint(
"""
loaded = _load_from_checkpoint(
cls, # type: ignore[arg-type]
cls,
checkpoint_path,
map_location=map_location,
hparams_file=hparams_file,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def load_from_checkpoint(
"""
loaded = _load_from_checkpoint(
cls, # type: ignore[arg-type]
cls,
checkpoint_path,
map_location,
hparams_file,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
scaler: Optional["torch.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> N
def _wrap_closure(
self,
model: "pl.LightningModule",
optimizer: Optimizer,
optimizer: Steppable,
closure: Callable[[], Any],
) -> Any:
"""This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _generate_sync_fn(self) -> None:
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
# save the function as `_fn` as the meta are being re-created and the object references need to match.
# ignore typing, bad support for `partial`: mypy/issues/1484
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore [arg-type]
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore[arg-type,operator,misc]

@property
def __call__(self) -> Any:
Expand Down
11 changes: 7 additions & 4 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")]
_validate_optimizers_attached(trainer.optimizers, trainer.lr_scheduler_configs)

def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]:
def plot(
self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None
) -> Optional[Union["plt.Figure", "plt.SubFigure"]]:
"""Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
Expand All @@ -151,10 +153,11 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] =
lrs = self.results["lr"]
losses = self.results["loss"]

fig: Optional[Union[plt.Figure, plt.SubFigure]]
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.figure # type: ignore[assignment]
fig = ax.figure

# Plot loss as a function of the learning rate
ax.plot(lrs, losses)
Expand Down Expand Up @@ -460,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
super().__init__(optimizer, last_epoch)

@override
def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand Down Expand Up @@ -497,7 +500,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
super().__init__(optimizer, last_epoch)

@override
def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand Down
9 changes: 6 additions & 3 deletions src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def wrapper(*args: Any, **kwargs: Any) -> _R_co:
return wrapper


# trick static type checkers into thinking it's a @classmethod
# https://github.com/microsoft/pyright/issues/5865
_restricted_classmethod = classmethod if TYPE_CHECKING else _restricted_classmethod_impl
if TYPE_CHECKING:
# trick static type checkers into thinking it's a @classmethod
# https://github.com/microsoft/pyright/issues/5865
_restricted_classmethod = classmethod
else:
_restricted_classmethod = _restricted_classmethod_impl

0 comments on commit 1bc2aad

Please sign in to comment.