Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/Colossalai layers #92

Merged
merged 3 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/cifar/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 200
Expand Down
8 changes: 3 additions & 5 deletions benchmark/cifar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
model = vit_lite_depth7_patch4_32()

train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand Down Expand Up @@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/imagenet100/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 300
Expand Down
8 changes: 3 additions & 5 deletions benchmark/imagenet100/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
model = vit_small_patch16_224(num_classes=100, init_method='jax')

train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand All @@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/imagenet1k/configs/vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 300
Expand Down
8 changes: 3 additions & 5 deletions benchmark/imagenet1k/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
model = vit_small_patch16_224(num_classes=1000, init_method='jax')

train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

Expand All @@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
Expand Down
4 changes: 4 additions & 0 deletions colossalai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-

ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'

# intializer
INITIALIZER_MAPPING = {
Expand All @@ -16,6 +17,9 @@
'sequence': 'Initializer_Sequence'
}

# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'

# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'

Expand Down
6 changes: 4 additions & 2 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import random
from typing import Union

import numpy as np
import torch
import torch.distributed as dist

from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.context.config import Config
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER

from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode

Expand Down Expand Up @@ -386,6 +387,7 @@ def init_parallel_groups(self):
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
self.check_sanity()

pg_init = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import torch.distributed as dist

from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D


@DIST_GROUP_INITIALIZER.register_module
Expand All @@ -29,6 +30,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''

for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
Expand Down
8 changes: 6 additions & 2 deletions colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device

from colossalai.nn.layer import split_batch

class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
Expand Down Expand Up @@ -59,7 +59,11 @@ def load_batch(self, data_iter):
else:
data, label = batch_data

data, label = self._to_list(data), self._to_list(label)
if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label)

def pre_processing(self, engine: Engine):
Expand Down
8 changes: 7 additions & 1 deletion colossalai/nn/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import *
Loading