diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index e619e90f26dd..13a815592fbb 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -48,9 +48,13 @@ def load_batch(self, data_iter): # Pipeline schedule just puts data in memory self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ + if isinstance(self.batch_data, torch.Tensor): + batch_size = self.batch_data.size(0) + else: + batch_size = next(iter(self.batch_data.values())).size(0) + assert batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + self.microbatch_size = batch_size // self.num_microbatches def _get_data_slice(self, data, offset): if isinstance(data, torch.Tensor): diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 3a3fa6e00e0a..832d7d9df7c7 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -71,6 +71,7 @@ def forward(self, input_: Tensor) -> Tensor: @LAYERS.register_module class Classifier1D(ParallelLayer): """RowLinear with given weight""" + def __init__(self, in_features: int, num_classes: int, @@ -127,8 +128,8 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel = F.linear(input_, self.weight) output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - - output = output + self.bias + if self.bias is not None: + output = output + self.bias return output @@ -152,6 +153,7 @@ class Linear1D_Col(ParallelLayer): which is :math:`Y_i = XA_i`, defaults to False :type gather_output: bool, optional """ + def __init__(self, in_features: int, out_features: int, @@ -233,6 +235,7 @@ class Linear1D_Row(ParallelLayer): :param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False :type parallel_input: bool, optional """ + def __init__(self, in_features: int, out_features: int, @@ -302,6 +305,7 @@ def forward(self, input_: Tensor) -> Tensor: class MixedFusedLayerNorm1D(torch.nn.Module): """ Experimental """ + def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm1D, self).__init__() diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 9955bcefec90..ef899a5ece2e 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -121,9 +121,10 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = B_grad.reshape(ctx.B_shape) - - bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) - bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) + bias_grad = None + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) + bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None @@ -174,9 +175,9 @@ def forward( col_group = gpc.get_group(col_parallel_mode) src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opb = [None] * 2 @@ -279,9 +280,9 @@ def forward( col_group = gpc.get_group(col_parallel_mode) src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opb = [None] * 2 opr = [None] * 2 @@ -393,9 +394,9 @@ def forward( col_group = gpc.get_group(col_parallel_mode) src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opr = [None] * 2 diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py index 5bd4a83e3254..dd422a75f979 100644 --- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -38,3 +38,9 @@ def register_module(self, module: nn.Module): for p in module.parameters(): setattr(p, 'pipeline_shared_module_pg', self.group) dist.broadcast(p, src, group=self.group) + + def register_parameter(self, param: nn.Parameter): + assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + src = self.ranks_in_group[self.pipeline_ranks[0]] + setattr(param, 'pipeline_shared_module_pg', self.group) + dist.broadcast(param, src, group=self.group) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 3489298807d2..c4546b6f964b 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -25,6 +25,7 @@ class Metric(ABC): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ + def __init__(self, epoch_only: bool): # is the metric only read for the full epoch self._epoch_only = epoch_only @@ -82,6 +83,7 @@ class LossMetric(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ + def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) self.last_step_loss = torch.zeros(1, device=get_current_device()) @@ -132,6 +134,7 @@ class LearningRateMetric(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ + def __init__(self, epoch_only: bool, initial_lr: float = 0.): super().__init__(epoch_only=epoch_only) self.lr = initial_lr @@ -159,6 +162,7 @@ class AccuracyMetric(Metric): :param epoch_only: Whether the metric only read for the full epoch :type epoch_only: bool """ + def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func @@ -217,6 +221,7 @@ class MetricHook(BaseHook): :type trainer: Trainer :type priority: int """ + def __init__( self, priority: int, @@ -238,6 +243,7 @@ class LossHook(MetricHook): :type trainer: Trainer :type priority: int, optional """ + def __init__(self, priority: int = 0): super().__init__(priority) @@ -278,6 +284,7 @@ class AccuracyHook(MetricHook): :type trainer: Trainer :type priority: int """ + def __init__(self, accuracy_func: Callable, priority: int = 0): super().__init__(priority) self.accuracy_func = accuracy_func @@ -351,13 +358,17 @@ def after_hook_is_attached(self, trainer): trainer.states['metrics']['test']['Throughput'] = self.metric def before_train_epoch(self, trainer): - self.metric.reset() + if self._is_stage_to_compute: + self.metric.reset() def after_train_iter(self, trainer, *args): - self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) + if self._is_stage_to_compute: + self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) def before_test(self, trainer): - self.metric.reset() + if self._is_stage_to_compute: + self.metric.reset() def after_test_iter(self, trainer, *args): - self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time()) + if self._is_stage_to_compute: + self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time()) diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index 99095f08cd44..bfa85813f39d 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -133,7 +133,7 @@ def __init__(self, dtype: dtype = None, bias: bool = True, checkpoint: bool = False): - super().__init__() + super().__init__(checkpoint=checkpoint) self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.attn = GPTSelfAttention(dim=dim, num_heads=num_heads,