Skip to content

Commit

Permalink
Hotfix/Colossalai layers (hpcaitech#92)
Browse files Browse the repository at this point in the history
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
  • Loading branch information
kurisusnowdeng and BoxiangW authored Dec 29, 2021
1 parent 0fedef4 commit 01a80cd
Show file tree
Hide file tree
Showing 71 changed files with 1,031 additions and 771 deletions.
2 changes: 1 addition & 1 deletion benchmark/cifar/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 200
Expand Down
8 changes: 3 additions & 5 deletions benchmark/cifar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
model = vit_lite_depth7_patch4_32()

train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand Down Expand Up @@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/imagenet100/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 300
Expand Down
8 changes: 3 additions & 5 deletions benchmark/imagenet100/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
model = vit_small_patch16_224(num_classes=100, init_method='jax')

train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand All @@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/imagenet1k/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 300
Expand Down
8 changes: 3 additions & 5 deletions benchmark/imagenet1k/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
model = vit_small_patch16_224(num_classes=1000, init_method='jax')

train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand All @@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-

ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'

# intializer
INITIALIZER_MAPPING = {
Expand All @@ -16,6 +17,9 @@
'sequence': 'Initializer_Sequence'
}

# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'

# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'

Expand Down
6 changes: 4 additions & 2 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import random
from typing import Union

import numpy as np
import torch
import torch.distributed as dist

from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.context.config import Config
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER

from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode

Expand Down Expand Up @@ -386,6 +387,7 @@ def init_parallel_groups(self):
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
self.check_sanity()

pg_init = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import torch.distributed as dist

from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D


@DIST_GROUP_INITIALIZER.register_module
Expand All @@ -29,6 +30,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''

for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
Expand Down
8 changes: 6 additions & 2 deletions colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

from colossalai.nn.layer import split_batch

class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
Expand Down Expand Up @@ -59,7 +59,11 @@ def load_batch(self, data_iter):
else:
data, label = batch_data

data, label = self._to_list(data), self._to_list(label)
if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label)

def pre_processing(self, engine: Engine):
Expand Down
8 changes: 7 additions & 1 deletion colossalai/nn/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import *
Loading

0 comments on commit 01a80cd

Please sign in to comment.