Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add longformer tokenizer #502

Merged
merged 48 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
098a0ed
裘博航创建文件
Yicorner Mar 8, 2023
4a679fb
裘博航———创建要求文件
Yicorner Mar 8, 2023
947c987
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 8, 2023
8dc7cd7
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 10, 2023
b119ecf
create longformer model
Yicorner Mar 10, 2023
4570cd2
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 12, 2023
a94338b
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 13, 2023
9220659
add longformer embedding
Yicorner Mar 13, 2023
a5f46e9
Add Longformer Embedding model
Yicorner Mar 13, 2023
fe5dd04
add longformer embedding
Yicorner Mar 13, 2023
26f2c25
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 14, 2023
ae55a61
add cumsum to _legacy/functional and longformer_embedding
Yicorner Mar 14, 2023
d7a2aa0
add cumsum to _legacy/functional and longformer_embedding
Yicorner Mar 14, 2023
90de573
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 14, 2023
96384c2
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 15, 2023
7e32218
add cumsum to functional && embedding class to longformer
Yicorner Mar 15, 2023
9cd93df
modify accroding to review
Yicorner Mar 16, 2023
eb04d89
modify accroding to review
Yicorner Mar 16, 2023
fe97d3f
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 21, 2023
2aacf1b
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 30, 2023
d86bbce
add longformer selfAttention class
Yicorner Mar 30, 2023
b41b35f
Merge branch 'mindspore-lab:master' into master
Yicorner Apr 5, 2023
d6b8d8a
selfoutput
Yicorner Mar 30, 2023
91c3d16
attention
Yicorner Mar 30, 2023
55b1227
Intermediate
Yicorner Mar 30, 2023
6489fd1
OutPut
Yicorner Mar 30, 2023
9842fb2
Layer
Yicorner Mar 31, 2023
0d14290
Encoder
Yicorner Mar 31, 2023
12e8f7c
Pooler
Yicorner Mar 31, 2023
2653594
LMHead
Yicorner Mar 31, 2023
690d020
LongformerModel
Yicorner Mar 31, 2023
d6aeb52
LongformerForMaskedLM
Yicorner Apr 1, 2023
de5c26b
LongformerForSequenceClassification
Yicorner Apr 1, 2023
276d29c
addtestLongformerClassificationHead
Yicorner Apr 1, 2023
3181848
LongformerForQuestionAnswering
Yicorner Apr 1, 2023
297678c
LongformerForTokenClassification
Yicorner Apr 1, 2023
f60590b
LongformerForMultipleChoice
Yicorner Apr 1, 2023
efe69eb
pylint OK
Yicorner Apr 1, 2023
e8c6524
pylint OK
Yicorner Apr 1, 2023
7922953
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
773a20f
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
cdc9352
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
6eba36c
Merge branch 'mindspore-lab:master' into master
Yicorner Apr 17, 2023
8dbe625
Merge branch 'mindspore-lab:master' into master
Yicorner May 13, 2023
20986eb
addtokenizer
Yicorner May 13, 2023
25d1130
addtokenizer2
Yicorner May 13, 2023
063f08e
add longformer tokenizer
Yicorner May 13, 2023
eb6c494
add longformer tokenizer2
Yicorner May 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add longformer embedding
  • Loading branch information
Yicorner committed Mar 13, 2023
commit 922065948bcf7f8294c117e1b02bae7e7750163d
155 changes: 155 additions & 0 deletions mindnlp/models/longformer/longformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Longformer model."""
# pylint: disable=relative-beyond-top-level
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-locals
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
# pylint: disable=line-too-long
# pylint: disable=invalid-name
# pylint:
import mindspore
from mindspore import nn

activation_map = {
'relu': nn.ReLU(),
'gelu': nn.GELU(False),
'gelu_approximate': nn.GELU(),
'swish':nn.SiLU()
}


def _get_question_end_index(input_ids, sep_token_id):
"""
Computes the index of the first occurrence of `sep_token_id`.
"""

sep_token_indices = (input_ids == sep_token_id).nonzero()
batch_size = input_ids.shape[0]

assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert sep_token_indices.shape[0] == 3 * batch_size, (
f"There should be exactly three separator tokens: {sep_token_id} "
f"in every sample for questions answering. You"
" might also consider to set `global_attention_mask` manually i"
"n the forward function to avoid this error."
)
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]


def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
"""
Computes global attention mask by putting attention on all
tokens before `sep_token_id` if `before_sep_token is
True` else after `sep_token_id`.
"""
question_end_index = _get_question_end_index(input_ids, sep_token_id)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = mindspore.numpy.arange(input_ids.shape[1]) # qbh delete device
if before_sep_token is True:
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(mindspore.uint8)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(mindspore.uint8) * (
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
).to(mindspore.uint8)

return attention_mask


def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.

Args:
x: torch.Tensor x:
_
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = mindspore.ops.cumsum(mask, axis=1, dtype=mask.dtype) * mask
return incremental_indices.long() + padding_idx


class LongformerEmbeddings(nn.Cell):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(normalized_shape=(config.hidden_size,), epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(keep_prob=1 - config.hidden_dropout_prob)

self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)

def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
"""forward"""
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

if input_ids is not None:
input_shape = input_ids.shape
else:
input_shape = inputs_embeds.shape[:-1]

if token_type_ids is None:
token_type_ids = mindspore.ops.zeros(input_shape, dtype=mindspore.int64) # delete device

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

Args:
inputs_embeds: torch.Tensor inputs_embeds:

Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]

position_ids = mindspore.numpy.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=mindspore.int64 # delete device
)
return position_ids.unsqueeze(0).expand(input_shape)
87 changes: 87 additions & 0 deletions mindnlp/models/longformer/longformer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=relative-beyond-top-level
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-locals
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
""" Longformer configuration"""
from typing import List, Union
from ...abc.backbones.pretrained import PretrainedConfig

class LongformerConfig(PretrainedConfig):
r"""
Example:

```python
>>> from transformers import LongformerConfig, LongformerModel

>>> # Initializing a Longformer configuration
>>> configuration = LongformerConfig()

>>> # Initializing a model from the configuration
>>> model = LongformerModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "longformer"

def __init__(
self,
attention_window: Union[List[int], int] = 512,
sep_token_id: int = 2,
pad_token_id: int = 1,
bos_token_id: int = 0,
eos_token_id: int = 2,
vocab_size: int = 30522,
hidden_size: int = 768,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
hidden_act: str = "gelu",
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
max_position_embeddings: int = 512,
type_vocab_size: int = 2,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
position_embedding_type: str = "absolute",
classifier_dropout: float = None,
onnx_export: bool = False,
**kwargs
):
"""Constructs LongformerConfig."""
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.attention_window = attention_window
self.sep_token_id = sep_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.classifier_dropout = classifier_dropout
self.onnx_export = onnx_export
Empty file.
42 changes: 42 additions & 0 deletions tests/ut/models/longformer/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test Longformer"""
import unittest
from mindnlp.models.longformer.longformer import LongformerEmbeddings
from mindnlp.models.longformer.longformer_config import LongformerConfig
from mindspore import Tensor
import numpy as np
class TestModelingBert(unittest.TestCase):
r"""
Test model bert
"""
def setUp(self):
"""
Set up.
"""
self.input = None

def test_modeling_longformer_embedding(self):
r"""
Test model bert with pynative mode
"""
ms_config = LongformerConfig()
ms_model = LongformerEmbeddings(ms_config)
ms_model.set_train(False)
tensor = np.random.randint(1, 10, (2, 2))
ms_input_ids = Tensor.from_numpy(tensor)
ms_outputs = ms_model.forward(ms_input_ids)
assert (2, 2, 768) == ms_outputs.shape