Skip to content

Commit

Permalink
Migrate Trainer from Repository to upload_folder (huggingface#25095)
Browse files Browse the repository at this point in the history
* First draft

* Deal with progress bars

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Address review comments

* Forgot one

* Pin hf_hub

* Add argument for push all and fix tests

* Fix tests

* Address review comments

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
sgugger and Wauplin authored Aug 7, 2023
1 parent c177606 commit baf1daa
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 91 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.14.1,<1.0",
"huggingface-hub>=0.15.1,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.14.1,<1.0",
"huggingface-hub": "huggingface-hub>=0.15.1,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
182 changes: 116 additions & 66 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@

# isort: on

import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import Repository, create_repo
from huggingface_hub import Repository, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
Expand Down Expand Up @@ -127,6 +128,7 @@
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushInProgress,
can_return_loss,
find_labels,
is_accelerate_available,
Expand Down Expand Up @@ -548,15 +550,10 @@ def __init__(
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False

# Create clone of distant repo and output directory if needed
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_git_repo(at_init=True)
# In case of pull, we need to make sure every process has the latest.
if is_torch_tpu_available():
xm.rendezvous("init git repo")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()

self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)

Expand Down Expand Up @@ -1531,12 +1528,25 @@ def train(
inner_training_loop = find_executable_batch_size(
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
if args.push_to_hub:
try:
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
hf_hub_utils.disable_progress_bars()
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
finally:
hf_hub_utils.enable_progress_bars()
else:
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
Expand Down Expand Up @@ -1968,6 +1978,9 @@ def _inner_training_loop(

self.control = self.callback_handler.on_train_end(args, self.state, self.control)

# Wait for the checkpoint to be uploaded.
self._finish_current_push()

return TrainOutput(self.state.global_step, train_loss, metrics)

def _get_output_dir(self, trial):
Expand Down Expand Up @@ -3386,16 +3399,43 @@ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
else:
return 0

def init_hf_repo(self):
"""
Initializes a git repo in `self.args.hub_model_id`.
"""
# Only on process zero
if not self.is_world_process_zero():
return

if self.args.hub_model_id is None:
repo_name = Path(self.args.output_dir).absolute().name
else:
repo_name = self.args.hub_model_id

repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
self.hub_model_id = repo_url.repo_id
self.push_in_progress = None

def init_git_repo(self, at_init: bool = False):
"""
Initializes a git repo in `self.args.hub_model_id`.
<Tip warning={true}>
This function is deprecated and will be removed in v4.34.0 of Transformers.
</Tip>
Args:
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
out.
"""
warnings.warn(
"`Trainer.init_git_repo` is deprecated and will be removed in v4.34.0 of Transformers. Use "
"`Trainer.init_hf_repo` instead."
)
if not self.is_world_process_zero():
return

Expand Down Expand Up @@ -3493,8 +3533,8 @@ def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
return
# If we haven't finished the last push, we don't do this one.
if self.push_in_progress is not None and not self.push_in_progress.is_done:
# If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
return

output_dir = self.args.output_dir
Expand All @@ -3511,34 +3551,51 @@ def _push_from_checkpoint(self, checkpoint_folder):
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

try:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)

if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
# Return type of `Repository.push_to_hub` is either None or a tuple.
if push_work is not None:
self.push_in_progress = push_work[1]
except Exception as e:
logger.error(f"Error when pushing to hub: {e}")
finally:
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)
if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"

model_push_job = upload_folder(
repo_id=self.hub_model_id,
folder_path=output_dir,
commit_message=commit_message,
token=self.args.hub_token,
run_as_future=True,
ignore_patterns=["_*", "**/*"],
)

push_jobs = [model_push_job]

if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
path_in_repo = (
"last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
)
checkpoint_push = upload_folder(
repo_id=self.hub_model_id,
folder_path=checkpoint_folder,
path_in_repo=path_in_repo,
commit_message=commit_message + ", checkpoint",
token=self.args.hub_token,
run_as_future=True,
)
push_jobs.append(checkpoint_push)

if self.push_in_progress is None or self.push_in_progress.is_done():
self.push_in_progress = PushInProgress(push_jobs)
else:
self.push_in_progress.jobs.extend(push_jobs)

def _finish_current_push(self):
if not hasattr(self, "push_in_progress"):
return
if self.push_in_progress is not None and not self.push_in_progress.is_done():
logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
self.push_in_progress.wait_until_done()

def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (`str`, *optional*, defaults to `"End of training"`):
Expand All @@ -3549,21 +3606,20 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the commit and an object to track the progress of the commit if `blocking=True`
The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.
"""
# If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
# it might fail.
if not hasattr(self, "repo"):
self.init_git_repo()

model_name = kwargs.pop("model_name", None)
if model_name is None and self.args.should_save:
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]

# In case the user calls this method with args.push_to_hub = False
if self.hub_model_id is None:
self.init_hf_repo()

# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
self.save_model(_internal_call=True)
Expand All @@ -3572,25 +3628,19 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
if not self.is_world_process_zero():
return

# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
self.push_in_progress._process.kill()
self.push_in_progress = None
self.create_model_card(model_name=model_name, **kwargs)

git_head_commit_url = self.repo.push_to_hub(
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
)
# push separately the model card to be independant from the rest of the model
if self.args.should_save:
self.create_model_card(model_name=model_name, **kwargs)
try:
self.repo.push_to_hub(
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
)
except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
# Wait for the current upload to be finished.
self._finish_current_push()

return git_head_commit_url
return upload_folder(
repo_id=self.hub_model_id,
folder_path=self.args.output_dir,
commit_message=commit_message,
token=self.args.hub_token,
run_as_future=not blocking,
ignore_patterns=["_*", "**/*"],
)

#
# Deprecated code
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ class TrainingArguments:
`huggingface-cli login`.
hub_private_repo (`bool`, *optional*, defaults to `False`):
If True, the Hub repo will be set to private.
hub_always_push (`bool`, *optional*, defaults to `False`):
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -1110,6 +1112,10 @@ class TrainingArguments:
)
hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
hub_always_push: bool = field(
default=False,
metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
)
gradient_checkpointing: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -2367,6 +2373,7 @@ def set_push_to_hub(
strategy: Union[str, HubStrategy] = "every_save",
token: Optional[str] = None,
private_repo: bool = False,
always_push: bool = False,
):
"""
A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
Expand Down Expand Up @@ -2407,6 +2414,9 @@ def set_push_to_hub(
with `huggingface-cli login`.
private_repo (`bool`, *optional*, defaults to `False`):
If True, the Hub repo will be set to private.
always_push (`bool`, *optional*, defaults to `False`):
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
finished.
Example:
Expand All @@ -2424,6 +2434,7 @@ def set_push_to_hub(
self.hub_strategy = HubStrategy(strategy)
self.hub_token = token
self.hub_private_repo = private_repo
self.hub_always_push = always_push
return self

def set_optimizer(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
TRANSFORMERS_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
EntryNotFoundError,
PushInProgress,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
import traceback
import warnings
from concurrent import futures
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -1177,6 +1178,29 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
)


class PushInProgress:
"""
Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
"""

def __init__(self, jobs: Optional[futures.Future] = None) -> None:
self.jobs = [] if jobs is None else jobs

def is_done(self):
return all(job.done() for job in self.jobs)

def wait_until_done(self):
futures.wait(self.jobs)

def cancel(self) -> None:
self.jobs = [
job
for job in self.jobs
# Cancel the job if it wasn't started yet and remove cancelled/done jobs from the list
if not (job.cancel() or job.done())
]


cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
if not os.path.isfile(cache_version_file):
cache_version = 0
Expand Down
Loading

0 comments on commit baf1daa

Please sign in to comment.