diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 12f26b1927af..255dacdccf6e 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -15,6 +15,7 @@ from . import ops from . import module_inject +from .accelerator import get_accelerator from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine @@ -50,6 +51,9 @@ def _parse_version(version_str): __git_hash__ = git_hash __git_branch__ = git_branch +# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init +dist = None + def initialize(args=None, model: torch.nn.Module = None, @@ -119,6 +123,11 @@ def initialize(args=None, assert model is not None, "deepspeed.initialize requires a model" + global dist + from deepspeed import comm as dist + dist_backend = get_accelerator().communication_backend_name() + dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required) + # Set config using config_params for backwards compat if config is None and config_params is not None: config = config_params diff --git a/deepspeed/comm/__init__.py b/deepspeed/comm/__init__.py index 82c86113fd30..2c26ef491308 100644 --- a/deepspeed/comm/__init__.py +++ b/deepspeed/comm/__init__.py @@ -3,48 +3,5 @@ # DeepSpeed Team -import torch from .utils import * -from deepspeed import utils - -supported_torch_version = False - -# See more details at: https://github.com/pytorch/pytorch/pull/48767 -# The PG API in torch versions lesser than 1.8 are different so it is -# non-trivial to support both in the same API. We will just use the -# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+. - -if older_torch(): - # Add custom deepspeed torch comm functions here since we can't import deepspeed.comm - # NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them. - supported_torch_version = False - from torch.distributed import * - - def get_world_group(): - return group.WORLD - - def get_global_rank(group, group_rank): - if hasattr(torch.distributed.distributed_c10d, "get_global_rank"): - from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank - else: - from torch.distributed.distributed_c10d import _get_global_rank - return _get_global_rank(group, group_rank) - - def allgather_fn(output_tensor, input_tensor, group=None, async_op=False): - from torch.distributed import all_gather, get_world_size - from torch import chunk - output_tensors = list(chunk(output_tensor, get_world_size(group))) - return all_gather(output_tensors, input_tensor, group=group, async_op=async_op) - - def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False): - from torch.distributed import reduce_scatter, get_world_size - from torch import chunk - input_tensor_lst = list(chunk(input_tensor, get_world_size(group))) - return reduce_scatter(output_tensor, input_tensor_lst, group=group) - - def configure(deepspeed_config=None, enabled=None, prof_all=None, prof_ops=None, verbose=None): - utils.logger.warn("Communication logging is not supported in torch versions older than 1.8") - -else: - supported_torch_version = True - from .comm import * +from .comm import * diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 18e18500a4e9..496fb0d9615e 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -21,7 +21,6 @@ class TorchBackend(Backend): def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'): super(TorchBackend, self).__init__() - self.torch_version_before_18 = older_torch() self.all_gather_function = self.get_all_gather_function() self.reduce_scatter_function = self.get_reduce_scatter_function() self.initialized = True diff --git a/deepspeed/comm/utils.py b/deepspeed/comm/utils.py index 7978918e9d01..27a4d2c4a588 100644 --- a/deepspeed/comm/utils.py +++ b/deepspeed/comm/utils.py @@ -4,25 +4,10 @@ # DeepSpeed Team import os -import torch import inspect from deepspeed.utils import get_caller_func -def older_torch(): - ''' - Helper to lookup torch version. For versions less than 1.8, torch.dist - used torch.distributed.group.WORLD as the default group argument instead of None. - See more details at: https://github.com/pytorch/pytorch/pull/48767 - ''' - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - if TORCH_MAJOR == 1 and TORCH_MINOR < 8: - return True - else: - return False - - def get_local_rank_from_launcher(): # DeepSpeed launcher will set it so get from there diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f290ab179caf..f073a0e61695 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -21,6 +21,7 @@ import deepspeed +from deepspeed import comm as dist from deepspeed.runtime.utils import see_memory_usage, DummyOptim from .zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer @@ -96,9 +97,6 @@ from deepspeed.runtime.config import DtypeEnum -# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init -dist = None - MEMORY_OPT_ALLREDUCE_SIZE = 500000000 DeepSpeedOptimizerCallable = \ @@ -232,8 +230,6 @@ def __init__( self.checkpoint_engine = None - global dist - from deepspeed import comm as dist self._is_gradient_accumulation_boundary = None self.scale_wrt_gas = None @@ -243,22 +239,6 @@ def __init__( # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict self.param_names = {param: name for name, param in model.named_parameters()} - from deepspeed.comm import supported_torch_version - # This supported_torch_version check is for torch1.2 compatibility only - if supported_torch_version: - dist.init_distributed(dist_backend=self.dist_backend, dist_init_required=dist_init_required) - else: - if dist_init_required is None: - dist_init_required = not dist.is_initialized() - - if dist_init_required is False: - assert ( - dist.is_initialized() is True - ), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()" - else: - if not dist.is_initialized(): - dist.init_process_group(backend=self.dist_backend) - self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check()