Skip to content

Commit

Permalink
Remove unused disallow_batch_sampler argument (Lightning-AI#18401)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Aug 26, 2023
1 parent 60339f1 commit 722fdea
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 173 deletions.
110 changes: 48 additions & 62 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
Expand Down Expand Up @@ -131,7 +130,7 @@ def _get_dataloader_init_args_and_kwargs(
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, disallow_batch_sampler))
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler))

required_args = {
p.name
Expand Down Expand Up @@ -173,73 +172,60 @@ def _get_dataloader_init_args_and_kwargs(
def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
instantiation."""
batch_sampler = getattr(dataloader, "batch_sampler")

if batch_sampler is not None:
if disallow_batch_sampler:
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
if not (
type(batch_sampler) is BatchSampler
and batch_sampler.sampler == sampler
and dataloader.batch_size == batch_sampler.batch_size
):
raise MisconfigurationException(
"It is not possible to have a batch sampler in your dataloader, "
"when running on multiple IPU devices."
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler_cls = type(batch_sampler)
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)
elif type(batch_sampler) is not BatchSampler:
batch_sampler_cls = type(batch_sampler)
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names

batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)

batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
except TypeError as ex:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise TypeError(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from ex

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}
except TypeError as ex:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise TypeError(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from ex

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

return {"sampler": sampler, "shuffle": False, "batch_sampler": None}

Expand Down
133 changes: 58 additions & 75 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
Expand Down Expand Up @@ -189,7 +188,7 @@ def _get_dataloader_init_args_and_kwargs(
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler))
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode))

required_args = {
p.name
Expand Down Expand Up @@ -232,98 +231,82 @@ def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
instantiation.
If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so
Lightning can keep track of its indices.
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
automatically, since `poptorch.DataLoader` will try to increase the batch_size
"""
is_predicting = mode == RunningStage.PREDICTING
batch_sampler = getattr(dataloader, "batch_sampler")
batch_sampler_cls = type(batch_sampler)

if batch_sampler is not None:
if disallow_batch_sampler:
# Check that we don't have a PyTorch default batch sampler that was instantiated in DataLoader __init__
if not (
batch_sampler_cls is BatchSampler
and batch_sampler.sampler == sampler
and dataloader.batch_size == batch_sampler.batch_size
):
raise MisconfigurationException(
"It is not possible to have a batch sampler in your dataloader, "
"when running on multiple IPU devices."
)
elif batch_sampler_cls is not BatchSampler or is_predicting:
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

if is_predicting:
success, args, kwargs = _replace_value_in_saved_args(
"drop_last", False, args, kwargs, default_kwargs, arg_names
)
if not success:
rank_zero_warn(
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
"the `__init__` method of your custom class."
)
if batch_sampler is not None and (batch_sampler_cls is not BatchSampler or is_predicting):
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

if is_predicting:
success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names
"drop_last", False, args, kwargs, default_kwargs, arg_names
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
rank_zero_warn(
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
"the `__init__` method of your custom class."
)

batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
except TypeError as ex:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise MisconfigurationException(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from ex
success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)

if is_predicting:
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler)

# batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last
return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
except TypeError as ex:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise MisconfigurationException(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from ex

if is_predicting:
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler)

# batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last
return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

return {"sampler": sampler, "shuffle": False, "batch_sampler": None}

Expand Down
22 changes: 2 additions & 20 deletions tests/tests_fabric/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import pytest
import torch
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

from lightning.fabric.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_replace_dunder_methods,
_replace_value_in_saved_args,
Expand Down Expand Up @@ -479,24 +478,7 @@ def __init__(self, extra_arg):

# Assert that error is raised
with pytest.raises(TypeError, match="sampler into the batch sampler"):
dataloader = _update_dataloader(dataloader, dataloader.sampler)


def test_dataloader_disallow_batch_sampler():
dataset = RandomDataset(5, 100)
dataloader = DataLoader(dataset, batch_size=10)

# This should not raise
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)

dataset = RandomDataset(5, 100)
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

# this should raise - using batch sampler, that was not automatically instantiated by DataLoader
with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"):
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
_ = _update_dataloader(dataloader, dataloader.sampler)


def test_dataloader_kwargs_replacement_with_iterable_dataset():
Expand Down
18 changes: 2 additions & 16 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import torch
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

from lightning.fabric.utilities.data import _replace_dunder_methods
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_update_dataloader,
extract_batch_size,
Expand Down Expand Up @@ -239,20 +238,7 @@ def __init__(self, extra_arg):

# Assert that error is raised
with pytest.raises(TypeError, match="sampler into the batch sampler"):
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)


def test_dataloader_disallow_batch_sampler():
dataset = RandomDataset(5, 100)
dataloader = DataLoader(dataset, batch_size=10)

# This should not raise
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)

dataset = RandomDataset(5, 100)
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)


@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
Expand Down

0 comments on commit 722fdea

Please sign in to comment.