Skip to content

Commit

Permalink
[zero] stateful tensor manager (hpcaitech#687)
Browse files Browse the repository at this point in the history
* [WIP] stateful tensor manager

* add eviction strategy

* polish code

* polish code

* polish comment

* add unit test

* fix sampler bug

* polish code

* fix max sampling cnt resetting bug

* fix sampler bug

* polish code

* fix bug

* fix unit test

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
  • Loading branch information
ver217 and feifeibear authored Apr 8, 2022
1 parent 70e8dd4 commit 3c9cd5b
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 73 deletions.
40 changes: 30 additions & 10 deletions colossalai/engine/ophooks/zero_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr

from ._base_ophook import BaseOpHook

Expand All @@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook):

def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector],
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None):
super().__init__()
self.shard_strategy = shard_strategy
self.process_group = process_group

# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')

self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr

def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)

if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload

if self._memstarts_collector:
self._memstarts_collector.sample_memstats()

for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
param.data = param.colo_attr.sharded_data_tensor.payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"

def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
Expand All @@ -60,19 +71,27 @@ def post_fwd_exec(self, module: torch.nn.Module, *args):
param.colo_attr.remove_torch_payload()

def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)

if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)

tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload

if self._memstarts_collector:
self._memstarts_collector.sample_memstats()

for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
param.data = param.colo_attr.sharded_data_tensor.payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"

def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False):
Expand All @@ -91,4 +110,5 @@ def pre_iter(self):
pass

def post_iter(self):
pass
if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.reset()
32 changes: 20 additions & 12 deletions colossalai/utils/memory_tracer/memstats_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def next(self):
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt

@property
def sampling_cnt(self):
def current(self):
return self._samplint_cnt

def max(self):
return self._max_sampling_cnt

def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0
Expand All @@ -37,7 +39,7 @@ class MemStatsCollector:
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""

Expand All @@ -50,6 +52,8 @@ def __init__(self) -> None:
self._model_data_cpu_list = []
self._overall_cpu_list = []

self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = []

self._start_flag = False
Expand Down Expand Up @@ -96,18 +100,20 @@ def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
raise TypeError

if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
return [elem / scale for elem in self._non_model_data_cuda_list]
elif device_type == 'cpu':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
return [elem / scale for elem in self._non_model_data_cpu_list]
else:
raise TypeError

def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of current sampling moment
"""get the non model data of the current sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]

def next_non_model_data(self, device_type: str):
"""get the non model data of the next sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]

@property
Expand All @@ -128,18 +134,20 @@ def sample_memstats(self) -> None:
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
sampling_cnt = self._sampling_cnter.current()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])

self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)

# FIXME() cpu sys used should also return from self._mem_monitor()
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))

self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self._sampling_cnter.advance()

def reset_sampling_cnter(self) -> None:
Expand All @@ -155,4 +163,4 @@ def clear(self) -> None:

self._start_flag = False
self._sampling_cnter.reset()
self._mem_monitor.finish()
self._mem_monitor.finish()
3 changes: 2 additions & 1 deletion colossalai/zero/shard_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
from .stateful_tensor_mgr import StatefulTensorMgr

__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
108 changes: 79 additions & 29 deletions colossalai/zero/shard_utils/stateful_tensor_mgr.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
import functools
import torch
from colossalai.context.singleton_meta import SingletonMeta
import types
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from typing import Set
from typing import Dict, List
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.logging import get_dist_logger


class StatefulTensorMgr(SingletonMeta):
_stateful_tensor_list: Set[ShardedParamV2] = set()
class StatefulTensorMgr(object):
"""
Stateful Tensor Manager, inspired from PatrickStar
def register_param(self, param: ShardedParamV2) -> None:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""

def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
self._stateful_tensor_list: List[StatefulTensor] = []
self._mem_stats_collector = mem_stats_collector
self._logger = get_dist_logger("StatefulTensorMgr")

self._warmup = True
self._warmup_cuda_available_ratio = 0.2

self._compute_list: List[StatefulTensor] = []
self._compute_idx: int = -1

def register_stateful_param(self, param: ShardedParamV2) -> None:
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor)
self._stateful_tensor_list.add(t)
self._stateful_tensor_list.append(t)
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)

def evict_tensors(self) -> None:
pass

def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
def adjust_layout(self) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
Expand All @@ -41,29 +58,62 @@ def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor)
else:
elif tensor.device.type == 'cpu':
if tensor.state == TensorState.COMPUTE:
move_to_cuda_tensor_list.append(tensor)
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0]

# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
mem_stats_collector.next_non_model_data('cuda'))
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
else:
raise RuntimeError
cuda_capacity = colo_cuda_memory_capacity()
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
# Here use a naive eviction strategy.
acc_size = 0
for t in hold_cuda_tensor_list:
if acc_size > cuda_demand:
break
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
t_size = colo_tensor_mem_usage(t)
acc_size += t_size
if acc_size < cuda_demand:
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")

if self._warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
self._mem_stats_collector.next_non_model_data('cuda'))

total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data

if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
# move COMPUTE tensors to CUDA
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())

def reset(self):
"""This function must be called when each iteration finishes
"""
self._warmup = False
self._compute_idx = -1

def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
freed_cuda_model_data = 0
to_free_tensor_list = hold_cuda_tensor_list
if not self._warmup:
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
if self._compute_list[i] in next_compute_idx:
next_compute_idx[self._compute_list[i]] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
for t in to_free_tensor_list:
if freed_cuda_model_data > to_free_cuda_model_data:
break
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)

def _trans_state(self, trans_state_func, stateful_tensor, state):
trans_state_func(state)
if state == TensorState.COMPUTE:
self._compute_idx += 1
if self._warmup:
self._compute_list.append(stateful_tensor)
Loading

0 comments on commit 3c9cd5b

Please sign in to comment.