Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero] initialize a stateful tensor manager #614

Merged
merged 6 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/utils/memory_tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .async_memtracer import AsyncMemoryMonitor
from .memstats_collector import MemStatsCollector

__all__ = ['AsyncMemoryMonitor']
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector']
18 changes: 16 additions & 2 deletions colossalai/utils/memory_tracer/memstats_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ class SamplingCounter:

def __init__(self) -> None:
self._samplint_cnt = 0
self._max_sampling_cnt = None

def advance(self):
self._samplint_cnt += 1

def next(self):
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt

@property
def sampling_cnt(self):
return self._samplint_cnt

def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0


Expand Down Expand Up @@ -56,7 +62,7 @@ def overall_mem_stats(self, device_type: str):
else:
raise TypeError

def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
Expand All @@ -75,7 +81,7 @@ def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
else:
raise TypeError

def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
if unit == 'GB':
Expand All @@ -96,6 +102,14 @@ def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[in
else:
raise TypeError

def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of current sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]

def next_non_model_data(self, device_type: str):
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]

@property
def sampling_time(self):
return [t - self._sampling_time[0] for t in self._sampling_time]
Expand Down
69 changes: 69 additions & 0 deletions colossalai/zero/shard_utils/stateful_tensor_mgr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from typing import Set
from colossalai.utils.memory_tracer import MemStatsCollector


class StatefulTensorMgr(SingletonMeta):
_stateful_tensor_list: Set[ShardedParamV2] = set()
feifeibear marked this conversation as resolved.
Show resolved Hide resolved

def register_param(self, param: ShardedParamV2) -> None:
for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor)
self._stateful_tensor_list.add(t)
ver217 marked this conversation as resolved.
Show resolved Hide resolved

def evict_tensors(self) -> None:
pass

def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.

Args:
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
It contains non-model footprint of a DNN model.
"""
# find stateful tensor in state COMPUTE
move_to_cuda_tensor_list = []
cuda_demand = 0
used_cuda_model_data = 0
hold_cuda_tensor_list = []
for tensor in self._stateful_tensor_list:
if tensor.state == TensorState.FREE:
continue

if tensor.device.type == 'cuda':
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor)
else:
if tensor.state == TensorState.COMPUTE:
move_to_cuda_tensor_list.append(tensor)
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0]

# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
mem_stats_collector.next_non_model_data('cuda'))
cuda_capacity = colo_cuda_memory_capacity()
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
# Here use a naive eviction strategy.
acc_size = 0
for t in hold_cuda_tensor_list:
if acc_size > cuda_demand:
break
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
t_size = colo_tensor_mem_usage(t)
acc_size += t_size
if acc_size < cuda_demand:
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")

# move COMPUTE tensors to CUDA
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
6 changes: 6 additions & 0 deletions colossalai/zero/sharded_param/sharded_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Tuple
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from typing import List


class ShardedParamV2(object):
Expand All @@ -22,6 +23,11 @@ def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
if rm_torch_payload:
self.remove_torch_payload()

def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor, self.saved_grad]

def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)

Expand Down