Skip to content

Commit

Permalink
[pipeline] refactor pipeline (hpcaitech#679)
Browse files Browse the repository at this point in the history
* refactor pipeline---put runtime schedule into engine.

* add type hint for schedule Optional[BaseSchedule]

* preprocess schedule during engine initializing

* infer pipeline schedule params from config
  • Loading branch information
YuliangLiu0306 authored Apr 7, 2022
1 parent eace693 commit 0ed7042
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 9 deletions.
4 changes: 2 additions & 2 deletions colossalai/engine/schedule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule

__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
23 changes: 23 additions & 0 deletions colossalai/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,29 @@

from ._base_schedule import BaseSchedule

def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'):
return gpc.config.TENSOR_SHAPE

if not gpc.is_initialized(ParallelMode.PIPELINE):
return None

if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA)
else:
dp_size = 1
if gpc.is_initialized(ParallelMode.SEQUENCE):
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
else:
seq_size = 1

tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
gpc.config.HIDDEN_SIZE)
return tensor_shape
else:
return None

def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
Expand Down
12 changes: 8 additions & 4 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape

from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
Expand Down Expand Up @@ -391,14 +391,18 @@ def initialize(model: nn.Module,

# initialize schedule for engine
if is_using_pp():
tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None)
tensor_shape = get_tensor_shape()
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=True)
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=True)
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()

Expand Down
6 changes: 3 additions & 3 deletions colossalai/trainer/hooks/_metric_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def before_test(self, trainer):

def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
batch_size = trainer.schedule.batch_size
batch_size = trainer.engine.schedule.batch_size
self.metric.update(logits, targets, batch_size)


Expand Down Expand Up @@ -392,12 +392,12 @@ def before_train_epoch(self, trainer):

def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())

def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()

def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())

0 comments on commit 0ed7042

Please sign in to comment.