forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* integrated parallel layers for ease of building models * integrated 2.5d layers * cleaned codes and unit tests * added log metric by step hook; updated imagenet benchmark; fixed some bugs * reworked initialization; cleaned codes Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
- Loading branch information
1 parent
5c3843d
commit 0fedef4
Showing
118 changed files
with
4,965 additions
and
8,140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Benchmark for Tuning Accuracy and Efficiency | ||
|
||
## Overview | ||
|
||
The benchmark includes our efforts in using Colossal-AI to train different tasks to achieve SOTA results. | ||
We are interested in both validataion accuracy and training speed, and prefer larger batch size to take advantage of more GPU devices. | ||
For example, we trained vision transformer with batch size 512 on CIFAR10 and 4096 on ImageNet1k, which are basically not used in existing works. | ||
Some of the results in the benchmark trained with 8x A100 are shown below. | ||
|
||
| Task | Model | Training Time | Top-1 Accuracy | | ||
| ---------- | ------------ | ------------- | -------------- | | ||
| CIFAR10 | [ViT-Lite-7/4](https://arxiv.org/pdf/2104.05704.pdf) | ~ 16 min | ~ 90.5% | | ||
| ImageNet1k | ViT-S/16 | ~ 16.5 h | ~ 74.5% | | ||
|
||
The `train.py` script in each task runs training with the specific configuration script in `configs/` for different parallelisms. | ||
Supported parallelisms include data parallel only (ends with `vanilla`), 1D (ends with `1d`), 2D (ends with `2d`), 2.5D (ends with `2p5d`), 3D (ends with `3d`). | ||
|
||
Each configuration scripts basically includes the following elements, taking ImageNet1k task as example: | ||
``` | ||
TOTAL_BATCH_SIZE = 4096 | ||
LEARNING_RATE = 3e-3 | ||
WEIGHT_DECAY = 0.3 | ||
NUM_EPOCHS = 300 | ||
WARMUP_EPOCHS = 32 | ||
# data parallel only | ||
TENSOR_PARALLEL_SIZE = 1 | ||
TENSOR_PARALLEL_MODE = None | ||
# parallelism setting | ||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
fp16 = dict(mode=AMP_TYPE.TORCH, ) # amp setting | ||
gradient_accumulation = 2 # accumulate 2 steps for gradient update | ||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation # actual batch size for dataloader | ||
clip_grad_norm = 1.0 # clip gradient with norm 1.0 | ||
``` | ||
Upper case elements are basically what `train.py` needs, and lower case elements are what Colossal-AI needs to initialize the training. | ||
|
||
## Usage | ||
|
||
To start training, use the following command to run each worker: | ||
``` | ||
$ DATA=/path/to/dataset python train.py --world_size=WORLD_SIZE \ | ||
--rank=RANK \ | ||
--local_rank=LOCAL_RANK \ | ||
--host=MASTER_IP_ADDRESS \ | ||
--port=MASTER_PORT \ | ||
--config=CONFIG_FILE | ||
``` | ||
It is also recommended to start training with `torchrun` as: | ||
``` | ||
$ DATA=/path/to/dataset torchrun --nproc_per_node=NUM_GPUS_PER_NODE \ | ||
--nnodes=NUM_NODES \ | ||
--node_rank=NODE_RANK \ | ||
--master_addr=MASTER_IP_ADDRESS \ | ||
--master_port=MASTER_PORT \ | ||
train.py --config=CONFIG_FILE | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
BATCH_SIZE = 512 | ||
LEARNING_RATE = 2e-3 | ||
WEIGHT_DECAY = 3e-2 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
TENSOR_PARALLEL_MODE = '1d' | ||
|
||
NUM_EPOCHS = 200 | ||
WARMUP_EPOCHS = 40 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
seed = 42 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
BATCH_SIZE = 512 | ||
LEARNING_RATE = 2e-3 | ||
WEIGHT_DECAY = 3e-2 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
TENSOR_PARALLEL_MODE = '2d' | ||
|
||
NUM_EPOCHS = 200 | ||
WARMUP_EPOCHS = 40 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
seed = 42 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
BATCH_SIZE = 512 | ||
LEARNING_RATE = 2e-3 | ||
WEIGHT_DECAY = 3e-2 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
DEPTH = 1 | ||
TENSOR_PARALLEL_MODE = '2.5d' | ||
|
||
NUM_EPOCHS = 200 | ||
WARMUP_EPOCHS = 40 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH), | ||
) | ||
|
||
seed = 42 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
BATCH_SIZE = 512 | ||
LEARNING_RATE = 2e-3 | ||
WEIGHT_DECAY = 3e-2 | ||
|
||
TENSOR_PARALLEL_SIZE = 8 | ||
TENSOR_PARALLEL_MODE = '3d' | ||
|
||
NUM_EPOCHS = 200 | ||
WARMUP_EPOCHS = 40 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
seed = 42 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
BATCH_SIZE = 512 | ||
LEARNING_RATE = 2e-3 | ||
WEIGHT_DECAY = 3e-2 | ||
|
||
TENSOR_PARALLEL_SIZE = 1 | ||
TENSOR_PARALLEL_MODE = None | ||
|
||
NUM_EPOCHS = 200 | ||
WARMUP_EPOCHS = 40 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
seed = 42 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
|
||
import os | ||
|
||
import colossalai | ||
import torch | ||
import torchvision | ||
from colossalai.builder import * | ||
from colossalai.core import global_context as gpc | ||
from colossalai.logging import get_dist_logger | ||
from colossalai.nn import Accuracy, CrossEntropyLoss | ||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR | ||
from colossalai.trainer import Trainer | ||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, | ||
LogMetricByEpochHook, | ||
LogMetricByStepHook, | ||
LogTimingByEpochHook, LossHook, | ||
LRSchedulerHook, ThroughputHook) | ||
from colossalai.utils import MultiTimer, get_dataloader | ||
from model_zoo.vit import vit_lite_depth7_patch4_32 | ||
from torchvision import transforms | ||
|
||
DATASET_PATH = str(os.environ['DATA']) | ||
|
||
|
||
def build_cifar(batch_size): | ||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
transform_test = transforms.Compose([ | ||
transforms.Resize(32), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
|
||
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, | ||
train=True, | ||
download=True, | ||
transform=transform_train) | ||
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test) | ||
train_dataloader = get_dataloader(dataset=train_dataset, | ||
shuffle=True, | ||
batch_size=batch_size, | ||
num_workers=4, | ||
pin_memory=True) | ||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True) | ||
return train_dataloader, test_dataloader | ||
|
||
|
||
def train_cifar(): | ||
args = colossalai.get_default_parser().parse_args() | ||
# standard launch | ||
# colossalai.launch(config=args.config, | ||
# rank=args.rank, | ||
# world_size=args.world_size, | ||
# local_rank=args.local_rank, | ||
# host=args.host, | ||
# port=args.port) | ||
|
||
# launch from torchrun | ||
colossalai.launch_from_torch(config=args.config) | ||
|
||
logger = get_dist_logger() | ||
if hasattr(gpc.config, 'LOG_PATH'): | ||
if gpc.get_global_rank() == 0: | ||
log_path = gpc.config.LOG_PATH | ||
if not os.path.exists(log_path): | ||
os.mkdir(log_path) | ||
logger.log_to_file(log_path) | ||
|
||
tp = gpc.config.parallel.tensor.mode | ||
|
||
model = vit_lite_depth7_patch4_32(tensor_parallel=tp) | ||
|
||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size) | ||
|
||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) | ||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) | ||
|
||
steps_per_epoch = len(train_dataloader) | ||
|
||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, | ||
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch, | ||
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch) | ||
|
||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model, | ||
optimizer=optimizer, | ||
criterion=criterion, | ||
train_dataloader=train_dataloader, | ||
test_dataloader=test_dataloader, | ||
lr_scheduler=lr_scheduler) | ||
|
||
logger.info("Engine is built", ranks=[0]) | ||
|
||
timer = MultiTimer() | ||
|
||
trainer = Trainer(engine=engine, logger=logger, timer=timer) | ||
logger.info("Trainer is built", ranks=[0]) | ||
|
||
hooks = [ | ||
LogMetricByEpochHook(logger=logger), | ||
LogMetricByStepHook(), | ||
# LogTimingByEpochHook(timer=timer, logger=logger), | ||
# LogMemoryByEpochHook(logger=logger), | ||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), | ||
LossHook(), | ||
ThroughputHook(), | ||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) | ||
] | ||
|
||
logger.info("Train start", ranks=[0]) | ||
trainer.fit(train_dataloader=train_dataloader, | ||
test_dataloader=test_dataloader, | ||
epochs=gpc.config.NUM_EPOCHS, | ||
hooks=hooks, | ||
display_progress=True, | ||
test_interval=1) | ||
|
||
|
||
if __name__ == '__main__': | ||
train_cifar() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from colossalai.amp import AMP_TYPE | ||
|
||
TOTAL_BATCH_SIZE = 4096 | ||
LEARNING_RATE = 3e-3 | ||
WEIGHT_DECAY = 0.3 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
TENSOR_PARALLEL_MODE = '1d' | ||
|
||
NUM_EPOCHS = 300 | ||
WARMUP_EPOCHS = 32 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
fp16 = dict(mode=AMP_TYPE.TORCH, ) | ||
|
||
gradient_accumulation = 2 | ||
|
||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation | ||
|
||
clip_grad_norm = 1.0 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from colossalai.amp import AMP_TYPE | ||
|
||
TOTAL_BATCH_SIZE = 4096 | ||
LEARNING_RATE = 3e-3 | ||
WEIGHT_DECAY = 0.3 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
TENSOR_PARALLEL_MODE = '2d' | ||
|
||
NUM_EPOCHS = 300 | ||
WARMUP_EPOCHS = 32 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
fp16 = dict(mode=AMP_TYPE.TORCH, ) | ||
|
||
gradient_accumulation = 2 | ||
|
||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation | ||
|
||
clip_grad_norm = 1.0 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from colossalai.amp import AMP_TYPE | ||
|
||
TOTAL_BATCH_SIZE = 4096 | ||
LEARNING_RATE = 3e-3 | ||
WEIGHT_DECAY = 0.3 | ||
|
||
TENSOR_PARALLEL_SIZE = 4 | ||
DEPTH = 1 | ||
TENSOR_PARALLEL_MODE = '2.5d' | ||
|
||
NUM_EPOCHS = 300 | ||
WARMUP_EPOCHS = 32 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH), | ||
) | ||
|
||
fp16 = dict(mode=AMP_TYPE.TORCH, ) | ||
|
||
gradient_accumulation = 2 | ||
|
||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation | ||
|
||
clip_grad_norm = 1.0 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from colossalai.amp import AMP_TYPE | ||
|
||
TOTAL_BATCH_SIZE = 4096 | ||
LEARNING_RATE = 3e-3 | ||
WEIGHT_DECAY = 0.3 | ||
|
||
TENSOR_PARALLEL_SIZE = 8 | ||
TENSOR_PARALLEL_MODE = '3d' | ||
|
||
NUM_EPOCHS = 300 | ||
WARMUP_EPOCHS = 32 | ||
|
||
parallel = dict( | ||
pipeline=1, | ||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), | ||
) | ||
|
||
fp16 = dict(mode=AMP_TYPE.TORCH, ) | ||
|
||
gradient_accumulation = 2 | ||
|
||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation | ||
|
||
clip_grad_norm = 1.0 | ||
|
||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" |
Oops, something went wrong.