Skip to content

Commit

Permalink
add download url (mindspore-lab#560)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Jun 6, 2023
1 parent 05dcdc1 commit 158ed0f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
22 changes: 7 additions & 15 deletions mindnlp/abc/models/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

if resolved_archive_file is None:
base_url = '/'.join(archive_file.split('/')[:-1])
archive_file = base_url + '/' + HF_WEIGHTS_INDEX_NAME
archive_file = base_url + '/' + HF_WEIGHTS_INDEX_NAME if from_pt else \
base_url + '/' + WEIGHTS_INDEX_NAME

resolved_archive_file = str(cached_path(
archive_file,
Expand All @@ -385,22 +386,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
subfolder=folder_name
)
is_sharded = True
else:
raise EnvironmentError(f"Couldn't reach server at '{archive_file}' to download pretrained weights.")

except EnvironmentError as exc:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
else:
format1 = ", ".join(
cls.pretrained_model_archive_map.keys())
format2 = ["mindspore.ckpt"]
msg = (
f"Model name '{pretrained_model_name_or_path}' "
f"was not found in model name list ({format1}). "
f"We assumed '{archive_file}' "
f"was a path or url to model weight files named one of {format2} but "
f"couldn't find any such file at this path or url."
)
raise EnvironmentError(msg) from exc
raise exc

if resolved_archive_file == archive_file:
logger.info("loading weights file %s", archive_file)
Expand All @@ -424,6 +414,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
else:
resolved_archive_file = cls.convert_torch_to_mindspore(
str(resolved_archive_file), prefix=cls.base_model_prefix)
else:
converted_filenames = cached_filenames

def load_ckpt(resolved_archive_file):
try:
Expand Down
11 changes: 5 additions & 6 deletions mindnlp/models/glm/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@
from mindnlp.generation.stopping_criteria import StoppingCriteriaList
from mindnlp.abc import GenerationConfig
from mindnlp.modules import functional as F
from mindnlp.configs import MINDNLP_MODEL_URL_BASE
from .chatglm_config import ChatGLMConfig


CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
"THUDM/chatglm-6b",
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
]
PRETRAINED_MODEL_ARCHIVE_MAP = {
'chatglm-6b': MINDNLP_MODEL_URL_BASE.format('glm', 'chatglm-6b')
}


def torch_to_mindspore(pth_file, **kwargs):
Expand Down Expand Up @@ -590,7 +589,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
config_class = ChatGLMConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLMBlock"]
pretrained_model_archive_map = CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
convert_torch_to_mindspore = torch_to_mindspore

def _init_weights(self, cell: nn.Cell):
Expand Down
5 changes: 5 additions & 0 deletions mindnlp/models/glm/chatglm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
""" ChatGLM model configuration """

from mindnlp.abc import PreTrainedConfig
from mindnlp.configs import MINDNLP_CONFIG_URL_BASE

CONFIG_ARCHIVE_MAP = {
'chatglm-6b': MINDNLP_CONFIG_URL_BASE.format('glm', 'chatglm-6b')
}

class ChatGLMConfig(PreTrainedConfig):
r"""
Expand Down Expand Up @@ -66,6 +70,7 @@ class ChatGLMConfig(PreTrainedConfig):
```
"""
model_type = "chatglm"
pretrained_config_archive_map = CONFIG_ARCHIVE_MAP

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions mindnlp/transforms/tokenizers/chatglm_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from mindnlp.utils.generic import PaddingStrategy

PRETRAINED_VOCAB_MAP = {
"THUDM/chatglm-6b": "https://huggingface.co/THUDM/chatglm-6b/resolve/main/ice_text.model"
'chatglm-6b': 'https://download.mindspore.cn/toolkits/mindnlp/models/glm/chatglm-6b/ice_text.model'
}


PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"THUDM/chatglm-6b": 2048,
"chatglm-6b": 2048,
}

class TextTokenizer:
Expand Down
4 changes: 2 additions & 2 deletions tests/ut/models/glm/test_modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def ids_tensor(shape, vocab_size):

def get_model_and_tokenizer():
"""get model and tokenizer"""
model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b", from_pt=True)
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b")
model = ChatGLMForConditionalGeneration.from_pretrained("chatglm-6b")
tokenizer = ChatGLMTokenizer.from_pretrained("chatglm-6b")
return model, tokenizer


Expand Down

0 comments on commit 158ed0f

Please sign in to comment.