Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update baichuan, support 7b & 13b #766

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update baichuan, support 7b & 13b
  • Loading branch information
lvyufeng committed Dec 1, 2023
commit 798ab48132df4efabc17f2d1918de427414db505
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ The table below represents the current support in the library for each of those
| Model | Pynative support | Graph Support |
|-------------------------------|------------------|---------------|
| ALBERT | ✅ | ✅ |
| Autoformer | TODO | ❌ |
| Autoformer | TODO | ❌ |
| BaiChuan | ✅ | ❌ |
| Bark | TODO | ❌ |
| BART | ✅ | ❌ |
| BERT | ✅ | ✅ |
| BLOOM | ✅ | ❌ |
| CLIP | ✅ | ❌ |
| CodeGen | ✅ | ❌ |
| ConvBERT | TODO | ❌ |
| ConvBERT | TODO | ❌ |
| CPM | ✅ | ❌ |
| CPM-Ant | ✅ | ❌ |
| CPM-Bee | ✅ | ❌ |
Expand Down
88 changes: 65 additions & 23 deletions mindnlp/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
"""
Injection mindspore.nn for MindNLP
"""
import operator
from functools import reduce, partial
import math
from functools import partial
from packaging import version
import mindspore
import mindspore.common.dtype as mstype
Expand Down Expand Up @@ -84,6 +85,34 @@ def wrapper(*args, **kwargs):

return wrapper

def _get_unflatten_size(input_shape, dim, sizes):
input_rank = len(input_shape)
if not isinstance(sizes, (tuple, list)):
raise TypeError(f"Type of `sizes` should be `Tuple` or `List`, but got {type(sizes)}")

if len(sizes) == 0:
raise ValueError("`sizes` must be non-empty")

if isinstance(dim, str):
raise TypeError("Until Now, `dim` not support type of str in `unflatten`")

_dim = dim
if _dim < 0:
_dim += input_rank

if _dim < 0 or _dim >= input_rank:
raise ValueError(f"`dim` should be in range [{-input_rank}, {input_rank}), but got {input_rank, dim}")

_sizes_mul = reduce(operator.mul, list(sizes))
if -1 not in sizes and _sizes_mul != input_shape[_dim]:
raise ValueError(f"unflatten: Provided `sizes` {sizes} don't multiply up to the"
f"size of dim {dim} ({input_shape[_dim]}) in the input tensor")

out_shape = input_shape[:_dim] + tuple(sizes) + input_shape[_dim + 1:]
return out_shape

# For all backend
# For functional api
# matmul
origin_matmul = ops.matmul
ops.matmul = fp16_patch_decorator(origin_matmul)
Expand Down Expand Up @@ -119,6 +148,7 @@ def dense(input, weight, bias=None):
# conv1d
ops.conv1d = fp16_patch_decorator(ops.conv1d)

# for Tensor
# unfold
def _get_unfold_indices(input_shape, dimension, size, step):
if dimension < 0:
Expand Down Expand Up @@ -219,6 +249,38 @@ def _contains(self, key):
Tensor.__contains__ = _contains
StubTensor.__contains__ = _contains

def unflatten(self, dim, sizes):
"""Tensor.unflatten"""
out_shape = _get_unflatten_size(self.shape, dim, sizes)
return self.reshape(out_shape)

Tensor.unflatten = unflatten
StubTensor.unflatten = unflatten

if version.parse(mindspore.__version__) < version.parse('2.2.0'):
def eq(self, other):
"""patched eq"""
return ops.equal(self, other)
Tensor.eq = eq
StubTensor.eq = eq


def _eq(self, other):
if not isinstance(other, (int, float, Tensor)):
return False
if isinstance(other, Tensor) and self.shape != other.shape:
return False
if id(self) == id(other):
return True
# bool type is not supported for `Equal` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
self = self.to(mstype.int32)
other = other.to(mstype.int32)
return ops.eq(self, other)

Parameter.__eq__ = _eq

# Ascend only
if DEVICE_TARGET == 'Ascend':
# cumsum
ops.cumsum = int32_patch_decorator(ops.cumsum)
Expand Down Expand Up @@ -266,6 +328,7 @@ def bitwise_and(self, other):
ops.cat = bool_patch_decorator(ops.cat)
ops.concat = bool_patch_decorator(ops.concat)

# GPU only
def custom_multinomial(probabilities, num_samples, replacement=True):
"""custom multinomial"""
if replacement:
Expand All @@ -290,29 +353,8 @@ def custom_multinomial(probabilities, num_samples, replacement=True):
if DEVICE_TARGET == 'GPU':
ops.multinomial = custom_multinomial

if version.parse(mindspore.__version__) < version.parse('2.2.0'):
def eq(self, other):
"""patched eq"""
return ops.equal(self, other)
Tensor.eq = eq
StubTensor.eq = eq


def _eq(self, other):
if not isinstance(other, (int, float, Tensor)):
return False
if isinstance(other, Tensor) and self.shape != other.shape:
return False
if id(self) == id(other):
return True
# bool type is not supported for `Equal` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
self = self.to(mstype.int32)
other = other.to(mstype.int32)
return ops.eq(self, other)

Parameter.__eq__ = _eq

# For Cells
class Dense(nn.Cell):
"""patched Dense"""
def __init__(self,
Expand Down
Loading