Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Adafactor Optimizer #1361

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/en/common_usage/better_optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,35 @@ runner = Runner(
)
runner.train()
```

## transformers

[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。

```{note}
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.8.5`.
```

- Installation

```bash
pip install transformers
```

- Usage

Take the `Adafactor` as an example.

```python
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
# To view the input parameters for AdamW8bit, you can refer to
# https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/optim/adamw.py
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
optim_wrapper=dict(optimizer=dict(type='Adafactor', lr=1e-5,
weight_decay=1e-2, scale_parameter=False, relative_step=False)),
train_cfg=dict(by_epoch=True, max_epochs=3),
)
runner.train()
```
15 changes: 15 additions & 0 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ def register_bitsandbytes_optimizers() -> List[str]:
BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers()


def register_transformers_optimizers():
transformer_optimizers = []
try:
from transformers import Adafactor
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor')(Adafactor)
transformer_optimizers.append('Adafactor')
okotaku marked this conversation as resolved.
Show resolved Hide resolved
return transformer_optimizers


TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()


def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ neptune
parameterized
pydantic==1.10.9
pytest
transformers
19 changes: 17 additions & 2 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS,
DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
TORCH_OPTIMIZERS,
TRANSFORMERS_OPTIMIZERS)
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
Expand Down Expand Up @@ -53,6 +54,14 @@ def has_bitsandbytes() -> bool:
return False


def has_transformers() -> bool:
try:
import transformers # noqa: F401
return True
except ImportError:
return False


class ExampleModel(nn.Module):

def __init__(self):
Expand Down Expand Up @@ -244,7 +253,7 @@ def test_dadaptation_optimizers(self):
def test_lion_optimizers(self):
assert 'Lion' in LION_OPTIMIZERS

@unittest.skipIf(not has_bitsandbytes(), 'dadaptation is not installed')
@unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed')
def test_bitsandbytes_optimizers(self):
bitsandbytes_optimizers = [
'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit',
Expand All @@ -254,6 +263,12 @@ def test_bitsandbytes_optimizers(self):
assert set(bitsandbytes_optimizers).issubset(
set(BITSANDBYTES_OPTIMIZERS))

@unittest.skipIf(not has_transformers(), 'transformers is not installed')
def test_transformers_optimizers(self):
transformers_optimizers = ['Adafactor']
assert set(transformers_optimizers).issubset(
set(TRANSFORMERS_OPTIMIZERS))

def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optim_wrapper_cfg = dict(
Expand Down