Skip to content

Commit

Permalink
[refactor] pipeline, put runtime schedule into engine. (hpcaitech#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangLiu0306 authored Apr 3, 2022
1 parent e5d615a commit ade05a5
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 49 deletions.
37 changes: 34 additions & 3 deletions colossalai/engine/_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# -*- encoding: utf-8 -*-

from asyncio.log import logger
from typing import List
from typing import List, Iterable
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from colossalai.logging import get_dist_logger
from torch import Tensor
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.logging import get_dist_logger
Expand All @@ -27,6 +28,7 @@ class Engine:
clip_grad_norm (float, optional): The norm of gradient clipping.
ophook_list (list): List of ophook.
verbose (bool): whether to display log info.
schedule (''BaseSchedule''): Runtime schedule.
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
Expand Down Expand Up @@ -59,7 +61,8 @@ def __init__(self,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
ophook_list: Optional[List[BaseOpHook]] = None,
verbose: bool = True):
verbose: bool = True,
schedule: Optional[BaseSchedule] = None):
self._model = model
self._optimizer = optimizer
self._criterion = criterion
Expand All @@ -80,6 +83,14 @@ def __init__(self,
self._ophook_list = []
else:
self._ophook_list = ophook_list

# build schedule
if schedule:
self._schedule = schedule
else:
self._schedule = NonPipelineSchedule()
if self.uses_pipeline:
self._schedule.pre_processing(self)
register_ophooks_recursively(self._model, self._ophook_list)

@property
Expand All @@ -102,6 +113,16 @@ def criterion(self):
"""Criterion attached to the engine"""
return self._criterion

@property
def schedule(self):
"""Schedule attached to the engine"""
return self._schedule

@property
def uses_pipeline(self):
"""show the pipeline parallel used or not"""
return isinstance(self._schedule, (PipelineSchedule, InterleavedPipelineSchedule))

def add_hook(self, ophook: Type[BaseOpHook]) -> None:
"""add necessary hook"""
# whether this hook exist
Expand Down Expand Up @@ -165,6 +186,16 @@ def _all_reduce_gradients(self):
"""
for handler in self._gradient_handlers:
handler.handle_gradient()

def execute_schedule(self, data_iter: Iterable, **kwargs):
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self, data_iter, **kwargs)
return output, label, loss

def train(self):
"""Sets the model to training mode.
Expand All @@ -176,4 +207,4 @@ def eval(self):
"""Sets the model to evaluation mode.
"""
self.training = False
self._model.eval()
self._model.eval()
5 changes: 2 additions & 3 deletions colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -75,14 +74,14 @@ def load_batch(self, data_iter, to_gpu=True):
return self._move_to_device(data), self._move_to_device(label)
return data, label

def pre_processing(self, engine: Engine):
def pre_processing(self, engine):
"""To perform actions before running the schedule.
"""
pass

@abstractmethod
def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
Expand Down
3 changes: 1 addition & 2 deletions colossalai/engine/schedule/_non_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch

from colossalai.engine import Engine
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context

Expand All @@ -22,7 +21,7 @@ class NonPipelineSchedule(BaseSchedule):
"""

def forward_backward_step(self,
engine: Engine,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
Expand Down
18 changes: 17 additions & 1 deletion colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule

from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
Expand Down Expand Up @@ -388,6 +389,20 @@ def initialize(model: nn.Module,
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
model.module.sync_buffer = False

# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
else:
schedule = NonPipelineSchedule()


if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):
Expand Down Expand Up @@ -418,6 +433,7 @@ def initialize(model: nn.Module,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks)
ophook_list=ophooks,
schedule=schedule)

return engine, train_dataloader, test_dataloader, lr_scheduler
36 changes: 6 additions & 30 deletions colossalai/trainer/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from colossalai.core import global_context as gpc

from colossalai.engine import Engine
from colossalai.engine.schedule import NonPipelineSchedule, BaseSchedule
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
Expand All @@ -23,13 +22,9 @@ class Trainer:
Args:
engine (:class:`Engine`): Engine responsible for the process function.
schedule (:class:`BaseSchedule`, optional): Schedule responsible for forward and backward steps.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
Note:
when `schedule` is None, the ``NonPipelineSchedule`` would be used. If you would like to use pipeline,
you should choose ``PipelineSchedule`` or ``InterleavedPipelineSchedule`` for the `schedule`
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
Expand All @@ -42,7 +37,7 @@ class Trainer:
>>> # Beginning training progress
>>> timier = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
>>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
Expand All @@ -61,7 +56,6 @@ class Trainer:
def __init__(
self,
engine: Engine,
schedule: BaseSchedule = None,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
Expand All @@ -86,17 +80,6 @@ def __init__(
# multi-timer for time benchmarking
self._timer = timer

# set schedule which specifies the training iteration for the engine
if schedule is None:
schedule = NonPipelineSchedule()
if (gpc.is_initialized(ParallelMode.PIPELINE)
and gpc.get_world_size(ParallelMode.PIPELINE) > 1):
assert not isinstance(
schedule, NonPipelineSchedule
), "NonPipelineSchedule cannot be used for pipeline parallel training, please use PipelineSchedule instead."
self._schedule = schedule
self._schedule.pre_processing(engine)

@property
def cur_epoch(self):
"""Returns the index of the current epoch."""
Expand Down Expand Up @@ -129,10 +112,6 @@ def steps_per_epoch(self):
def engine(self):
return self._engine

@property
def schedule(self):
return self._schedule

def _set_current_step(self, epoch: int):
"""Sets current step number.
Expand Down Expand Up @@ -203,8 +182,7 @@ def _train_epoch(

# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step(
self.engine,
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=False,
return_loss=True,
Expand Down Expand Up @@ -260,8 +238,7 @@ def _eval(
for _ in progress:
self._call_hooks("before_test_iter")
self._call_timer(action="start", item="Test-step")
logits, label, loss = self.schedule.forward_backward_step(
self.engine,
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=True,
return_loss=True,
Expand Down Expand Up @@ -449,8 +426,7 @@ def predict(self, data: Union[Tensor, List[Tensor]]):
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(self.engine,
data_iter,
forward_only=True,
return_loss=False)
output, _, _ = self.engine.execute_schedule(data_iter,
forward_only=True,
return_loss=False)
return output
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
BATCH_SIZE = 4
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)


def run_trainer(rank, world_size, port):
Expand Down Expand Up @@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
train_dataloader,
lr_scheduler=lr_scheduler)

schedule = PipelineSchedule(num_microbatches=2)
logger = get_dist_logger()

trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
trainer = Trainer(engine=engine, logger=logger)

hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
Expand Down
1 change: 1 addition & 0 deletions tests/test_trainer/test_pipeline/resnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2

# resnet 18
model = dict(type='VanillaResNet',
Expand Down
3 changes: 1 addition & 2 deletions tests/test_trainer/test_pipeline/test_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torchvision.datasets import CIFAR10

BATCH_SIZE = 4
NUM_MICRO = 2

DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
Expand Down Expand Up @@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)

# build pipeline schedule
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
schedule = engine.schedule

# run schedule
data_iter = iter(train_dataloader)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_trainer/test_trainer_with_pipe_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
IMG_SIZE = 32
NUM_EPOCHS = 200

CONFIG = dict(parallel=dict(pipeline=2),)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)


def run_trainer_with_pipeline(rank, world_size, port):
Expand Down Expand Up @@ -69,9 +69,8 @@ def forward(self, x):

logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
pipe_schedule = PipelineSchedule(num_microbatches=2)
timer = MultiTimer()
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])

logger.info("start training", ranks=[0])
Expand Down

0 comments on commit ade05a5

Please sign in to comment.