Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for dist not being initialized when constructing main config #3324

Merged
merged 7 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import ops
from . import module_inject

from .accelerator import get_accelerator
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -50,6 +51,9 @@ def _parse_version(version_str):
__git_hash__ = git_hash
__git_branch__ = git_branch

# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None


def initialize(args=None,
model: torch.nn.Module = None,
Expand Down Expand Up @@ -119,6 +123,11 @@ def initialize(args=None,

assert model is not None, "deepspeed.initialize requires a model"

global dist
from deepspeed import comm as dist
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)

# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params
Expand Down
45 changes: 1 addition & 44 deletions deepspeed/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,5 @@

# DeepSpeed Team

import torch
from .utils import *
from deepspeed import utils

supported_torch_version = False

# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.

if older_torch():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version = False
from torch.distributed import *

def get_world_group():
return group.WORLD

def get_global_rank(group, group_rank):
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)

def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import all_gather, get_world_size
from torch import chunk
output_tensors = list(chunk(output_tensor, get_world_size(group)))
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)

def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import reduce_scatter, get_world_size
from torch import chunk
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)

def configure(deepspeed_config=None, enabled=None, prof_all=None, prof_ops=None, verbose=None):
utils.logger.warn("Communication logging is not supported in torch versions older than 1.8")

else:
supported_torch_version = True
from .comm import *
from .comm import *
1 change: 0 additions & 1 deletion deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class TorchBackend(Backend):

def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
Expand Down
15 changes: 0 additions & 15 deletions deepspeed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,10 @@
# DeepSpeed Team

import os
import torch
import inspect
from deepspeed.utils import get_caller_func


def older_torch():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
return True
else:
return False


def get_local_rank_from_launcher():

# DeepSpeed launcher will set it so get from there
Expand Down
22 changes: 1 addition & 21 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import deepspeed

from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
Expand Down Expand Up @@ -96,9 +97,6 @@

from deepspeed.runtime.config import DtypeEnum

# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

MEMORY_OPT_ALLREDUCE_SIZE = 500000000

DeepSpeedOptimizerCallable = \
Expand Down Expand Up @@ -232,8 +230,6 @@ def __init__(

self.checkpoint_engine = None

global dist
from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None

Expand All @@ -243,22 +239,6 @@ def __init__(
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

from deepspeed.comm import supported_torch_version
# This supported_torch_version check is for torch1.2 compatibility only
if supported_torch_version:
dist.init_distributed(dist_backend=self.dist_backend, dist_init_required=dist_init_required)
else:
if dist_init_required is None:
dist_init_required = not dist.is_initialized()

if dist_init_required is False:
assert (
dist.is_initialized() is True
), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_backend)

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand Down