Skip to content

Commit

Permalink
add phi_2 model & speed up ckpt load (mindspore-lab#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Jan 19, 2024
1 parent 02e1a85 commit 6d6698f
Show file tree
Hide file tree
Showing 26 changed files with 1,929 additions and 295 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ The table below represents the current support in the library for each of those
| CPM-Ant |||
| CPM-Bee |||
| EnCodec |||
| ERNIE |||
| ERNIE |||
| ERNIEM |||
| Falcon |||
| GLM |||
| OpenAI GPT |||
Expand Down
84 changes: 34 additions & 50 deletions mindnlp/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
if DEVICE_TARGET == 'Ascend':
GLOBAL_FP16_PATCH = True

old_set_context = mindspore.set_context
def _set_context(**kwargs):
if 'device_target' in kwargs and kwargs['device_target'] != 'Ascend':
set_global_fp16(False)
old_set_context(**kwargs)

mindspore.set_context = _set_context

def set_global_fp16(mode: bool):
"""set global fp16"""
global GLOBAL_FP16_PATCH
Expand All @@ -52,12 +60,14 @@ def fp16_patch_decorator(func):
"""fp16 patch on ascend"""
def wrapper(*args, **kwargs):
if GLOBAL_FP16_PATCH:
args = [arg.astype(mstype.float16) if arg is not None and isinstance(arg, Tensor) \
else arg for arg in args]
has_fp32 = any(bool(isinstance(arg, Tensor) and arg.dtype == mstype.float32) for arg in args)
args = (arg.astype(mstype.float16) if arg is not None and isinstance(arg, Tensor) \
else arg for arg in args)
kwargs = {k: (v.astype(mstype.float16) if v is not None and isinstance(v, Tensor) else v) \
for k, v in kwargs.items()}
result = func(*args, **kwargs)
result = result.astype(mstype.float32)
if has_fp32:
result = result.astype(mstype.float32)
return result
return func(*args, **kwargs)

Expand Down Expand Up @@ -512,8 +522,6 @@ def __init__(self,
in_channels,
out_channels,
has_bias=True,
weight_init='zeros',
bias_init='zeros',
dtype=mstype.float32):
"""Initialize Dense."""
super().__init__()
Expand All @@ -525,21 +533,14 @@ def __init__(self,
has_bias, "has_bias", self.cls_name)

self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels], dtype=dtype), name="weight")
HeUniform(math.sqrt(5)), [out_channels, in_channels], dtype=dtype), name="weight")

self.bias = None
if self.has_bias:
self.bias = Parameter(initializer(
bias_init, [out_channels], dtype=dtype), name="bias")
self.reset_parameters()

def reset_parameters(self):
"""reset_embedding_params"""
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape, self.weight.dtype))
if self.has_bias:
fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape)
bound = 1 / math.sqrt(fan_in)
self.bias.set_data(initializer(Uniform(bound), self.bias.shape, self.bias.dtype))
self.bias = Parameter(initializer(
Uniform(bound), [out_channels], dtype=dtype), name="bias")

def construct(self, x):
if LESS_MS_2_2:
Expand All @@ -557,7 +558,7 @@ def construct(self, x):

class Embedding(nn.Cell):
"""patched Embedding"""
def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=False, dtype=mstype.float32, weight_init='zeros'):
def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=False, dtype=mstype.float32):
"""Initialize Embedding."""
super().__init__()
self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
Expand All @@ -567,18 +568,9 @@ def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=Fal
self.use_one_hot = use_one_hot
self.dtype = dtype
self.padding_idx = padding_idx
self.weight = Parameter(initializer(weight_init, [vocab_size, embedding_size]), name='weight')
self.reset_parameters()

def reset_parameters(self):
"""reset_embedding_params"""
init_tensor = initializer(Normal(1.0), self.weight.shape, self.weight.dtype)
init_tensor = init_tensor.init_data()
if self.padding_idx:
init_tensor = init_tensor.asnumpy()
init_tensor[self.padding_idx] = 0
init_tensor = Tensor(init_tensor)
self.weight.assign_value(init_tensor)
self.weight = Parameter(initializer(Normal(1.0), [vocab_size, embedding_size]), name='weight')
if self.padding_idx and self.weight.init_flag:
self.weight[self.padding_idx] = 0

def construct(self, ids):
out_shape = ids.shape + (self.embedding_size,)
Expand Down Expand Up @@ -630,7 +622,6 @@ def __init__(self,
stride = (stride,)

dilation = (dilation,)

super().__init__(
in_channels,
out_channels,
Expand All @@ -641,22 +632,14 @@ def __init__(self,
dilation,
group,
has_bias,
'zeros',
'zeros')
None,
None)
self.padding = padding
self.reset_parameters()

def construct(self, x):
return ops.conv1d(x, self.weight, self.bias, stride=self.stride, pad_mode=self.pad_mode,
padding=self.padding, dilation=self.dilation, groups=self.group)

def reset_parameters(self):
"""reset_embedding_params"""
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape, self.weight.dtype))
if self.has_bias:
fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape)
bound = 1 / math.sqrt(fan_in)
self.bias.set_data(initializer(Uniform(bound), self.bias.shape, self.bias.dtype))

class Conv1dTranspose(_Conv):
"""patched Conv1dTranspose"""
Expand Down Expand Up @@ -699,8 +682,8 @@ def __init__(self,
dilation,
group,
has_bias,
weight_init,
bias_init,
None,
None,
transposed=True,
dtype=dtype)
self.kernel_size = kernel_size
Expand All @@ -720,7 +703,6 @@ def __init__(self,
stride=stride,
dilation=dilation,
group=group)
self.reset_parameters()

def construct(self, x):
x = x.expand_dims(2)
Expand All @@ -736,13 +718,6 @@ def construct(self, x):
output = output.squeeze(2)
return output

def reset_parameters(self):
"""reset_embedding_params"""
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape, self.weight.dtype))
if self.has_bias:
fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape)
bound = 1 / math.sqrt(fan_in)
self.bias.set_data(initializer(Uniform(bound), self.bias.shape, self.bias.dtype))

class LayerNorm(nn.Cell):
r"""
Expand Down Expand Up @@ -875,7 +850,7 @@ def extend_repr(self):

def half(self):
"""patched nn.Cell.half"""
self.to_float(mindspore.float16)
# self.to_float(mindspore.float16)
return self

nn.Cell.half = half
Expand All @@ -885,6 +860,15 @@ def _check_cell_flags_in_pynative(self):

nn.Cell._check_cell_flags_in_pynative = _check_cell_flags_in_pynative

def _update_parameters_name(self, prefix='', recurse=True):
for name, param in self.parameters_and_names(expand=recurse):
if prefix != '':
param.is_init = False
if param.name in name: # for tied weight
param.name = prefix + name

nn.Cell.update_parameters_name = _update_parameters_name

nn.LayerNorm = LayerNorm
nn.Conv1d = Conv1d
nn.Conv1dTranspose = Conv1dTranspose
Expand Down
Loading

0 comments on commit 6d6698f

Please sign in to comment.