Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Ernie's config and model for "from_pretrained" #520

Merged
merged 1 commit into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mindnlp/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
MINDNLP_CONFIG_URL_BASE = "https://download.mindspore.cn/toolkits/mindnlp/models/{}/{}/config.json"
MINDNLP_MODEL_URL_BASE = "https://download.mindspore.cn/toolkits/mindnlp/models/{}/{}/mindspore.ckpt"
MINDNLP_TOKENIZER_CONFIG_URL_BASE = "https://download.mindspore.cn/toolkits/mindnlp/models/{}/{}/tokenizer.json"
MINDNLP_VOCAB_URL_BASE = "https://download.mindspore.cn/toolkits/mindnlp/models/{}/{}/vocab.txt"
102 changes: 69 additions & 33 deletions mindnlp/models/ernie/ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@
Ernie Models
"""

import re
from typing import Optional, Tuple

import mindspore
from mindspore import nn, ops, Tensor
from mindspore.common.initializer import TruncatedNormal
from mindnlp.abc import PreTrainedModel
from .ernie_config import ErnieConfig, ERNIE_PRETRAINED_INIT_CONFIGURATION, ERNIE_PRETRAINED_RESOURCE_FILES_MAP
from mindnlp.configs import MINDNLP_MODEL_URL_BASE
from .ernie_config import ErnieConfig, ERNIE_SUPPORT_LIST


__all__ = ['ErnieEmbeddings', 'ErnieModel', 'ErniePooler', "UIE"]
PRETRAINED_MODEL_ARCHIVE_MAP = {
model: MINDNLP_MODEL_URL_BASE.format(re.search(r"^[^-]*", model).group(), model)
for model in ERNIE_SUPPORT_LIST
}

__all__ = ["ErnieEmbeddings", "ErnieModel", "ErniePooler", "UIE"]


class ErnieEmbeddings(nn.Cell):
"""
Expand All @@ -37,21 +45,30 @@ def __init__(self, config: ErnieConfig, embedding_table):
super().__init__()

self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, embedding_table=embedding_table
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id,
embedding_table=embedding_table,
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, embedding_table=embedding_table
config.max_position_embeddings,
config.hidden_size,
embedding_table=embedding_table,
)
self.type_vocab_size = config.type_vocab_size
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size, embedding_table=embedding_table
config.type_vocab_size,
config.hidden_size,
embedding_table=embedding_table,
)
self.use_task_id = config.use_task_id
self.task_id = config.task_id
if self.use_task_id:
self.task_type_embeddings = nn.Embedding(
config.task_type_vocab_size, config.hidden_size, embedding_table=embedding_table
config.task_type_vocab_size,
config.hidden_size,
embedding_table=embedding_table,
)
self.layer_norm = nn.LayerNorm([config.hidden_size])
self.dropout = nn.Dropout(config.hidden_dropout_prob, p=0.5)
Expand All @@ -65,7 +82,6 @@ def construct(
inputs_embeds: Optional[Tensor] = None,
past_key_values_length: int = 0,
):

if input_ids is not None:
inputs_embeds = self.word_embeddings(input_ids)

Expand All @@ -92,8 +108,7 @@ def construct(

if self.use_task_id:
if task_type_ids is None:
task_type_ids = ops.ones(
input_shape, mindspore.int64) * self.task_id
task_type_ids = ops.ones(input_shape, mindspore.int64) * self.task_id
task_type_embeddings = self.task_type_embeddings(task_type_ids)
embeddings = embeddings + task_type_embeddings
embeddings = self.layer_norm(embeddings)
Expand All @@ -106,22 +121,22 @@ class ErniePretrainedModel(PreTrainedModel):
Ernie Pretrained Model.
"""

pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION
pretrained_resource_files_map = ERNIE_PRETRAINED_RESOURCE_FILES_MAP
config_class = ErnieConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP

# TODO
def get_input_embeddings(self):
pass

#TODO
# TODO
def get_position_embeddings(self):
pass

#TODO
# TODO
def resize_position_embeddings(self):
pass

#TODO
# TODO
def set_input_embeddings(self):
pass

Expand Down Expand Up @@ -153,10 +168,12 @@ class ErniePooler(nn.Cell):
"""
Ernie Pooler.
"""

def __init__(self, config: ErnieConfig, weight_init):
super().__init__()
self.dense = nn.Dense(config.hidden_size,
config.hidden_size, weight_init=weight_init)
self.dense = nn.Dense(
config.hidden_size, config.hidden_size, weight_init=weight_init
)
self.activation = nn.Tanh()

def construct(self, hidden_states):
Expand All @@ -180,18 +197,17 @@ def __init__(self, config: ErnieConfig):
self.nheads = config.num_attention_heads
embedding_table = TruncatedNormal(sigma=self.initializer_range)
self.embeddings = ErnieEmbeddings(
config=config, embedding_table=embedding_table)
config=config, embedding_table=embedding_table
)
encoder_layer = nn.TransformerEncoderLayer(
config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation=config.hidden_act,
batch_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer, config.num_hidden_layers
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers)
self.pooler = ErniePooler(config, weight_init=embedding_table)
self.apply(self.init_weights)

Expand All @@ -217,39 +233,57 @@ def construct(
):
batch_size, seq_length = input_ids.shape

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time.")
"You cannot specify both input_ids and inputs_embeds at the same time."
)

# init the default bool value
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
output_attentions = (
output_attentions if output_attentions is not None else False
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else False
)
return_dict = return_dict if return_dict is not None else False
# use_cache = use_cache if use_cache is not None else False
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]

if attention_mask is None:
attention_mask = ((input_ids == self.pad_token_id).astype(
self.pooler.dense.weight.dtype) * -1e4).unsqueeze(1).unsqueeze(2)
attention_mask = (
(
(input_ids == self.pad_token_id).astype(
self.pooler.dense.weight.dtype
)
* -1e4
)
.unsqueeze(1)
.unsqueeze(2)
)

if past_key_values is not None:
batch_size = past_key_values[0][0].shape[0]
past_mask = ops.zeros(
[batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype)
attention_mask = ops.concat(
[past_mask, attention_mask], axis=-1)
[batch_size, 1, 1, past_key_values_length],
dtype=attention_mask.dtype,
)
attention_mask = ops.concat([past_mask, attention_mask], axis=-1)

attention_mask = ops.tile(
attention_mask, (1, self.nheads, seq_length, 1)).reshape(-1, seq_length, seq_length)
attention_mask, (1, self.nheads, seq_length, 1)
).reshape(-1, seq_length, seq_length)
# For 2D attention_mask from tokenizer
elif attention_mask.ndim == 2:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask = ops.tile(
attention_mask, (1, self.nheads, seq_length, 1)).reshape(-1, seq_length, seq_length)
attention_mask, (1, self.nheads, seq_length, 1)
).reshape(-1, seq_length, seq_length)

attention_mask.stop_gradient = True

Expand Down Expand Up @@ -298,7 +332,9 @@ def construct(
inputs_embeds: Optional[Tensor] = None,
return_dict: Optional[Tensor] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
sequence_output, _ = self.ernie(
input_ids=input_ids,
token_type_ids=token_type_ids,
Expand Down
Loading