forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download.py
155 lines (140 loc) · 6.25 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import textwrap
from pathlib import Path
from typing import Literal, Union
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from torchtune._cli.subcommand import Subcommand
class Download(Subcommand):
"""Holds all the logic for the `tune download` subcommand."""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self._parser = subparsers.add_parser(
"download",
prog="tune download",
usage="tune download <repo-id> [OPTIONS]",
help="Download a model from the Hugging Face Hub.",
description="Download a model from the Hugging Face Hub.",
epilog=textwrap.dedent(
"""\
examples:
# Download a model from the Hugging Face Hub with a Hugging Face API token
$ tune download meta-llama/Llama-2-7b-hf --hf-token <TOKEN>
Successfully downloaded model repo and wrote to the following locations:
/tmp/Llama-2-7b-hf/config.json
/tmp/Llama-2-7b-hf/README.md
/tmp/Llama-2-7b-hf/consolidated.00.pth
...
# Download an ungated model from the Hugging Face Hub
$ tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/model
Successfully downloaded model repo and wrote to the following locations:
/tmp/model/config.json
/tmp/model/README.md
/tmp/model/model-00001-of-00002.bin
...
For a list of all models, visit the Hugging Face Hub https://huggingface.co/models.
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self._parser.set_defaults(func=self._download_cmd)
def _add_arguments(self) -> None:
"""Add arguments to the parser."""
self._parser.add_argument(
"repo_id",
type=str,
help="Name of the repository on Hugging Face Hub.",
)
self._parser.add_argument(
"--output-dir",
type=Path,
required=False,
default=None,
help="Directory in which to save the model. Defaults to `/tmp/<model_name>`.",
)
self._parser.add_argument(
"--output-dir-use-symlinks",
type=str,
required=False,
default="auto",
help=(
"To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be"
" either duplicated or symlinked to the local directory depending on its size. It set to `True`, a"
" symlink will be created, no matter the file size. If set to `False`, the file will either be"
" duplicated from cache (if already exists) or downloaded from the Hub and not cached."
),
)
self._parser.add_argument(
"--hf-token",
type=str,
required=False,
default=os.getenv("HF_TOKEN", None),
help="Hugging Face API token. Needed for gated models like Llama2.",
)
self._parser.add_argument(
"--ignore-patterns",
type=str,
required=False,
default="*.safetensors",
help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring "
"safetensors files to avoid downloading duplicate weights.",
)
def _download_cmd(self, args: argparse.Namespace) -> None:
"""Downloads a model from the Hugging Face Hub."""
# Download the tokenizer and PyTorch model files
# Default output_dir is `/tmp/<model_name>`
output_dir = args.output_dir
if output_dir is None:
model_name = args.repo_id.split("/")[-1]
output_dir = Path("/tmp") / model_name
# Raise if local_dir_use_symlinks is invalid
output_dir_use_symlinks: Union[Literal["auto"], bool]
use_symlinks_lowercase = args.output_dir_use_symlinks.lower()
if use_symlinks_lowercase == "true":
output_dir_use_symlinks = True
elif use_symlinks_lowercase == "false":
output_dir_use_symlinks = False
elif use_symlinks_lowercase == "auto":
output_dir_use_symlinks = "auto"
else:
self._parser.error(
f"'{args.output_dir_use_symlinks}' is not a valid value for `--output-dir-use-symlinks`. It must be either"
" 'auto', 'True' or 'False'."
)
print(f"Ignoring files matching the following patterns: {args.ignore_patterns}")
try:
true_output_dir = snapshot_download(
args.repo_id,
local_dir=output_dir,
local_dir_use_symlinks=output_dir_use_symlinks,
ignore_patterns=args.ignore_patterns,
token=args.hf_token,
)
except GatedRepoError:
self._parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
self._parser.error(
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
)
except Exception as e:
import traceback
tb = traceback.format_exc()
msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}"
self._parser.error(msg)
print(
"Successfully downloaded model repo and wrote to the following locations:",
*list(Path(true_output_dir).iterdir()),
sep="\n",
)