Skip to content

Commit

Permalink
Layer integration (hpcaitech#83)
Browse files Browse the repository at this point in the history
* integrated parallel layers for ease of building models

* integrated 2.5d layers

* cleaned codes and unit tests

* added log metric by step hook; updated imagenet benchmark; fixed some bugs

* reworked initialization; cleaned codes

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
  • Loading branch information
kurisusnowdeng and BoxiangW authored Dec 27, 2021
1 parent 5c3843d commit 0fedef4
Show file tree
Hide file tree
Showing 118 changed files with 4,965 additions and 8,140 deletions.
66 changes: 66 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Benchmark for Tuning Accuracy and Efficiency

## Overview

The benchmark includes our efforts in using Colossal-AI to train different tasks to achieve SOTA results.
We are interested in both validataion accuracy and training speed, and prefer larger batch size to take advantage of more GPU devices.
For example, we trained vision transformer with batch size 512 on CIFAR10 and 4096 on ImageNet1k, which are basically not used in existing works.
Some of the results in the benchmark trained with 8x A100 are shown below.

| Task | Model | Training Time | Top-1 Accuracy |
| ---------- | ------------ | ------------- | -------------- |
| CIFAR10 | [ViT-Lite-7/4](https://arxiv.org/pdf/2104.05704.pdf) | ~ 16 min | ~ 90.5% |
| ImageNet1k | ViT-S/16 | ~ 16.5 h | ~ 74.5% |

The `train.py` script in each task runs training with the specific configuration script in `configs/` for different parallelisms.
Supported parallelisms include data parallel only (ends with `vanilla`), 1D (ends with `1d`), 2D (ends with `2d`), 2.5D (ends with `2p5d`), 3D (ends with `3d`).

Each configuration scripts basically includes the following elements, taking ImageNet1k task as example:
```
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
# data parallel only
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
# parallelism setting
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, ) # amp setting
gradient_accumulation = 2 # accumulate 2 steps for gradient update
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation # actual batch size for dataloader
clip_grad_norm = 1.0 # clip gradient with norm 1.0
```
Upper case elements are basically what `train.py` needs, and lower case elements are what Colossal-AI needs to initialize the training.

## Usage

To start training, use the following command to run each worker:
```
$ DATA=/path/to/dataset python train.py --world_size=WORLD_SIZE \
--rank=RANK \
--local_rank=LOCAL_RANK \
--host=MASTER_IP_ADDRESS \
--port=MASTER_PORT \
--config=CONFIG_FILE
```
It is also recommended to start training with `torchrun` as:
```
$ DATA=/path/to/dataset torchrun --nproc_per_node=NUM_GPUS_PER_NODE \
--nnodes=NUM_NODES \
--node_rank=NODE_RANK \
--master_addr=MASTER_IP_ADDRESS \
--master_port=MASTER_PORT \
train.py --config=CONFIG_FILE
```
18 changes: 18 additions & 0 deletions benchmark/cifar/configs/vit_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 200
WARMUP_EPOCHS = 40

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

seed = 42

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
18 changes: 18 additions & 0 deletions benchmark/cifar/configs/vit_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'

NUM_EPOCHS = 200
WARMUP_EPOCHS = 40

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

seed = 42

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
19 changes: 19 additions & 0 deletions benchmark/cifar/configs/vit_2p5d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'

NUM_EPOCHS = 200
WARMUP_EPOCHS = 40

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)

seed = 42

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
18 changes: 18 additions & 0 deletions benchmark/cifar/configs/vit_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'

NUM_EPOCHS = 200
WARMUP_EPOCHS = 40

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

seed = 42

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
18 changes: 18 additions & 0 deletions benchmark/cifar/configs/vit_vanilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2

TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None

NUM_EPOCHS = 200
WARMUP_EPOCHS = 40

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

seed = 42

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
126 changes: 126 additions & 0 deletions benchmark/cifar/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os

import colossalai
import torch
import torchvision
from colossalai.builder import *
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook, LossHook,
LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms

DATASET_PATH = str(os.environ['DATA'])


def build_cifar(batch_size):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
train=True,
download=True,
transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
return train_dataloader, test_dataloader


def train_cifar():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)

# launch from torchrun
colossalai.launch_from_torch(config=args.config)

logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)

tp = gpc.config.parallel.tensor.mode

model = vit_lite_depth7_patch4_32(tensor_parallel=tp)

train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)

criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)

optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)

steps_per_epoch = len(train_dataloader)

lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch)

engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
lr_scheduler=lr_scheduler)

logger.info("Engine is built", ranks=[0])

timer = MultiTimer()

trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])

hooks = [
LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
]

logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hooks,
display_progress=True,
test_interval=1)


if __name__ == '__main__':
train_cifar()
26 changes: 26 additions & 0 deletions benchmark/imagenet100/configs/vit_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE

TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '1d'

NUM_EPOCHS = 300
WARMUP_EPOCHS = 32

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

fp16 = dict(mode=AMP_TYPE.TORCH, )

gradient_accumulation = 2

BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation

clip_grad_norm = 1.0

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
26 changes: 26 additions & 0 deletions benchmark/imagenet100/configs/vit_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE

TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'

NUM_EPOCHS = 300
WARMUP_EPOCHS = 32

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

fp16 = dict(mode=AMP_TYPE.TORCH, )

gradient_accumulation = 2

BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation

clip_grad_norm = 1.0

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
27 changes: 27 additions & 0 deletions benchmark/imagenet100/configs/vit_2p5d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from colossalai.amp import AMP_TYPE

TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'

NUM_EPOCHS = 300
WARMUP_EPOCHS = 32

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)

fp16 = dict(mode=AMP_TYPE.TORCH, )

gradient_accumulation = 2

BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation

clip_grad_norm = 1.0

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
26 changes: 26 additions & 0 deletions benchmark/imagenet100/configs/vit_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE

TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3

TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'

NUM_EPOCHS = 300
WARMUP_EPOCHS = 32

parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)

fp16 = dict(mode=AMP_TYPE.TORCH, )

gradient_accumulation = 2

BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation

clip_grad_norm = 1.0

LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
Loading

0 comments on commit 0fedef4

Please sign in to comment.