Skip to content

Commit

Permalink
Describe the behavior with limit_*_batches=1|1.0 (Lightning-AI#11950)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 21, 2022
1 parent 08384c4 commit 3579a30
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `LOGGER_REGISTRY` instance to register custom loggers to the `LightningCLI` ([#11533](https://github.com/PyTorchLightning/pytorch-lightning/pull/11533))


- Added info message when the `Trainer` arguments `limit_*_batches`, `overfit_batches`, or `val_check_interval` are set to `1` or `1.0` ([#11950](https://github.com/PyTorchLightning/pytorch-lightning/pull/11950))

- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def advance(self) -> None: # type: ignore[override]
log.detail(f"{self.__class__.__name__}: advancing loop")
assert self.trainer.train_dataloader is not None
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
assert self._data_fetcher is not None
self._data_fetcher.setup(
dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
)
Expand Down
53 changes: 39 additions & 14 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def __init__(
max_steps: int = -1,
min_steps: Optional[int] = None,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
limit_train_batches: Union[int, float] = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
limit_predict_batches: Union[int, float] = 1.0,
val_check_interval: Union[int, float] = 1.0,
limit_train_batches: Optional[Union[int, float]] = None,
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
val_check_interval: Optional[Union[int, float]] = None,
flush_logs_every_n_steps: Optional[int] = None,
log_every_n_steps: int = 50,
accelerator: Optional[Union[str, Accelerator]] = None,
Expand Down Expand Up @@ -511,6 +511,7 @@ def __init__(
self._call_callback_hooks("on_init_start")

# init data flags
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
Expand Down Expand Up @@ -568,6 +569,7 @@ def __init__(
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
self.val_check_interval: Union[int, float]
self._init_debugging_flags(
limit_train_batches,
limit_val_batches,
Expand All @@ -583,14 +585,14 @@ def __init__(

def _init_debugging_flags(
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
val_check_interval,
overfit_batches,
fast_dev_run,
):
limit_train_batches: Optional[Union[int, float]],
limit_val_batches: Optional[Union[int, float]],
limit_test_batches: Optional[Union[int, float]],
limit_predict_batches: Optional[Union[int, float]],
val_check_interval: Optional[Union[int, float]],
overfit_batches: Union[int, float],
fast_dev_run: Union[int, bool],
) -> None:
if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
raise MisconfigurationException(
f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0."
Expand Down Expand Up @@ -2628,7 +2630,30 @@ def terminate_on_nan(self, val: bool) -> None:
self._terminate_on_nan = val # : 212


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
if batches is None:
# batches is optional to know if the user passed a value so that we can show the above info messages only to the
# users that set a value explicitly
return 1.0

# differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour
if isinstance(batches, int) and batches == 1:
if name == "limit_train_batches":
message = "1 batch per epoch will be used."
elif name == "val_check_interval":
message = "validation will run after every batch."
else:
message = "1 batch will be used."
rank_zero_info(f"`Trainer({name}=1)` was configured so {message}")
elif isinstance(batches, float) and batches == 1.0:
if name == "limit_train_batches":
message = "100% of the batches per epoch will be used."
elif name == "val_check_interval":
message = "validation will run at the end of the training epoch."
else:
message = "100% of the batches will be used."
rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.")

if 0 <= batches <= 1:
return batches
if batches > 1 and batches % 1.0 == 0:
Expand Down
21 changes: 21 additions & 0 deletions tests/trainer/flags/test_limit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import pytest

Expand Down Expand Up @@ -66,3 +67,23 @@ def test_eval_limit_batches(stage, mode, limit_batches):
expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
assert loader_num_batches[0] == expected_batches
assert len(dataloaders[0]) == len(eval_loader)


@pytest.mark.parametrize(
"argument",
("limit_train_batches", "limit_val_batches", "limit_test_batches", "limit_predict_batches", "overfit_batches"),
)
@pytest.mark.parametrize("value", (1, 1.0))
def test_limit_batches_info_message(caplog, argument, value):
with caplog.at_level(logging.INFO):
Trainer(**{argument: value})
assert f"`Trainer({argument}={value})` was configured" in caplog.text
message = f"configured so {'1' if isinstance(value, int) else '100%'}"
assert message in caplog.text

caplog.clear()

# the message should not appear by default
with caplog.at_level(logging.INFO):
Trainer()
assert message not in caplog.text
18 changes: 18 additions & 0 deletions tests/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import pytest

from pytorch_lightning.trainer import Trainer
Expand Down Expand Up @@ -39,3 +41,19 @@ def on_validation_epoch_start(self) -> None:

assert model.train_epoch_calls == max_epochs
assert model.val_epoch_calls == max_epochs * denominator


@pytest.mark.parametrize("value", (1, 1.0))
def test_val_check_interval_info_message(caplog, value):
with caplog.at_level(logging.INFO):
Trainer(val_check_interval=value)
assert f"`Trainer(val_check_interval={value})` was configured" in caplog.text
message = "configured so validation will run"
assert message in caplog.text

caplog.clear()

# the message should not appear by default
with caplog.at_level(logging.INFO):
Trainer()
assert message not in caplog.text

0 comments on commit 3579a30

Please sign in to comment.