Skip to content

Commit

Permalink
Rename SingleTPUPlugin to SingleTPUStrategy (Lightning-AI#11182)
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Dec 21, 2021
1 parent 1c5a5c3 commit c3cd4d0
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Renamed the `TrainingTypePlugin` to `Strategy` ([#11120](https://github.com/PyTorchLightning/pytorch-lightning/pull/11120))
* Renamed `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Training Type Plugins
DDPSpawnPlugin
DeepSpeedStrategy
HorovodStrategy
SingleTPUPlugin
SingleTPUStrategy
TPUSpawnStrategy

Precision Plugins
Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Training Type Plugins
DDPSpawnPlugin
DeepSpeedStrategy
HorovodStrategy
SingleTPUPlugin
SingleTPUStrategy
TPUSpawnStrategy


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.plugins.training_type.sharded import DDPShardedStrategy
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUStrategy
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnStrategy
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy

Expand Down Expand Up @@ -59,7 +59,7 @@
"ShardedNativeMixedPrecisionPlugin",
"FullyShardedNativeMixedPrecisionPlugin",
"SingleDevicePlugin",
"SingleTPUPlugin",
"SingleTPUStrategy",
"TPUPrecisionPlugin",
"TPUBf16PrecisionPlugin",
"TPUSpawnStrategy",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
from pytorch_lightning.plugins.training_type.sharded import DDPShardedStrategy # noqa: F401
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUStrategy # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnStrategy # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import Strategy # noqa: F401
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import torch_xla.core.xla_model as xm


class SingleTPUPlugin(SingleDevicePlugin):
"""Plugin for training on a single TPU device."""
class SingleTPUStrategy(SingleDevicePlugin):
"""Strategy for training on a single TPU device."""

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
PrecisionPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
SingleTPUPlugin,
SingleTPUStrategy,
Strategy,
TPUBf16PrecisionPlugin,
TPUPrecisionPlugin,
Expand Down Expand Up @@ -749,7 +749,7 @@ def select_training_type_plugin(self) -> Strategy:
elif self.use_horovod:
plugin = HorovodStrategy(parallel_devices=self.parallel_devices)
elif self.use_tpu and isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
plugin = SingleTPUStrategy(self.tpu_id)
elif self.use_ipu:
plugin = IPUStrategy(parallel_devices=self.parallel_devices)
else:
Expand Down Expand Up @@ -1026,8 +1026,8 @@ def _check_plugin_compatibility(self) -> None:
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
f" found: {self.training_type_plugin.precision_plugin}."
)
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnStrategy)):
if not isinstance(self.training_type_plugin, (SingleTPUStrategy, TPUSpawnStrategy)):
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnStrategy`,"
"The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`,"
f" found {self.training_type_plugin}."
)
4 changes: 2 additions & 2 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,11 +879,11 @@ def test_unsupported_tpu_choice(monkeypatch):
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
Trainer(accelerator="tpu", precision=64)

with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"):
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
Trainer(accelerator="tpu", precision=16)

with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"):
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
Trainer(accelerator="tpu", precision=16, amp_backend="apex")

Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_tpu_invalid_raises():
Trainer(strategy=training_type_plugin)

training_type_plugin = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"):
Trainer(strategy=training_type_plugin)


Expand All @@ -306,7 +306,7 @@ def test_tpu_invalid_raises_set_precision_with_strategy():
accelerator = TPUAccelerator()
training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin())
with pytest.raises(
ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnStrategy"
ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy"
):
Trainer(strategy=training_type_plugin)

Expand Down

0 comments on commit c3cd4d0

Please sign in to comment.