Skip to content

Commit

Permalink
Check the weights exist at the very beginning of training (facebookre…
Browse files Browse the repository at this point in the history
…search#206)

Summary:
Pull Request resolved: facebookresearch#206

mbaroni shared feedback that in VISSL, currently if the user data loading step takes a long time and then weights file is not found, it's quite frustrating (understandably). Instead we should check weights file exists first before we start any training or data loading

Reviewed By: QuentinDuval

Differential Revision: D26726411

fbshipit-source-id: f0881da1ffe69e056395ca3508d153d40f85a87c
  • Loading branch information
prigoyal authored and facebook-github-bot committed Mar 1, 2021
1 parent 962be15 commit 3e85241
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions vissl/utils/distributed_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,19 @@ def launch_distributed(
node_id = get_node_id(node_id)
dist_run_id = get_dist_run_id(cfg, cfg.DISTRIBUTED.NUM_NODES)
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE

# set the environment variables including local rank, node id etc.
set_env_vars(local_rank=0, node_id=node_id, cfg=cfg)
_copy_to_local(cfg)

# given the checkpoint folder, we check that there's not already a final checkpoint
# and that if there already exists a final checkpoint and user is not overriding
# to ignore the final checkpoint
checkpoint_folder = get_checkpoint_folder(cfg)
if is_training_finished(cfg, checkpoint_folder=checkpoint_folder):
logging.info(f"Training already succeeded on node: {node_id}, exiting.")
return

# Get the checkpoint where to load from. The load_checkpoints function will
# Get the checkpoint where to resume from. The get_resume_checkpoint function will
# automatically take care of detecting whether it's a resume or not.
symlink_checkpoint_path = f"{checkpoint_folder}/checkpoint.torch"
if cfg.CHECKPOINT.USE_SYMLINK_CHECKPOINT_FOR_RESUME and PathManager.exists(
Expand All @@ -105,6 +108,18 @@ def launch_distributed(
cfg, checkpoint_folder=checkpoint_folder
)

# assert that if the user set the PARAMS_FILE, it must exist and be valid.
# we only use the PARAMS_FILE init if the checkpoint doesn't exist for the
# given training. This ensures that if the same training resumes, then it
# resumes from the checkpoint and not the weight init
if checkpoint_path is None and cfg["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]:
assert PathManager.exists(
cfg["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
), "Specified PARAMS_FILE does NOT exist"

# copy the data to local if user wants. This can speed up dataloading.
_copy_to_local(cfg)

try:
if world_size > 1:
torch.multiprocessing.spawn(
Expand Down

0 comments on commit 3e85241

Please sign in to comment.