Skip to content

Commit

Permalink
Refactor distributed_launcher logic and simplify fb build (facebookre…
Browse files Browse the repository at this point in the history
…search#185)

Summary:
Pull Request resolved: facebookresearch#185

we are bringing in submitit to VISSL (thanks to QuentinDuval ) and this required moving the `launch_distributed` to the core `vissl` library. Some refactorings needed to make the fblearner workflow adapt accordingly.

Reviewed By: QuentinDuval

Differential Revision: D26340992

fbshipit-source-id: ca5477f488317fe79bdcb8b224cc9900acf30f49
  • Loading branch information
prigoyal authored and facebook-github-bot committed Feb 10, 2021
1 parent 8e20679 commit e94858c
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 99 deletions.
2 changes: 1 addition & 1 deletion dev/launch_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ export PYTHONPATH="$EXP_ROOT_DIR/:$PYTHONPATH"
python -u "$EXP_ROOT_DIR/tools/run_distributed_engines.py" \
"${CFG[@]}" \
hydra.run.dir="$EXP_ROOT_DIR" \
config.SLURM.ENABLED=true \
config.SLURM.USE_SLURM=true \
config.SLURM.LOG_FOLDER="$EXP_ROOT_DIR" \
config.CHECKPOINT.DIR="$CHECKPOINT_DIR"
2 changes: 2 additions & 0 deletions docs/source/train_resource_setup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ While the more SLURM specific options are located in the "SLURM" configuration b
.. code-block:: yaml
SLURM:
# set to True to use SLURM
USE_SLURM: true
# Name of the job on SLURM
NAME: "vissl"
# Comment of the job on SLURM
Expand Down
2 changes: 1 addition & 1 deletion tools/cluster_features_and_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import faiss
import numpy as np
from hydra.experimental import compose, initialize_config_module
from run_distributed_engines import launch_distributed
from vissl.data import build_dataset
from vissl.hooks import default_hook_generator
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import AttrDict, convert_to_attrdict, is_hydra_available
from vissl.utils.io import save_file
Expand Down
2 changes: 1 addition & 1 deletion tools/nearest_neighbor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

import torch
from hydra.experimental import compose, initialize_config_module
from run_distributed_engines import launch_distributed
from torch import nn
from vissl.hooks import default_hook_generator
from vissl.models.model_helpers import get_trunk_output_feature_names
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import (
AttrDict,
Expand Down
35 changes: 27 additions & 8 deletions tools/run_distributed_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,48 @@

"""
Wrapper to call torch.distributed.launch to run multi-gpu trainings.
Supports two engines: train and extract_features
Supports SLURM as an option
Supports two engines: train and extract_features.
Supports SLURM as an option. Set config.SLURM.USE_SLURM=true to use slurm.
"""

import sys
from typing import List, Any

from hydra.experimental import initialize_config_module, compose

from vissl.utils.distributed_training import is_submitit_available, launch_on_local_node, launch_on_slurm
from vissl.utils.distributed_launcher import (
launch_distributed,
launch_distributed_on_slurm,
)
from vissl.utils.hydra_config import is_hydra_available, convert_to_attrdict
from vissl.utils.slurm import is_submitit_available


def hydra_main(overrides: List[Any]):
######################################################################################
# DO NOT MOVE THIS IMPORT TO TOP LEVEL: submitit processes will not be initialized
# correctly (MKL_THREADING_LAYER will be set to INTEL instead of GNU)
######################################################################################
from vissl.hooks import default_hook_generator

######################################################################################

print(f"####### overrides: {overrides}")
with initialize_config_module(config_module="vissl.config"):
cfg = compose("defaults", overrides=overrides)
args, config = convert_to_attrdict(cfg)
if config.SLURM.ENABLED:
assert is_submitit_available(), "Please 'pip install submitit' to schedule jobs on SLURM"
launch_on_slurm(engine_name=args.engine_name, config=config)

if config.SLURM.USE_SLURM:
assert (
is_submitit_available()
), "Please 'pip install submitit' to schedule jobs on SLURM"
launch_distributed_on_slurm(engine_name=args.engine_name, cfg=config)
else:
launch_on_local_node(node_id=args.node_id, engine_name=args.engine_name, config=config)
launch_distributed(
cfg=config,
node_id=args.node_id,
engine_name=args.engine_name,
hook_generator=default_hook_generator,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tools/train_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import numpy as np
from hydra.experimental import compose, initialize_config_module
from run_distributed_engines import launch_distributed
from vissl.hooks import default_hook_generator
from vissl.models.model_helpers import get_trunk_output_feature_names
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import (
AttrDict,
Expand Down
2 changes: 1 addition & 1 deletion tools/train_svm_low_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
generate_places_low_shot_samples,
)
from hydra.experimental import compose, initialize_config_module
from run_distributed_engines import launch_distributed
from vissl.data import dataset_catalog
from vissl.hooks import default_hook_generator
from vissl.models.model_helpers import get_trunk_output_feature_names
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import (
AttrDict,
Expand Down
2 changes: 1 addition & 1 deletion vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ config:
# ----------------------------------------------------------------------------------- #
SLURM:
# Whether or not to run the job on SLURM
ENABLED: false
USE_SLURM: false
# Name of the job on SLURM
NAME: "vissl"
# Comment of the job on SLURM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from fvcore.common.file_io import PathManager

from vissl.data.dataset_catalog import get_data_files
from vissl.engines.extract_features import extract_main
from vissl.engines.train import train_main
Expand All @@ -23,7 +22,7 @@
)
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import AttrDict
from vissl.utils.io import cleanup_dir, copy_data_to_local
from vissl.utils.io import cleanup_dir, copy_data_to_local, makedir
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.misc import get_dist_run_id
from vissl.utils.slurm import get_node_id
Expand Down Expand Up @@ -55,53 +54,33 @@ def _cleanup_local_dir(cfg: AttrDict):
cleanup_dir(dest_dir)


def _distributed_worker(
local_rank: int,
def launch_distributed(
cfg: AttrDict,
node_id: int,
dist_run_id: str,
engine_name: str,
checkpoint_path: str,
checkpoint_folder: str,
hook_generator: Callable[[Any], List[ClassyHook]],
):
dist_rank = cfg.DISTRIBUTED.NUM_PROC_PER_NODE * node_id + local_rank
if engine_name == "extract_features":
process_main = extract_main
else:
def process_main(cfg, dist_run_id, local_rank, node_id):
train_main(
cfg,
dist_run_id,
checkpoint_path,
checkpoint_folder,
local_rank=local_rank,
node_id=node_id,
hook_generator=hook_generator,
)
"""
Launch the distributed training across gpus of the current node according to the cfg.
logging.info(
f"Spawning process for node_id: {node_id}, local_rank: {local_rank}, "
f"dist_rank: {dist_rank}, dist_run_id: {dist_run_id}"
)
process_main(cfg, dist_run_id, local_rank=local_rank, node_id=node_id)
If more than 1 nodes are needed for training, this function should be called on each
of the different nodes, each time with an unique node_id in the range [0..N-1] if N
is the total number of nodes to take part in training.
Alternatively, you can use SLURM or any cluster management system to run this function
for you.
def _launch_distributed(
cfg: AttrDict,
node_id: int,
engine_name: str,
hook_generator: Callable[[Any], List[ClassyHook]],
):
"""
Launch the distributed training across gpus, according to the cfg
Configure the node_id, dist_run_id, setup the environment variabled
Args:
cfg -- VISSL yaml configuration
node_id -- node_id for this node
engine_name -- what engine to run: train or extract_features
hook_generator -- Callback to generate all the ClassyVision hooks for this engine
cfg (AttrDict): VISSL yaml configuration
node_id (int): node_id for this node
engine_name (str): what engine to run: train or extract_features
hook_generator (Callable): Callback to generate all the ClassyVision hooks
for this engine
"""

setup_logging(__name__)
node_id = get_node_id(node_id)
dist_run_id = get_dist_run_id(cfg, cfg.DISTRIBUTED.NUM_NODES)
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
Expand Down Expand Up @@ -162,25 +141,40 @@ def _launch_distributed(
_cleanup_local_dir(cfg)

logging.info("All Done!")
shutdown_logging()


def launch_on_local_node(node_id: int, engine_name: str, config: AttrDict):
"""
Launch the distributed training on the current node.
def _distributed_worker(
local_rank: int,
cfg: AttrDict,
node_id: int,
dist_run_id: str,
engine_name: str,
checkpoint_path: str,
checkpoint_folder: str,
hook_generator: Callable[[Any], List[ClassyHook]],
):
dist_rank = cfg.DISTRIBUTED.NUM_PROC_PER_NODE * node_id + local_rank
if engine_name == "extract_features":
process_main = extract_main
else:

If more than 1 nodes are needed for training, this function should be called on each of the different nodes, each
time with an unique node_id in the range [0..N-1] if N is the total number of nodes to take part in training.
def process_main(cfg, dist_run_id, local_rank, node_id):
train_main(
cfg,
dist_run_id,
checkpoint_path,
checkpoint_folder,
local_rank=local_rank,
node_id=node_id,
hook_generator=hook_generator,
)

Alternatively, you can use SLURM or any cluster management system to run this function for you.
"""
setup_logging(__name__)
_launch_distributed(
config,
node_id=node_id,
engine_name=engine_name,
hook_generator=default_hook_generator,
logging.info(
f"Spawning process for node_id: {node_id}, local_rank: {local_rank}, "
f"dist_rank: {dist_rank}, dist_run_id: {dist_run_id}"
)
shutdown_logging()
process_main(cfg, dist_run_id, local_rank=local_rank, node_id=node_id)


class _ResumableSlurmJob:
Expand All @@ -190,63 +184,60 @@ def __init__(self, engine_name: str, config: AttrDict):

def __call__(self):
import submitit

environment = submitit.JobEnvironment()
node_id = environment.global_rank
master_ip = environment.hostnames[0]
master_port = self.config.SLURM.PORT_ID
self.config.DISTRIBUTED.INIT_METHOD = "tcp"
self.config.DISTRIBUTED.RUN_ID = f"{master_ip}:{master_port}"
launch_on_local_node(
launch_distributed(
cfg=self.config,
node_id=node_id,
engine_name=self.engine_name,
config=self.config,
hook_generator=default_hook_generator,
)

def checkpoint(self):
import submitit

trainer = _ResumableSlurmJob(engine_name=self.engine_name, config=self.config)
return submitit.helpers.DelayedSubmission(trainer)


def launch_on_slurm(engine_name: str, config: AttrDict):
def launch_distributed_on_slurm(cfg: AttrDict, engine_name: str):
"""
Launch a distributed training on SLURM, allocating the nodes and GPUs as described in the configuration, and calls
the function "launch_on_local_node" appropriately on each of the nodes.
Launch a distributed training on SLURM, allocating the nodes and GPUs as described in
the configuration, and calls the function "launch_on_local_node" appropriately on each
of the nodes.
:param engine_name: the name of the engine to run (train or extract_features)
:param config: the configuration of the experiment
Args:
cfg (AttrDict): the configuration of the experiment
engine_name (str): the name of the engine to run (train or extract_features)
"""

# DO NOT REMOVE: submitit processes will not be initialized correctly if numpy is not imported first
import numpy
print(numpy.__version__)

import submitit
log_folder = config.SLURM.LOG_FOLDER

# setup the log folder
log_folder = cfg.SLURM.LOG_FOLDER
makedir(log_folder)
assert PathManager.exists(
log_folder
), f"Specified config.SLURM.LOG_FOLDER={log_folder} doesn't exist"

executor = submitit.AutoExecutor(folder=log_folder)
executor.update_parameters(
name=config.SLURM.NAME,
slurm_comment=config.SLURM.COMMENT,
slurm_partition=config.SLURM.PARTITION,
slurm_constraint=config.SLURM.CONSTRAINT,
timeout_min=config.SLURM.TIME_HOURS * 60,
nodes=config.DISTRIBUTED.NUM_NODES,
cpus_per_task=8 * config.DISTRIBUTED.NUM_PROC_PER_NODE,
name=cfg.SLURM.NAME,
slurm_comment=cfg.SLURM.COMMENT,
slurm_partition=cfg.SLURM.PARTITION,
slurm_constraint=cfg.SLURM.CONSTRAINT,
timeout_min=cfg.SLURM.TIME_HOURS * 60,
nodes=cfg.DISTRIBUTED.NUM_NODES,
cpus_per_task=8 * cfg.DISTRIBUTED.NUM_PROC_PER_NODE,
tasks_per_node=1,
gpus_per_node=config.DISTRIBUTED.NUM_PROC_PER_NODE,
mem_gb=config.SLURM.MEM_GB,
gpus_per_node=cfg.DISTRIBUTED.NUM_PROC_PER_NODE,
mem_gb=cfg.SLURM.MEM_GB,
)
trainer = _ResumableSlurmJob(engine_name=engine_name, config=config)
trainer = _ResumableSlurmJob(engine_name=engine_name, config=cfg)
job = executor.submit(trainer)
print(f"SUBMITTED: {job.job_id}")


def is_submitit_available() -> bool:
"""
Indicates if submitit, the library around SLURM used to run distributed training, is available.
"""
try:
import submitit # NOQA
return True
except ImportError:
return False
13 changes: 13 additions & 0 deletions vissl/utils/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ def get_slurm_dir(input_dir: str):
if "SLURM_JOBID" in os.environ:
output_dir = f"{input_dir}/{os.environ['SLURM_JOBID']}"
return output_dir


def is_submitit_available() -> bool:
"""
Indicates if submitit, the library around SLURM used to run distributed training, is
available.
"""
try:
import submitit # NOQA

return True
except ImportError:
return False

0 comments on commit e94858c

Please sign in to comment.