Skip to content

Commit

Permalink
support llama2 (mindspore-lab#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Nov 23, 2023
1 parent 6743e64 commit 1dcd282
Show file tree
Hide file tree
Showing 29 changed files with 3,410 additions and 2,333 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ The table below represents the current support in the library for each of those
| BLOOM |||
| CLIP |||
| CodeGen |||
| CodeLlama |||
| ConvBERT | TODO ||
| CPM |||
| CPM-Ant |||
Expand All @@ -113,8 +112,9 @@ The table below represents the current support in the library for each of those
| GPT NeoX | TODO ||
| GPTBigCode |||
| Graphormer | TODO ||
| LLaMA |||
| Llama2 | TODO ||
| Llama |||
| Llama2 |||
| CodeLlama |||
| Longformer |||
| LongT5 | TODO ||
| LUKE |||
Expand Down
2 changes: 2 additions & 0 deletions mindnlp/_legacy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,8 @@ def sumproduct_pair(left_, right_, sum_dims_, keep_dim_):
ELLIPSIS = 52

def einsum(equation, *operands):
if mindspore.get_context('device_target') == 'GPU':
return _get_cache_prim(ops.Einsum)(equation)(operands)
assert operands, "einsum(): must provide at least one operand"

arrow_pos = equation.find("->")
Expand Down
9 changes: 9 additions & 0 deletions mindnlp/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ def _std(self, axis=None, ddof=0, keepdims=False):
Tensor.std = _std
StubTensor.std = _std

# Tensor.__contains__
def _contains(self, key):
eq_res = ops.equal(self, key)
res = ops.any(eq_res)
return bool(res)

Tensor.__contains__ = _contains
StubTensor.__contains__ = _contains

if DEVICE_TARGET == 'Ascend':
# cumsum
ops.cumsum = int32_patch_decorator(ops.cumsum)
Expand Down
1 change: 1 addition & 0 deletions mindnlp/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,7 @@ def greedy_search(

# argmax
next_tokens = ops.argmax(next_tokens_scores, dim=-1)

# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _make_causal_mask(
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = ops.full((tgt_len, tgt_len), np.finfo(mindspore.dtype_to_nptype(dtype)).min)
mask = ops.full((tgt_len, tgt_len), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min))
mask_cond = ops.arange(mask.shape[-1])
mask = mask.masked_fill(mask_cond < (mask_cond + 1).view(mask.shape[-1], 1), 0)

Expand All @@ -139,7 +139,7 @@ def _make_causal_mask(
diagonal = past_key_values_length - sliding_window + 1

context_mask = 1 - ops.triu(ops.ones_like(mask, dtype=mindspore.int32), diagonal=diagonal)
mask = mask.masked_fill(context_mask.bool(), np.finfo(mindspore.dtype_to_nptype(dtype)).min)
mask = mask.masked_fill(context_mask.bool(), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min))

return mask[None, None, :, :].broadcast_to((bsz, 1, tgt_len, tgt_len + past_key_values_length))

Expand All @@ -155,7 +155,7 @@ def _expand_mask(mask: mindspore.Tensor, dtype, tgt_len: Optional[int] = None):

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(mindspore.bool_), np.finfo(mindspore.dtype_to_nptype(dtype)).min)
return inverted_mask.masked_fill(inverted_mask.to(mindspore.bool_), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min))


def _prepare_4d_causal_attention_mask(
Expand Down
53 changes: 37 additions & 16 deletions mindnlp/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,19 +670,27 @@ def from_pretrained(
if from_pt and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME)
):
# Load from a TF 2.0 checkpoint in priority if from_tf
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
# Load from a MindSpore checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif from_pt and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
# Load from a sharded MindSpore checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
Expand Down Expand Up @@ -819,6 +827,7 @@ def load_ckpt(resolved_archive_file):
keys_missing = list(model.parameters_dict().keys())
param_id_set = set()


def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str):
keys_unexpected = list(param_dict.keys())

Expand Down Expand Up @@ -850,10 +859,9 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str):
f'\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.')
logger.warning(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter'
f' `{param_name} with shape {new_param.shape} in checkpoint, ')
param = Parameter(new_param, param.name)
param = Parameter(new_param.data, param.name)
else:
param.set_dtype(new_param.dtype)
param.assign_value(new_param)
param.set_data(new_param)
keys_unexpected.remove(param_name)
keys_missing.remove(param.name)
param_id_set.add(id(param))
Expand Down Expand Up @@ -1308,10 +1316,16 @@ def convert_torch_to_mindspore(pth_file):
"`pip install torch` or instructions from 'https://pytorch.org'") \
from exc

ms_ckpt_path = pth_file.replace('pytorch_model', 'mindspore')
ms_ckpt_path = ms_ckpt_path.replace('.bin', '.ckpt')
if os.path.exists(ms_ckpt_path):
return ms_ckpt_path

logger.info('Starting checkpoint conversion.')
ms_ckpt = []
state_dict = torch.load(pth_file, map_location=torch.device('cpu'))
state_dict = torch.load(pth_file, map_location='cpu')

has_bf16 = False
for key, value in state_dict.items():
if 'LayerNorm' in key or 'layer_norm' in key or 'ln' in key:
if '.weight' in key:
Expand All @@ -1323,16 +1337,23 @@ def convert_torch_to_mindspore(pth_file):
'embed_' in key or '_embed' in key and \
'embedding_hidden_mapping_in' not in key: # for albert
key = key.replace('weight', 'embedding_table')
ms_ckpt.append({'name': key, 'data': Tensor(value.numpy())})

ms_ckpt_path = pth_file.replace('pytorch_model', 'mindspore')
ms_ckpt_path = ms_ckpt_path.replace('.bin', '.ckpt')
if not os.path.exists(ms_ckpt_path):
try:
save_checkpoint(ms_ckpt, ms_ckpt_path)
except Exception as exc:
raise RuntimeError(f'Save checkpoint to {ms_ckpt_path} failed, '
f'please checkout the path.') from exc
if value.dtype == torch.bfloat16:
data = Tensor(value.to(torch.float).numpy())
if not has_bf16:
has_bf16 = True
else:
data = Tensor(value.numpy())
ms_ckpt.append({'name': key, 'data': data})

if has_bf16:
logger.warning("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16")

try:
save_checkpoint(ms_ckpt, ms_ckpt_path)
except Exception as exc:
raise RuntimeError(f'Save checkpoint to {ms_ckpt_path} failed, '
f'please checkout the path.') from exc

return ms_ckpt_path

Expand Down
2 changes: 1 addition & 1 deletion mindnlp/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def construct(

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * np.finfo(mindspore.dtype_to_nptype(self.dtype)).min
extended_attention_mask = (1.0 - extended_attention_mask) * mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(self.dtype)).min)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
Expand Down
20 changes: 1 addition & 19 deletions mindnlp/transformers/models/baichuan/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _expand_mask(mask: Tensor, dtype: mstype, tgt_len: Optional[int] = None):

return inverted_mask.masked_fill(
inverted_mask.to(mindspore.bool_),
np.finfo(mindspore.dtype_to_nptype(dtype)).min,
mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min),
)


Expand Down Expand Up @@ -117,7 +117,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.inv_freq = 1.0 / (base ** (ops.arange(0, dim, 2).float() / dim))

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = ops.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
freqs = ops.einsum("i,j->ij", t, self.inv_freq)
Expand Down Expand Up @@ -507,23 +506,6 @@ def construct(
past_key_value = past_key_values[idx] if past_key_values is not None else None

# TODO: how checkpoint
# if self.gradient_checkpointing and self.training:

# def create_custom_forward(module):
# def custom_forward(*inputs):
# # None for past_key_value
# return module(*inputs, output_attentions, None)

# return custom_forward

# layer_outputs = torch.utils.checkpoint.checkpoint(
# create_custom_forward(decoder_layer),
# hidden_states,
# attention_mask,
# position_ids,
# None,
# )
# else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/transformers/models/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _expand_mask(mask: Tensor, dtype: mstype, tgt_len: Optional[int] = None):

return inverted_mask.masked_fill(
inverted_mask.to(mindspore.bool_),
np.finfo(mindspore.dtype_to_nptype(dtype)).min,
Tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min),
)


Expand Down
5 changes: 0 additions & 5 deletions mindnlp/transformers/models/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
from .codegen_config import CodeGenConfig
from ...activations import ACT2FN


_CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
_CONFIG_FOR_DOC = "CodeGenConfig"


#
def fixed_pos_embedding(tensor, seq_dim=1, seq_len=None):
"""
Expand Down
21 changes: 0 additions & 21 deletions mindnlp/transformers/models/codellama/__init__.py

This file was deleted.

Loading

0 comments on commit 1dcd282

Please sign in to comment.