Skip to content

Commit

Permalink
Avoid loading dataloaders if limit_batches=0 (Lightning-AI#11576)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Feb 22, 2022
1 parent de1815f commit 5ea811b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 103 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))


- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))


## [1.5.8] - 2022-01-05

### Fixed
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,21 +421,21 @@ def _reset_eval_dataloader(
limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches")

# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
if isinstance(limit_eval_batches, int):
num_batches = min(num_batches, int(limit_eval_batches))
elif num_batches != float("inf"):
num_batches = int(num_batches * limit_eval_batches)
elif limit_eval_batches != 1.0:
raise MisconfigurationException(
f"When using an IterableDataset for `limit_{mode}_batches`,"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `1.0` or an int. An int k"
f" specifies `num_{mode.dataloader_prefix}_batches` to use."
)

if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
raise MisconfigurationException(
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f"You requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
Expand Down
20 changes: 15 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,13 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
Args:
model: The ``LightningModule`` if calling this outside of the trainer scope.
"""
source = self._data_connector._train_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("training_step", pl_module)
enable_training = self.limit_train_batches > 0
if not (source.is_defined() and has_step and enable_training):
return

self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)

if self.overfit_batches > 0:
Expand Down Expand Up @@ -1849,14 +1856,14 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
else float("inf")
)

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
if isinstance(self.limit_train_batches, int):
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
" `Trainer(limit_train_batches)` must be `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
)

Expand Down Expand Up @@ -1902,7 +1909,8 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
source = self._data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("validation_step", pl_module)
if source.is_defined() and has_step:
enable_validation = self.limit_val_batches > 0
if source.is_defined() and has_step and enable_validation:
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)
Expand All @@ -1919,7 +1927,8 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
source = self._data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("test_step", pl_module)
if source.is_defined() and has_step:
enable_testing = self.limit_test_batches > 0
if source.is_defined() and has_step and enable_testing:
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)
Expand All @@ -1932,7 +1941,8 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None)
"""
source = self._data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
if source.is_defined():
enable_prediction = self.limit_predict_batches > 0
if source.is_defined() and enable_prediction:
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)
Expand Down
2 changes: 0 additions & 2 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,6 @@ def configure_optimizers_multiple(self):
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.reset()
trainer.fit_loop.epoch_loop.val_loop.epoch_loop.reset()

epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
Expand Down
7 changes: 2 additions & 5 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name="train", args=(True,)),
dict(name="on_train_dataloader"),
dict(name="train_dataloader"),
# even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
dict(name="on_val_dataloader"),
dict(name="val_dataloader"),
dict(name="Callback.on_train_start", args=(trainer, model)),
dict(name="on_train_start"),
dict(name="Callback.on_epoch_start", args=(trainer, model)),
Expand Down Expand Up @@ -689,6 +686,8 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
fn = getattr(trainer, verb)
fn(model, verbose=False)
hooks = [
dict(name=f"on_{dataloader}_dataloader"),
dict(name=f"{dataloader}_dataloader"),
dict(name="train", args=(False,)),
dict(name=f"on_{noun}_model_eval"),
dict(name="zero_grad"),
Expand All @@ -710,8 +709,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
dict(name="setup", kwargs=dict(stage=verb)),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
dict(name=f"on_{dataloader}_dataloader"),
dict(name=f"{dataloader}_dataloader"),
*(hooks if batches else []),
dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage=verb)),
dict(name="teardown", kwargs=dict(stage=verb)),
Expand Down
146 changes: 59 additions & 87 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,18 @@ def on_test_epoch_start(self, trainer, pl_module):
self.test_epoch_count += 1


@pytest.mark.parametrize(
["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
)
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
def test_inf_dataloaders_with_limit_percent_batches(tmpdir):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent."""

ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
epoch_cb = Counter()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=1,
callbacks=[epoch_cb, ckpt_callback],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
callbacks=[epoch_cb],
limit_train_batches=1.0,
limit_val_batches=1.0,
limit_test_batches=1.0,
)
model = DummyModel()

Expand All @@ -268,39 +264,34 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == float("inf")
assert epoch_cb.train_epoch_count == int(limit_train_batches > 0)
# when limit_val_batches = 0, num_val_batches is empty as no data is loaded
if limit_val_batches != 0.0:
assert trainer.num_val_batches[0] == float("inf")
assert epoch_cb.val_epoch_count == int(limit_val_batches > 0)
assert epoch_cb.train_epoch_count == 1

assert trainer.num_val_batches[0] == float("inf")
assert epoch_cb.val_epoch_count == 1

trainer.test(model, dataloaders=test_dl)
assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float("inf"))
assert epoch_cb.test_epoch_count == int(limit_test_batches > 0)
assert trainer.num_test_batches[0] == float("inf")
assert epoch_cb.test_epoch_count == 1


@pytest.mark.parametrize(
["dataset", "limit_train_batches"],
[
(RandomDataset(32, 128), 0),
(RandomDataset(32, 128), 10),
(RandomIterableDataset(32, 128), 0),
(RandomIterableDataset(32, 128), 10),
(RandomIterableDatasetWithLen(32, 128), 0),
(RandomIterableDatasetWithLen(32, 128), 10),
],
)
def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""

ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
epoch_cb = Counter()
epochs = 2
max_epochs = 2
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=epochs,
callbacks=[epoch_cb, ckpt_callback],
max_epochs=max_epochs,
callbacks=[epoch_cb],
limit_train_batches=limit_train_batches,
)
model = DummyModel()
Expand All @@ -311,37 +302,32 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch

trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches != 0.0 else float("inf"))
assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0)
assert epoch_cb.train_batches_seen == limit_train_batches * epochs
assert trainer.num_training_batches == limit_train_batches
assert epoch_cb.train_epoch_count == max_epochs
assert epoch_cb.train_batches_seen == limit_train_batches * max_epochs


@pytest.mark.parametrize(
["dataset", "limit_val_batches"],
"dataset",
[
(RandomDataset(32, 128), 0),
(RandomDataset(32, 128), 10),
(RandomIterableDataset(32, 128), 0),
(RandomIterableDataset(32, 128), 10),
(RandomIterableDatasetWithLen(32, 128), 0),
(RandomIterableDatasetWithLen(32, 128), 10),
RandomDataset(32, 128),
RandomIterableDataset(32, 128),
RandomIterableDatasetWithLen(32, 128),
],
)
def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches):
def test_dataloaders_with_limit_val_batches(tmpdir, dataset):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""

epoch_cb = Counter()
callbacks = [epoch_cb]
enable_checkpointing = False
if limit_val_batches > 0:
callbacks.append(ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False))
enable_checkpointing = True

epochs = 2
max_epochs = 2
limit_val_batches = 10
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=epochs,
max_epochs=max_epochs,
callbacks=callbacks,
limit_val_batches=limit_val_batches,
enable_checkpointing=enable_checkpointing,
Expand All @@ -355,37 +341,33 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches):
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_val_batches[0] == limit_val_batches
assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0)
assert epoch_cb.val_batches_seen == limit_val_batches * epochs
assert epoch_cb.val_epoch_count == max_epochs
assert epoch_cb.val_batches_seen == limit_val_batches * max_epochs


@pytest.mark.skip()
@pytest.mark.parametrize(
["dataset", "limit_train_batches", "limit_val_batches", "limit_test_batches"],
"dataset",
[
(RandomDataset(32, 128), 0, 0, 0),
(RandomDataset(32, 128), 10, 10, 10),
(RandomIterableDataset(32, 128), 0, 0, 0),
(RandomIterableDataset(32, 128), 10, 10, 10),
(RandomIterableDatasetWithLen(32, 128), 0, 0, 0),
(RandomIterableDatasetWithLen(32, 128), 10, 10, 10),
RandomDataset(32, 128),
RandomIterableDataset(32, 128),
RandomIterableDatasetWithLen(32, 128),
],
)
def test_datasets_dataloaders_with_limit_num_batches(
tmpdir, dataset, limit_train_batches, limit_val_batches, limit_test_batches
):
def test_datasets_dataloaders_with_limit_num_batches(tmpdir, dataset):
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number."""

ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False)
epoch_cb = Counter()
epochs = 2
max_epochs = 2
limit_batches = 10
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
max_epochs=epochs,
callbacks=[epoch_cb, ckpt_callback],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
max_epochs=max_epochs,
callbacks=[epoch_cb],
limit_train_batches=limit_batches,
limit_val_batches=limit_batches,
limit_test_batches=limit_batches,
)
model = DummyModel()

Expand All @@ -396,24 +378,22 @@ def test_datasets_dataloaders_with_limit_num_batches(

trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches > 0.0 else float("inf"))
if limit_val_batches != 0.0:
assert trainer.num_val_batches[0] == limit_val_batches
else:
assert trainer.num_val_batches == []
assert epoch_cb.train_epoch_count == (epochs if limit_train_batches > 0 else 0)
assert epoch_cb.train_batches_seen == limit_train_batches * epochs
assert epoch_cb.val_epoch_count == (epochs if limit_val_batches > 0 else 0)
assert epoch_cb.val_batches_seen == limit_val_batches * epochs
assert trainer.num_training_batches == limit_batches
assert trainer.num_val_batches[0] == limit_batches
assert epoch_cb.train_epoch_count == max_epochs
assert epoch_cb.train_batches_seen == limit_batches * max_epochs
assert epoch_cb.val_epoch_count == max_epochs
assert epoch_cb.val_batches_seen == limit_batches * max_epochs

trainer.test(model, dataloaders=test_dl)
assert trainer.num_test_batches[0] == limit_test_batches
assert epoch_cb.test_epoch_count == int(limit_test_batches > 0)
assert trainer.num_test_batches[0] == limit_batches
assert epoch_cb.test_epoch_count == 1


@pytest.mark.skip()
@pytest.mark.parametrize(
["limit_train_batches", "limit_val_batches", "limit_test_batches"],
[(0.0, 0.0, 0.0), (0, 0, 0.5), (1.0, 1.0, 1.0), (0.2, 0.4, 0.4)],
[(1.0, 1.0, 1.0), (0.2, 0.4, 0.4)],
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent."""
Expand All @@ -427,22 +407,17 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
if limit_train_batches != 0.0:
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders]
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches
else:
assert trainer.train_dataloader is None
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders]
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches

trainer.test(model)
expected_test_batches = [int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders]
assert trainer.num_test_batches == expected_test_batches


@pytest.mark.parametrize(
["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0, 0, 0), (1, 2, 3), (1, 2, 1e50)]
)
@pytest.mark.parametrize(["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(1, 2, 3), (1, 2, 1e50)])
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number."""

Expand All @@ -464,12 +439,9 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
) as mocked:
trainer.fit(model)
assert trainer.num_training_batches == (limit_train_batches if limit_train_batches != 0.0 else float("inf"))
if limit_train_batches != 0.0:
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders)
else:
assert trainer.val_dataloaders is None
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders)

with patch.object(
trainer.test_loop.epoch_loop,
Expand Down Expand Up @@ -994,7 +966,7 @@ def test_inf_dataloader_raise_error_with_partial_batch_limits(tmpdir, stage, dat
trainer = Trainer(**trainer_kwargs)
trainer_fn = "fit" if stage == RunningStage.TRAINING else stage.value

with pytest.raises(MisconfigurationException, match=r"using an IterableDataset .* must be `0.0`, `1.0`"):
with pytest.raises(MisconfigurationException, match=r"using an IterableDataset .* must be `1.0` or an int"):
getattr(trainer, trainer_fn)(model)


Expand Down
Loading

0 comments on commit 5ea811b

Please sign in to comment.