From 66b83e1368af3a59b111ed8843a14499b880671f Mon Sep 17 00:00:00 2001
From: awaelchli
Date: Wed, 6 Sep 2023 18:58:19 +0200
Subject: [PATCH 1/5] refactor assertions that use walrus
---
src/lightning/fabric/strategies/xla_fsdp.py | 4 ++--
src/lightning/pytorch/loops/evaluation_loop.py | 6 ++++--
src/lightning/pytorch/loops/fetchers.py | 4 ++--
src/lightning/pytorch/loops/fit_loop.py | 6 ++++--
src/lightning/pytorch/loops/prediction_loop.py | 6 ++++--
5 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py
index a4504950f4c95..4b08c880a436b 100644
--- a/src/lightning/fabric/strategies/xla_fsdp.py
+++ b/src/lightning/fabric/strategies/xla_fsdp.py
@@ -413,9 +413,10 @@ def save_checkpoint(
# ensure model parameters are updated
xm.mark_step()
+ parallel_devices = self.parallel_devices
+ assert parallel_devices is not None
if self._sequential_save:
# each host runs this in parallel, but the ranks in the host run it sequentially
- assert (parallel_devices := self.parallel_devices) is not None
for rank in range(len(parallel_devices)):
if rank == self.local_rank:
self._save_checkpoint_shard(path, state, storage_options, filter)
@@ -426,7 +427,6 @@ def save_checkpoint(
if self._state_dict_type == "full":
ckpt_prefix = str(path / "checkpoint")
ckpt_suffix = "_rank-*-of-*.pth"
- assert (parallel_devices := self.parallel_devices) is not None
if len(parallel_devices) != self.world_size: # multihost
raise OSError(
"Multihost setups do not have a shared filesystem, so the checkpoint shards cannot be consolidated"
diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py
index 19c353b49333d..2f8b2cf54de89 100644
--- a/src/lightning/pytorch/loops/evaluation_loop.py
+++ b/src/lightning/pytorch/loops/evaluation_loop.py
@@ -122,7 +122,8 @@ def run(self) -> List[_OUT_DICT]:
return []
self.reset()
self.on_run_start()
- assert (data_fetcher := self._data_fetcher) is not None
+ data_fetcher = self._data_fetcher
+ assert data_fetcher is not None
previous_dataloader_idx = 0
while True:
try:
@@ -374,7 +375,8 @@ def _evaluation_step(
"""
trainer = self.trainer
- assert (data_fetcher := self._data_fetcher) is not None
+ data_fetcher = self._data_fetcher
+ assert data_fetcher is not None
if not (using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher)):
batch = trainer.precision_plugin.convert_input(batch)
diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py
index 4638dac8a2db0..806ca58541e06 100644
--- a/src/lightning/pytorch/loops/fetchers.py
+++ b/src/lightning/pytorch/loops/fetchers.py
@@ -52,10 +52,10 @@ def __iter__(self) -> "_DataFetcher":
return self
def __next__(self) -> _ITERATOR_RETURN:
- assert (iterator := self.iterator) is not None
+ assert self.iterator is not None
self._start_profiler()
try:
- batch = next(iterator)
+ batch = next(self.iterator)
except StopIteration:
self.done = True
raise
diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py
index 6ac00b1d16011..1a188e8d76e22 100644
--- a/src/lightning/pytorch/loops/fit_loop.py
+++ b/src/lightning/pytorch/loops/fit_loop.py
@@ -337,13 +337,15 @@ def advance(self) -> None:
"""Runs one whole epoch."""
log.debug(f"{type(self).__name__}: advancing loop")
- assert (combined_loader := self._combined_loader) is not None
+ combined_loader = self._combined_loader
+ assert combined_loader is not None
if combined_loader._mode == "sequential":
raise ValueError(
f'`{type(self).__name__}` does not support the `CombinedLoader(mode="sequential")` mode.'
f" The available modes are: {[m for m in _SUPPORTED_MODES if m != 'sequential']}"
)
- assert (data_fetcher := self._data_fetcher) is not None
+ data_fetcher = self._data_fetcher
+ assert self._data_fetcher is not None
data_fetcher.setup(combined_loader)
with self.trainer.profiler.profile("run_training_epoch"):
self.epoch_loop.run(data_fetcher)
diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py
index 9c13426060cac..e2956f7a5f4b2 100644
--- a/src/lightning/pytorch/loops/prediction_loop.py
+++ b/src/lightning/pytorch/loops/prediction_loop.py
@@ -103,7 +103,8 @@ def run(self) -> Optional[_PREDICT_OUTPUT]:
return None
self.reset()
self.on_run_start()
- assert (data_fetcher := self._data_fetcher) is not None
+ data_fetcher = self._data_fetcher
+ assert data_fetcher is not None
while True:
try:
if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
@@ -218,7 +219,8 @@ def _predict_step(
"""
trainer = self.trainer
- assert (data_fetcher := self._data_fetcher) is not None
+ data_fetcher = self._data_fetcher
+ assert data_fetcher is not None
if not (using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher)):
batch = trainer.precision_plugin.convert_input(batch)
From 48b393f71b2d49005695a9ae79151837deffcbdc Mon Sep 17 00:00:00 2001
From: awaelchli
Date: Wed, 6 Sep 2023 23:19:49 +0200
Subject: [PATCH 2/5] add a simple test
---
tests/tests_pytorch/run_standalone_tasks.sh | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/tests/tests_pytorch/run_standalone_tasks.sh b/tests/tests_pytorch/run_standalone_tasks.sh
index 46f1a4279f579..7648adfb45b27 100644
--- a/tests/tests_pytorch/run_standalone_tasks.sh
+++ b/tests/tests_pytorch/run_standalone_tasks.sh
@@ -34,3 +34,7 @@ MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python strategies/scripts/
# test that ddp can launched as a module (-m option)
echo "Running ddp example as module"
python -m strategies.scripts.cli_script ${args}
+
+# test that `python -O` (optimized mode) works
+echo "Running bug report model with 'python -O'"
+python -O ../../examples/pytorch/bug_report/bug_report_model.py
From e4d19cbbd9b83038bfa3d80fabe84c0201249045 Mon Sep 17 00:00:00 2001
From: awaelchli
Date: Wed, 6 Sep 2023 23:20:45 +0200
Subject: [PATCH 3/5] mypy
---
src/lightning/pytorch/loops/fit_loop.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py
index 1a188e8d76e22..8f828a4b59f60 100644
--- a/src/lightning/pytorch/loops/fit_loop.py
+++ b/src/lightning/pytorch/loops/fit_loop.py
@@ -345,7 +345,7 @@ def advance(self) -> None:
f" The available modes are: {[m for m in _SUPPORTED_MODES if m != 'sequential']}"
)
data_fetcher = self._data_fetcher
- assert self._data_fetcher is not None
+ assert data_fetcher is not None
data_fetcher.setup(combined_loader)
with self.trainer.profiler.profile("run_training_epoch"):
self.epoch_loop.run(data_fetcher)
From 6b71bbc33c9155c2fa6c1c3e6a266b3345cb8fdc Mon Sep 17 00:00:00 2001
From: awaelchli
Date: Wed, 6 Sep 2023 23:37:00 +0200
Subject: [PATCH 4/5] chlog
---
src/lightning/pytorch/CHANGELOG.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index bb8729aa72e13..2d515ec35a123 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -241,6 +241,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the display of `v_num` in the progress bar when running with `Trainer(fast_dev_run=True)` ([#18491](https://github.com/Lightning-AI/lightning/pull/18491))
+- Fixed `UnboundLocalError` when running with `python -O` ([#18496](https://github.com/Lightning-AI/lightning/pull/18496))
+
## [2.0.7] - 2023-08-14
From 83c434d3830fd788504aedd0a712a58f5ddaf48c Mon Sep 17 00:00:00 2001
From: awaelchli
Date: Thu, 7 Sep 2023 16:50:11 +0200
Subject: [PATCH 5/5] simplify
---
src/lightning/pytorch/loops/fit_loop.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py
index 2659ee244d27a..4d02f8408c218 100644
--- a/src/lightning/pytorch/loops/fit_loop.py
+++ b/src/lightning/pytorch/loops/fit_loop.py
@@ -358,10 +358,9 @@ def advance(self) -> None:
f'`{type(self).__name__}` does not support the `CombinedLoader(mode="sequential")` mode.'
f" The available modes are: {[m for m in _SUPPORTED_MODES if m != 'sequential']}"
)
- data_fetcher = self._data_fetcher
- assert data_fetcher is not None
with self.trainer.profiler.profile("run_training_epoch"):
- self.epoch_loop.run(data_fetcher)
+ assert self._data_fetcher is not None
+ self.epoch_loop.run(self._data_fetcher)
def on_advance_end(self) -> None:
trainer = self.trainer