forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support TP-compatible Torch AMP and Update trainer API (hpcaitech#27)
* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (hpcaitech#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit 2e0b0b7. * improved consistency between trainer, engine and schedule (hpcaitech#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com>
- Loading branch information
1 parent
2b05de4
commit 3defa32
Showing
80 changed files
with
2,184 additions
and
1,574 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,10 @@ | ||
from .builder import * | ||
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper, | ||
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, | ||
build_gradient_handler) | ||
from .pipeline import ModelInitializer | ||
|
||
__all__ = [ | ||
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper', | ||
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', | ||
'build_gradient_handler', 'ModelInitializer' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from .amp_type import AMP_TYPE | ||
from ._base_engine import Engine | ||
from .gradient_handler import * | ||
from .schedule import * | ||
from .amp import * | ||
|
||
|
||
__all__ = ['Engine'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,170 +1,176 @@ | ||
#!/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
|
||
from typing import Optional | ||
from torch.nn import Module | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
|
||
from colossalai.builder import build_gradient_handler | ||
from colossalai.context import ParallelMode | ||
from colossalai.core import global_context as gpc | ||
from colossalai.logging import get_global_dist_logger | ||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, | ||
ZeroRedundancyOptimizer_Level_3) | ||
from torch.nn import Module | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
from torch.utils.data import DataLoader | ||
|
||
from .schedule import BaseSchedule, NoPipelineSchedule | ||
from .schedule import BaseSchedule | ||
|
||
|
||
class Engine: | ||
"""Basic engine class for training and evaluation. It runs a specific process method | ||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. | ||
It controls a iteration in training. | ||
:param train_dataloader: Dataloader in training | ||
:param test_dataloader: Dataloader in evaluation | ||
:param model: The neural network model | ||
:param criterion: Criterion for calculating loss | ||
:param optimizer: Optimizer for updating the parameters | ||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation | ||
:param schedule: Running schedule in :meth:`step` | ||
:type train_dataloader: DataLoader, optional | ||
:type test_dataloader: DataLoader, optional | ||
:param step_schedule: Running schedule in :meth:`step` | ||
:param gradient_accumulation: Steps of gradient accumulation | ||
:param gradient_clipping: The norm of gradient clipping | ||
:type model: Module | ||
:type criterion: _Loss, optional | ||
:type optimizer: Optimizer, optional | ||
:type lr_scheduler: _LRScheduler, optional | ||
:type schedule: BaseSchedule, optional | ||
:type optimizer: Optimizer | ||
:type step_schedule: BaseSchedule, optional | ||
:type gradient_accumulation: int, optional | ||
:type gradient_clipping: float, optional | ||
""" | ||
|
||
def __init__(self, | ||
train_dataloader: Optional[DataLoader] = None, | ||
test_dataloader: Optional[DataLoader] = None, | ||
model: Module = None, | ||
criterion: _Loss = None, | ||
optimizer: Optimizer = None, | ||
lr_scheduler: Optional[_LRScheduler] = None, | ||
schedule: BaseSchedule = None): | ||
self.train_dataloader = train_dataloader | ||
self.test_dataloader = test_dataloader | ||
assert model is not None, "Engine requires a model" | ||
self.model = model | ||
self.criterion = criterion | ||
self.optimizer = optimizer | ||
self.lr_scheduler = lr_scheduler | ||
self.schedule = schedule if schedule is not None \ | ||
else NoPipelineSchedule() | ||
model: Module, | ||
optimizer: Optimizer, | ||
criterion: _Loss, | ||
step_schedule: BaseSchedule, | ||
gradient_handlers: list = None, | ||
gradient_accumulation: int = 1, | ||
gradient_clipping: float = 0.0, | ||
): | ||
self._model = model | ||
self._optimizer = optimizer | ||
self._criterion = criterion | ||
self._schedule = step_schedule | ||
|
||
# schedule initialize | ||
self._schedule.initialize(model, optimizer) | ||
|
||
# state | ||
self.training = True # default | ||
|
||
# gradient accumulation | ||
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0' | ||
self._grad_accum_size = gradient_accumulation | ||
self._grad_clip = gradient_clipping | ||
self._logger = get_global_dist_logger() | ||
|
||
# build gradient handler | ||
self._gradient_handlers = [] | ||
gradient_handler_cfg = [] | ||
|
||
if hasattr(gpc.config, 'gradient_handler'): | ||
assert isinstance(gpc.config.gradient_handler, list), \ | ||
if gradient_handlers is not None: | ||
assert isinstance(gradient_handlers, list), \ | ||
f'argument gradient_handler_cfg expected type list, ' \ | ||
f'but got type {type(gpc.config.gradient_handler)}' | ||
gradient_handler_cfg = gpc.config.gradient_handler | ||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2, | ||
ZeroRedundancyOptimizer_Level_3)): | ||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')] | ||
f'but got type {type(gradient_handlers)}' | ||
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, | ||
ZeroRedundancyOptimizer_Level_3)): | ||
gradient_handlers = [dict(type='ZeROGradientHandler')] | ||
self._logger.info( | ||
"Training with zero is detected, ZeROGradientHandler is automatically " | ||
"added even though not specified in the configuration", | ||
ranks=[0]) | ||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size( | ||
ParallelMode.DATA) > 1: | ||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] | ||
gradient_handlers = [dict(type='DataParallelGradientHandler')] | ||
self._logger.info( | ||
"Data parallel training is detected, DataParallelGradientHandler is automatically " | ||
"added even though not specified in the configuration", | ||
ranks=[0]) | ||
if len(gradient_handler_cfg) == 0: | ||
|
||
if gradient_handlers is None: | ||
self._logger.warning( | ||
"No gradient handler is set up, please make sure you do not need " | ||
"to all-reduce the gradients after a training step.", | ||
ranks=[0]) | ||
for cfg in gradient_handler_cfg: | ||
handler = build_gradient_handler(cfg, self.model, self.optimizer) | ||
self._gradient_handlers.append(handler) | ||
else: | ||
for cfg in gradient_handlers: | ||
handler = build_gradient_handler(cfg, model, optimizer) | ||
self._gradient_handlers.append(handler) | ||
|
||
self.schedule.initialize(self.train_dataloader, self.model, | ||
self.criterion, self.optimizer, | ||
self.lr_scheduler) | ||
self.forward_only = False | ||
@property | ||
def model(self): | ||
return self._model | ||
|
||
def handle_gradient(self): | ||
"""Handles all-reduce operations of gradients across different parallel groups. | ||
""" | ||
for handler in self._gradient_handlers: | ||
handler.handle_gradient() | ||
@property | ||
def optimizer(self): | ||
return self._optimizer | ||
|
||
def set_dataloader(self, data: DataLoader, train: bool = True): | ||
"""Sets dataloader in training or evaluation. | ||
@property | ||
def criterion(self): | ||
return self._criterion | ||
|
||
:param data: Dataloader to be set | ||
:param train: Set training dataloader if True, otherwise evaluation dataloader | ||
:type data: DataLoader | ||
:type train: bool | ||
""" | ||
if train: | ||
self.train_dataloader = data | ||
else: | ||
self.test_dataloader = data | ||
@property | ||
def schedule(self): | ||
return self._schedule | ||
|
||
def get_model(self): | ||
"""Returns the neural network model in the engine. | ||
""" | ||
return self.model | ||
def get_optimizer(self): | ||
"""Returns optimizier in the engine. | ||
""" | ||
return self.optimizer | ||
@property | ||
def gradient_accumulation(self): | ||
return self._grad_accum_size | ||
|
||
def get_lr_scheduler(self): | ||
"""Returns the learning rate scheduler in the engine. | ||
def handle_gradient(self): | ||
"""Handles all-reduce operations of gradients across different parallel groups. | ||
""" | ||
return self.lr_scheduler | ||
for handler in self._gradient_handlers: | ||
handler.handle_gradient() | ||
|
||
def train(self): | ||
"""Sets the model to training mode. | ||
""" | ||
self.forward_only = False | ||
self.schedule.train(dataloader=self.train_dataloader, mode=True) | ||
self.training = True | ||
self._model.train() | ||
|
||
def eval(self): | ||
"""Sets the model to evaluation mode. | ||
""" | ||
self.forward_only = True | ||
self.schedule.train(dataloader=self.test_dataloader, mode=False) | ||
self.training = False | ||
self._model.eval() | ||
|
||
def is_train(self): | ||
"""Returns True if it is in training, otherwise False. | ||
""" | ||
return not self.forward_only | ||
|
||
def get_lr(self): | ||
"""Gets current learning rate. | ||
""" | ||
return self.schedule.get_lr() | ||
|
||
def step(self, return_loss=True): | ||
def step(self, | ||
data_iter, | ||
is_last_iteration: bool = False, | ||
return_loss=True): | ||
"""A running step based on the schedule. Usually, it runs a training or | ||
evaluation over a batch of dataset. | ||
:param data_iter: Data iterator of the dataset | ||
:param is_last_iteration: If True, this iteration is the last iteration in the epoch | ||
:param return_loss: loss will be returned if True | ||
:type return_loss: bool | ||
:type data_iter: Iterator | ||
:type is_last_iteration: bool, optional | ||
:type return_loss: bool, optional | ||
:return: (output, lablel, loss) | ||
""" | ||
self.schedule.zero_grad(forward_only=self.forward_only) | ||
|
||
output, label, loss = self.schedule.forward_backward_step( | ||
forward_only=self.forward_only, return_loss=return_loss) | ||
|
||
if not self.forward_only: | ||
# all reduce gradients | ||
self.handle_gradient() | ||
|
||
self.schedule.step() | ||
if self.training: | ||
self._optimizer.zero_grad() | ||
|
||
# differentiate training and eval with grad accum | ||
if self.training: | ||
for i in range(self._grad_accum_size): | ||
output, label, loss = self._schedule.forward_backward_step( | ||
data_iter, self._model, self._criterion, self._optimizer, | ||
forward_only=False, | ||
grad_accum_size=self._grad_accum_size, | ||
return_loss=return_loss) | ||
|
||
if i == self._grad_accum_size - 1: | ||
# all reduce gradients | ||
self.handle_gradient() | ||
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip) | ||
else: | ||
output, label, loss = self._schedule.forward_backward_step( | ||
data_iter, self._model, self._criterion, self._optimizer, | ||
forward_only=True, | ||
grad_accum_size=1, | ||
return_loss=return_loss) | ||
|
||
# consume the remaining dataset left out due to gradient accumulation | ||
if is_last_iteration: | ||
while True: | ||
try: | ||
_ = next(data_iter) | ||
except StopIteration: | ||
break | ||
|
||
return output, label, loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .grad_scaler import GradScaler | ||
from .amp_type import AMP_TYPE |
File renamed without changes.
Oops, something went wrong.