Skip to content

Commit

Permalink
moved env variables to global variables; (hpcaitech#215)
Browse files Browse the repository at this point in the history
added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
  • Loading branch information
kurisusnowdeng authored and FrankLeeeee committed Feb 15, 2022
1 parent b82d60b commit 9ee197d
Show file tree
Hide file tree
Showing 63 changed files with 4,184 additions and 920 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,4 @@ dmypy.json
.DS_Store
#data/

# launcher setting
tests/launcher/log
tests/launcher/personal

docs/.build
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: yapf
args: ['--style=google', '--parallel', '--in-place']
- repo: https://github.com/pycqa/flake8
rev: ''
rev: '4.0.1'
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
Expand Down
5 changes: 3 additions & 2 deletions colossalai/amp/apex_amp/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
except ImportError:
raise ImportError('Cannot import apex.amp correctly.')

from torch import Tensor

from colossalai.nn.optimizer import ColossalaiOptimizer
Expand Down
26 changes: 16 additions & 10 deletions colossalai/communication/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
"""
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = [tensor]
out = tensor
work = None
else:
shape = list(tensor.shape)
Expand Down Expand Up @@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out


def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out


def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
21 changes: 5 additions & 16 deletions colossalai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,12 @@
'moe': 'Initializer_Moe'
}

# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'
# 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'

# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'

# 2.5D paralllel
TESSERACT_DIM = 'TESSERACT_DIM'
TESSERACT_DEP = 'TESSERACT_DEP'

# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'

# Tensor parallel attributes
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
9 changes: 5 additions & 4 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import numpy as np
import torch
import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER

from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.global_variables import moe_env


class ParallelContext:
Expand Down Expand Up @@ -307,7 +308,6 @@ def init_global_dist(self,
port: int
):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
Expand Down Expand Up @@ -389,7 +389,8 @@ def init_parallel_groups(self):
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
env.mode = tensor_parallel_mode

self.check_sanity()

pg_init = []
Expand Down
20 changes: 8 additions & 12 deletions colossalai/context/process_group_initializer/initializer_1d.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch.distributed as dist

from colossalai.context import Config
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer

from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D
from .process_group_initializer import ProcessGroupInitializer


@DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer):
"""A ProcessGroupInitializer for 1d tensor parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
'''A ProcessGroupInitializer for 1d tensor parallelism.
'''

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_group = self.world_size // self.tensor_parallel_size

def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
"""
Expand All @@ -33,7 +29,7 @@ def init_dist_group(self):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''
env.parallel_input_1d = False

for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
Expand Down
11 changes: 4 additions & 7 deletions colossalai/context/process_group_initializer/initializer_2d.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
import math
import os

import torch.distributed as dist

from colossalai.constants import SUMMA_DIM
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.global_variables import tensor_parallel_env as env


def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA
env_summa_dim = os.environ.get(SUMMA_DIM, None)
env_summa_dim = env.summa_dim

if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[SUMMA_DIM] = str(summa_dim)
env.summa_dim = summa_dim


class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class
:type num_group: int
:type summa_dim: int
"""
Expand Down Expand Up @@ -132,7 +129,7 @@ def __init__(self, *args, **kwargs):

def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
Expand Down
20 changes: 9 additions & 11 deletions colossalai/context/process_group_initializer/initializer_2p5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
# -*- encoding: utf-8 -*-

import math
import os

import torch.distributed as dist

from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer


def _check_tesseract_env_var(tesseract_dim: int,
tesseract_dep: int):
# check environment variable for TESSERACT
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None)
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None)
# check global variable for TESSERACT
env_tesseract_dim = env.tesseract_dim
env_tesseract_dep = env.tesseract_dep

if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \
Expand All @@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int,
'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[TESSERACT_DIM] = str(tesseract_dim)
os.environ[TESSERACT_DEP] = str(tesseract_dep)
env.tesseract_dim = tesseract_dim
env.tesseract_dep = tesseract_dep


# i row j col k dep
Expand Down Expand Up @@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
:param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel
:param depth: The depth of 2p5d parallel
:type rank: int
:type world_size: int
:type config: Config
Expand Down Expand Up @@ -281,7 +279,7 @@ def __init__(self,

def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
Expand Down
Loading

0 comments on commit 9ee197d

Please sign in to comment.