Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nezha from pretrained #556

Merged
merged 1 commit into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mindnlp/models/nezha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
model nezha init
"""

from . import nezha
from . import nezha_config
from .nezha import *
from .nezha_config import *

__all__ = []
__all__.extend(nezha_config.__all__)
__all__.extend(nezha.__all__)
37 changes: 35 additions & 2 deletions mindnlp/models/nezha/nezha.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,33 @@
# limitations under the License.
# ============================================================================
# pylint: disable=C0103
# pylint: disable=E0401

"""nezha model"""
import math

import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
from mindspore import log as logger
from .nezha_config import NezhaConfig
from mindspore.common.initializer import initializer, Normal
from mindnlp.configs import MINDNLP_MODEL_URL_BASE
from .nezha_config import NezhaConfig, NEZHA_SUPPORT_LIST
from ...abc import PreTrainedModel
from ..utils.utils import prune_linear_layer, find_pruneable_heads_and_indices, apply_chunking_to_forward
from ..utils.activations import ACT2FN

PRETRAINED_MODEL_ARCHIVE_MAP = {
model: MINDNLP_MODEL_URL_BASE.format('nezha', model) for model in NEZHA_SUPPORT_LIST
}

__all__ = ['NezhaRelativePositionsEncoding', 'NezhaEmbeddings', 'NezhaSelfAttention',
'NezhaSelfOutput', 'NezhaAttention', 'NezhaIntermediate', 'NezhaOutput',
'NezhaLayer', 'NezhaEncoder', 'NezhaPooler', 'NezhaPredictionHeadTransform',
'NezhaLMPredictionHead', 'NezhaOnlyMLMHead', 'NezhaOnlyNSPHead', 'NezhaModel',
'NezhaPreTrainingHeads', 'NezhaForPreTraining', 'NezhaForMaskedLM',
'NezhaForNextSentencePrediction', 'NezhaForSequenceClassification',
'NezhaForMultipleChoice', 'NezhaForTokenClassification', 'NezhaForQuestionAnswering']

class NezhaRelativePositionsEncoding(nn.Cell):
"""Implement the Functional Relative Position Encoding"""
Expand Down Expand Up @@ -588,8 +602,27 @@ class NezhaPreTrainedModel(PreTrainedModel):
config_class = NezhaConfig
base_model_prefix = "nezha"
supports_gradient_checkpointing = True
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
_keys_to_ignore_on_load_missing = [r"positions_encoding"]

def _init_weights(self, cell):
"""Initialize the weights"""
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(Normal(self.config.initializer_range),
cell.weight.shape, cell.weight.dtype))
if cell.has_bias:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.Embedding):
embedding_table = initializer(Normal(self.config.initializer_range),
cell.embedding_table.shape,
cell.embedding_table.dtype)
if cell.padding_idx is not None:
embedding_table[cell.padding_idx] = 0
cell.embedding_table.set_data(embedding_table)
elif isinstance(cell, nn.LayerNorm):
cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))

# TODO
def get_input_embeddings(self):
pass
Expand Down
11 changes: 7 additions & 4 deletions mindnlp/models/nezha/nezha_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
__all__ = ["NezhaConfig"]

NEZHA_SUPPORT_LIST = [
"sijunhe/nezha-cn-base",
"sijunhe/nezha-cn-large",
"sijunhe/nezha-base-wwm",
"sijunhe/nezha-large-wwm"
"nezha-cn-base",
"nezha-cn-large",
"nezha-base-wwm",
"nezha-large-wwm"
]

CONFIG_ARCHIVE_MAP = {
Expand All @@ -35,6 +35,9 @@ class NezhaConfig(PreTrainedConfig):
"""
Configuration for Nezha
"""

pretrained_config_archive_map = CONFIG_ARCHIVE_MAP

def __init__(
self,
vocab_size=21128,
Expand Down
10 changes: 5 additions & 5 deletions mindnlp/transforms/tokenizers/nezha_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from mindnlp.configs import HF_VOCAB_URL_BASE

PRETRAINED_VOCAB_MAP = {
model: HF_VOCAB_URL_BASE.format(model) for model in NEZHA_SUPPORT_LIST
model: HF_VOCAB_URL_BASE.format("sijunhe/" + model) for model in NEZHA_SUPPORT_LIST
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"sijunhe/nezha-cn-base": 512,
"sijunhe/nezha-cn-large": 512,
"sijunhe/nezha-base-wwm": 512,
"sijunhe/nezha-large-wwm": 512
"nezha-cn-base": 512,
"nezha-cn-large": 512,
"nezha-base-wwm": 512,
"nezha-large-wwm": 512
}

class NezhaTokenizer(PreTrainedTokenizer):
Expand Down
4 changes: 2 additions & 2 deletions tests/ut/transforms/test_nezha_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_nezha_tokenizer_from_pretrained():
texts = ['i make a small mistake when i\'m working! 床前明月光']
test_dataset = GeneratorDataset(texts, 'text')

bert_tokenizer = NezhaTokenizer.from_pretrained('sijunhe/nezha-cn-base', return_token=True)
bert_tokenizer = NezhaTokenizer.from_pretrained('nezha-cn-base', return_token=True)
test_dataset = test_dataset.map(operations=bert_tokenizer)
dataset_after = next(test_dataset.create_tuple_iterator())[0]

Expand All @@ -33,7 +33,7 @@ def test_nezha_tokenizer_from_pretrained():

def test_nezha_tokenizer_add_special_tokens():
"""test add special tokens."""
nezha_tokenizer = NezhaTokenizer.from_pretrained('sijunhe/nezha-cn-base')
nezha_tokenizer = NezhaTokenizer.from_pretrained('nezha-cn-base')
cls_id = nezha_tokenizer.token_to_id("[CLS]")

assert cls_id is not None