From 27896f5b273774ee49109ed4120c836b7219b6fa Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Tue, 19 Apr 2022 14:39:42 +0800 Subject: [PATCH 1/8] Revert "[zero] add ZeroTensorShardStrategy (#793)" This reverts commit 88759e289efd0a7b5e0d7bf8e01dbe29db85cf71. --- .../kernel/cuda_native/csrc/zero_comm.cpp | 109 ------------------ colossalai/zero/comm/__init__.py | 1 - colossalai/zero/comm/zero_comm.py | 46 -------- colossalai/zero/init_ctx/init_context.py | 2 - colossalai/zero/shard_utils/__init__.py | 3 +- .../shard_utils/zero_tensor_shard_strategy.py | 38 ------ setup.py | 6 - tests/test_zero/test_found_inf.py | 4 +- tests/test_zero/test_init_context.py | 4 +- tests/test_zero/test_mem_collector.py | 4 +- tests/test_zero/test_shard_model_v2.py | 4 +- tests/test_zero/test_state_dict.py | 4 +- 12 files changed, 11 insertions(+), 214 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/csrc/zero_comm.cpp delete mode 100644 colossalai/zero/comm/__init__.py delete mode 100644 colossalai/zero/comm/zero_comm.py delete mode 100644 colossalai/zero/shard_utils/zero_tensor_shard_strategy.py diff --git a/colossalai/kernel/cuda_native/csrc/zero_comm.cpp b/colossalai/kernel/cuda_native/csrc/zero_comm.cpp deleted file mode 100644 index e07d6f504cc8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/zero_comm.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include -#include - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -#define NCCLCHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ - ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -class ZeroCommMgr { - public: - cudaStream_t cuda_stream; - ncclComm_t nccl_comm; - - ZeroCommMgr(const ncclComm_t &comm_) { - CUDACHECK(cudaStreamCreate(&cuda_stream)); - nccl_comm = comm_; - } -}; - -ZeroCommMgr *GMGR = nullptr; - -#ifdef USE_C10D_NCCL -#include - -class HackNCCLGroup : public c10d::ProcessGroupNCCL { - public: - ncclComm_t getcomm(at::Device dev) { - ncclUniqueId ncclID; - int rank = getRank(); - if (rank == 0) { - ncclGetUniqueId(&ncclID); - } - - broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "colossal_zero_comm", - rank); - - ncclComm_t comm; - NCCLCHECK(ncclCommInitRank(&comm, getSize(), ncclID, rank)); - return comm; - } -}; - -int create_zero_comm(c10d::ProcessGroupNCCL &pg, at::Device dev) { - auto *hack_group = reinterpret_cast(&pg); - GMGR = new ZeroCommMgr(hack_group->getcomm(dev)); - assert(GMGR->nccl_comm != 0); - return 0; -} -#endif - -template -void colo_all_gather_impl(scalar_t *recvbuff, int rank, int sendcount, - ncclDataType_t data_type) { - scalar_t *sendbuff = recvbuff + (rank * sendcount); - NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, data_type, - GMGR->nccl_comm, GMGR->cuda_stream)); - CUDACHECK(cudaStreamSynchronize(GMGR->cuda_stream)); -} - -int colo_all_gather(torch::Tensor &input_tensor, int rank, int world_size) { - CHECK_INPUT(input_tensor); - - auto total_size = input_tensor.numel(); - assert(total_size % world_size == 0); - auto sendcount = total_size / world_size; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input_tensor.scalar_type(), "colo_all_gather", ([&] { - colo_all_gather_impl( - input_tensor.data_ptr(), rank, sendcount, - input_tensor.scalar_type() == at::ScalarType::Half ? ncclHalf - : ncclFloat); - })); - - return 0; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -#ifdef USE_C10D_NCCL - m.def("create_comm", &create_zero_comm, - "Create the communication environment for Colossal Zero"); -#endif - m.def("inplace_all_gather", &colo_all_gather, - "All gather operation used in Colossal Zero"); -} diff --git a/colossalai/zero/comm/__init__.py b/colossalai/zero/comm/__init__.py deleted file mode 100644 index 16b2d3e022e7..000000000000 --- a/colossalai/zero/comm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .zero_comm import ZeroDist diff --git a/colossalai/zero/comm/zero_comm.py b/colossalai/zero/comm/zero_comm.py deleted file mode 100644 index a2d54a0150af..000000000000 --- a/colossalai/zero/comm/zero_comm.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from colossalai.context.singleton_meta import SingletonMeta -from colossalai.utils import get_current_device -from typing import Optional - -ZERO_USE_NCCL = False -try: - import colossal_zero_comm - ZERO_USE_NCCL = True -except ImportError: - print("Please pip reinstall Colossalai.") - - -class ZeroCommWorld(metaclass=SingletonMeta): - """Zero communicator, used for communications in zero parallel. - """ - - def __init__(self): - super().__init__() - self.zero_pg: Optional[ProcessGroup] = None - - @property - def is_initialized(self): - return self.zero_pg is not None - - def zero_comm_init(self, comm_group: ProcessGroup): - if not ZERO_USE_NCCL: - return - - if self.is_initialized: - assert self.zero_pg == comm_group, "Cant not initialize zero group twice" - return - - self.zero_pg = comm_group - colossal_zero_comm.create_comm(self.zero_pg, get_current_device()) - - def zero_all_gather(self, input_tensor: torch.Tensor): - assert self.zero_pg is not None, "Please initialize zero communication world first" - rank = dist.get_rank(self.zero_pg) - world_size = self.zero_pg.size() - colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size) - - -ZeroDist = ZeroCommWorld() diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 7a5bf6f4cd4d..c27d7a577733 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -12,7 +12,6 @@ from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.zero.comm import ZeroDist from contextlib import AbstractContextManager @@ -192,7 +191,6 @@ def _pre_context_exec(self): The Callback function when entering the context """ self.logger = get_dist_logger("ZeroInitContext") - ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world # substitute fan-in and fan-out calculation self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py index 6a0e0c59b33f..5e5d63a7e768 100644 --- a/colossalai/zero/shard_utils/__init__.py +++ b/colossalai/zero/shard_utils/__init__.py @@ -1,6 +1,5 @@ from .base_shard_strategy import BaseShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy -from .zero_tensor_shard_strategy import ZeroTensorShardStrategy -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy'] +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py b/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py deleted file mode 100644 index afd2f619c193..000000000000 --- a/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.zero.comm import ZeroDist - -from .tensor_shard_strategy import TensorShardStrategy - - -class ZeroTensorShardStrategy(TensorShardStrategy): - """Use the same shard scheme as `TensorShardStrategy`'s. - But its all-gather operation is in-place, meaning that no extra buffer is created. - Extra buffer is created when using `torch.distributed.all_gather`. - This can reduce peak memory used in zero-offload. - You should notice that this strategy is highly coupled with zero. - You can not change its communication group and must use ZeroContext to create your model. - """ - - def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): - if not t.is_sharded: - return - target_device = t.device - payload_numel = t.payload.numel() - world_size = dist.get_world_size(process_group) - rank = dist.get_rank(process_group) - - buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) - buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) - buffer_list[rank].copy_(t.payload) - - ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here - gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) - t.reset_payload(gathered_payload) - colo_model_data_tensor_move_inline(t, target_device) - t.is_sharded = False diff --git a/setup.py b/setup.py index c8817c82ab8c..12b12c31dd49 100644 --- a/setup.py +++ b/setup.py @@ -134,12 +134,6 @@ def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) }) - ext_modules.append( - cuda_ext_helper(name='colossal_zero_comm', - sources=['zero_comm.cpp'], - extra_cuda_flags=['-DUSE_C10D_NCCL'], - extra_cxx_flags=['-DUSE_C10D_NCCL'])) - ext_modules.append( cuda_ext_helper('colossal_C', [ 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index af1b2e6700b8..34283f5015e1 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -9,7 +9,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import ZeroTensorShardStrategy +from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim._utils import has_inf_or_nan @@ -20,7 +20,7 @@ @parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [ZeroTensorShardStrategy]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) @parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): test_models = ['repeated_computed_layers'] diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index 61bec973c1a5..b955e4852a40 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -15,14 +15,14 @@ colo_model_mem_usage from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG @parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_model_test(init_device_type, shard_strategy_class): logger = get_dist_logger("test_zero_init") diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py index d311e0f378db..bea971935d36 100644 --- a/tests/test_zero/test_mem_collector.py +++ b/tests/test_zero/test_mem_collector.py @@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.shard_utils import ZeroTensorShardStrategy +from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from functools import partial @@ -35,7 +35,7 @@ def run_mem_collector_testing(): fraction = (50 * 1024**2) / cuda_capacity # limit max memory to 50MB colo_set_process_memory_fraction(fraction) - shard_strategy = ZeroTensorShardStrategy() + shard_strategy = BucketTensorShardStrategy() with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): model = MyTestModel() diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 62b860c236ac..654c82a4671e 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -10,7 +10,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy @@ -21,7 +21,7 @@ @parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) def run_model_test(enable_autocast, shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] shard_strategy = shard_strategy_class() diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 18d4c983a683..188bc5968da5 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -11,7 +11,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs @@ -19,7 +19,7 @@ from common import CONFIG -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_zero_state_dict(shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18'] shard_strategy = shard_strategy_class() From 7e5716fbf4cb2f757c311df658cc23d395e619c5 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Tue, 19 Apr 2022 15:48:19 +0800 Subject: [PATCH 2/8] [gemini] set cpu memory capacity --- colossalai/initialize.py | 3 --- colossalai/utils/__init__.py | 5 +++-- colossalai/utils/memory.py | 28 ++++++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index b806356e4ae2..08bd43f62fb1 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -275,9 +275,6 @@ def initialize(model: nn.Module, optimizer_config=optimizer_config) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) - # FIXME() throw a warning if using zero with MP - if gpc.get_world_size(ParallelMode.MODEL) > 1: - logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) else: if isinstance(model, nn.Module): # first sync model across dp ranks diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index fa69a33f65e0..6e1720b3d30f 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,8 @@ param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param, disposable) from .data_sampler import DataParallelSampler, get_dataloader -from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity +from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, + colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) from .timer import MultiTimer, Timer from .tensor_detector import TensorDetector @@ -19,5 +20,5 @@ 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', - 'ensure_path_exists', 'disposable' + 'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity' ] diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py index 12c23c3a0299..434e90edd3b9 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/utils/memory.py @@ -11,6 +11,7 @@ from packaging import version _GLOBAL_CUDA_MEM_FRACTION = 1.0 +_GLOBAL_CPU_MEM_CAPACITY = -1 def _bytes_to_MB(val, decimal=2): @@ -106,9 +107,8 @@ def colo_device_memory_capacity(device: torch.device) -> int: """ assert isinstance(device, torch.device) if device.type == 'cpu': - mem_info = _get_cpu_memory_info() # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. - return mem_info.total / gpc.num_processes_on_current_node + return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == 'cuda': return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION @@ -152,3 +152,27 @@ def colo_set_process_memory_fraction(ratio: float) -> None: global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + + +def colo_set_cpu_memory_capacity(size: int) -> None: + global _GLOBAL_CPU_MEM_CAPACITY + mem_info = _get_cpu_memory_info() + total_size = mem_info.total + if size <= total_size: + _GLOBAL_CPU_MEM_CAPACITY = size + else: + _GLOBAL_CPU_MEM_CAPACITY = total_size + + +def colo_get_cpu_memory_capacity() -> int: + """ + Get the cpu memory capacity. We may not use all of it. + Returns: + int: _description_ + """ + global _GLOBAL_CPU_MEM_CAPACITY + if _GLOBAL_CPU_MEM_CAPACITY == -1: + mem_info = _get_cpu_memory_info() + return mem_info.total + else: + return _GLOBAL_CPU_MEM_CAPACITY From 643bc453eeac4040410e193d11d6c82fdc4f5137 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 09:36:51 +0800 Subject: [PATCH 3/8] [log] local throughput collecting --- colossalai/gemini/tensor_placement_policy.py | 4 ++-- colossalai/trainer/hooks/_metric_hook.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index 3e0851c3e6f6..3770fb988fe0 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -48,8 +48,8 @@ def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> N super().__init__(None, mem_stats_collector=mem_stats_collector) # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase # TODO(ver217): make these args configurable - self._warmup_non_model_data_ratio: float = 0.8 - self._steady_cuda_cap_ratio: float = 0.8 + self._warmup_non_model_data_ratio: float = 0.2 + self._steady_cuda_cap_ratio: float = 0.7 def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 001557a35a55..44993bc4d4eb 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -124,7 +124,7 @@ def get_accumulated_value(self): def get_last_step_value(self) -> str: """Returns :attr:`last_step_loss`. """ - return str(self.last_step_loss) + return str(self.last_step_loss.cpu().item()) @staticmethod def is_better(a, b): @@ -207,7 +207,7 @@ def update(self, logits, targets, batch_size) -> None: def get_last_step_value(self) -> str: self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) - return str(_format_number((self.last_step_correct / self.last_step_sum).item())) + return str(_format_number((self.last_step_correct / self.last_step_sum).cpu().item())) def get_accumulated_value(self): self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) @@ -324,7 +324,7 @@ class ThroughputMetric(Metric): epoch_only (bool): Whether the metric only read for the full epoch. """ - def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0): + def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False): super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 @@ -333,6 +333,7 @@ def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device()) self._tflop_per_step = tflop_per_step + self._use_local = use_local def reset(self) -> None: # self.cur_steps = 0 @@ -350,9 +351,13 @@ def update(self, num_samples, time) -> None: self.accumulated_used_time += self.last_step_used_time def get_last_step_value(self) -> str: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) - self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + if self._use_local: + self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) + else: + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ + gpc.get_world_size(ParallelMode.DATA) + self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) + sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) if self._tflop_per_step > 0: tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12)) From 9e754d21154e72558a84010f1999408b2e7846dd Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 09:41:32 +0800 Subject: [PATCH 4/8] polish --- colossalai/gemini/tensor_placement_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index 3770fb988fe0..3e0851c3e6f6 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -48,8 +48,8 @@ def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> N super().__init__(None, mem_stats_collector=mem_stats_collector) # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase # TODO(ver217): make these args configurable - self._warmup_non_model_data_ratio: float = 0.2 - self._steady_cuda_cap_ratio: float = 0.7 + self._warmup_non_model_data_ratio: float = 0.8 + self._steady_cuda_cap_ratio: float = 0.8 def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], From 60029814c331dc784262b4228a6e7fddfe22600d Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 09:48:32 +0800 Subject: [PATCH 5/8] polish --- colossalai/zero/sharded_optim/sharded_optim_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index b5d980e047e1..9f6ee7e03233 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -122,7 +122,8 @@ def __init__(self, self._register_master_weight() if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy): - self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"') + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', + ranks=[0]) if self._verbose: self._logger.debug( From 43e4b8e7f7e6d4c3483b5564eb88279921d2f27c Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 09:54:03 +0800 Subject: [PATCH 6/8] polish --- colossalai/trainer/hooks/_metric_hook.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 44993bc4d4eb..dae0841706dc 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -387,17 +387,19 @@ class ThroughputHook(MetricHook): depend on the hooks order in the hook list. """ - def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0): + def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False): super().__init__(priority) self.ignored_steps = ignored_steps self._tflop_per_step = tflop_per_step + self._use_local = use_local def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps, - tflop_per_step=self._tflop_per_step) + tflop_per_step=self._tflop_per_step, + use_local=self._use_local) # register the metric trainer.states['metrics']['train']['Throughput'] = self.metric From 073e153ef1c2f24916e951dcd46c5a4d6b88bc1f Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 09:55:32 +0800 Subject: [PATCH 7/8] polish code --- colossalai/trainer/hooks/_metric_hook.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index dae0841706dc..8aa8dbc278d3 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -385,6 +385,8 @@ class ThroughputHook(MetricHook): priority (int, optional): Priority in the printing, hooks with small priority will be printed in front defaults to 10. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. + tflop_per_step(int, optional): tera floating point operations per step. + use_local (bool, optional): Whether to use local time. """ def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False): From 58ea5d0b37d0fd663fbbaff13c90dc412f90afb6 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 20 Apr 2022 10:00:39 +0800 Subject: [PATCH 8/8] polish --- colossalai/trainer/hooks/_metric_hook.py | 2 +- colossalai/zero/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 8aa8dbc278d3..dbca20169e7e 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -386,7 +386,7 @@ class ThroughputHook(MetricHook): defaults to 10. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. tflop_per_step(int, optional): tera floating point operations per step. - use_local (bool, optional): Whether to use local time. + use_local (bool, optional): Whether to use local time for throughput calculation. """ def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False): diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 714474ea5286..1ea7c73e36c9 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -23,10 +23,10 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model logger = get_dist_logger('convert_to_zero_v2') - logger.info(f'optimizer_config is {optimizer_config}') + logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) if optimizer_config is None: optimizer_config = dict() - logger.info(f'model_config is {model_config}') + logger.info(f'model_config is {model_config}', ranks=[0]) if model_config is None: model_config = dict()