Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add longformer tokenizer #502

Merged
merged 48 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
098a0ed
裘博航创建文件
Yicorner Mar 8, 2023
4a679fb
裘博航———创建要求文件
Yicorner Mar 8, 2023
947c987
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 8, 2023
8dc7cd7
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 10, 2023
b119ecf
create longformer model
Yicorner Mar 10, 2023
4570cd2
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 12, 2023
a94338b
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 13, 2023
9220659
add longformer embedding
Yicorner Mar 13, 2023
a5f46e9
Add Longformer Embedding model
Yicorner Mar 13, 2023
fe5dd04
add longformer embedding
Yicorner Mar 13, 2023
26f2c25
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 14, 2023
ae55a61
add cumsum to _legacy/functional and longformer_embedding
Yicorner Mar 14, 2023
d7a2aa0
add cumsum to _legacy/functional and longformer_embedding
Yicorner Mar 14, 2023
90de573
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 14, 2023
96384c2
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 15, 2023
7e32218
add cumsum to functional && embedding class to longformer
Yicorner Mar 15, 2023
9cd93df
modify accroding to review
Yicorner Mar 16, 2023
eb04d89
modify accroding to review
Yicorner Mar 16, 2023
fe97d3f
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 21, 2023
2aacf1b
Merge branch 'mindspore-lab:master' into master
Yicorner Mar 30, 2023
d86bbce
add longformer selfAttention class
Yicorner Mar 30, 2023
b41b35f
Merge branch 'mindspore-lab:master' into master
Yicorner Apr 5, 2023
d6b8d8a
selfoutput
Yicorner Mar 30, 2023
91c3d16
attention
Yicorner Mar 30, 2023
55b1227
Intermediate
Yicorner Mar 30, 2023
6489fd1
OutPut
Yicorner Mar 30, 2023
9842fb2
Layer
Yicorner Mar 31, 2023
0d14290
Encoder
Yicorner Mar 31, 2023
12e8f7c
Pooler
Yicorner Mar 31, 2023
2653594
LMHead
Yicorner Mar 31, 2023
690d020
LongformerModel
Yicorner Mar 31, 2023
d6aeb52
LongformerForMaskedLM
Yicorner Apr 1, 2023
de5c26b
LongformerForSequenceClassification
Yicorner Apr 1, 2023
276d29c
addtestLongformerClassificationHead
Yicorner Apr 1, 2023
3181848
LongformerForQuestionAnswering
Yicorner Apr 1, 2023
297678c
LongformerForTokenClassification
Yicorner Apr 1, 2023
f60590b
LongformerForMultipleChoice
Yicorner Apr 1, 2023
efe69eb
pylint OK
Yicorner Apr 1, 2023
e8c6524
pylint OK
Yicorner Apr 1, 2023
7922953
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
773a20f
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
cdc9352
Merge branch 'master' of https://github.com/Yicorner/mindnlp
Yicorner Apr 5, 2023
6eba36c
Merge branch 'mindspore-lab:master' into master
Yicorner Apr 17, 2023
8dbe625
Merge branch 'mindspore-lab:master' into master
Yicorner May 13, 2023
20986eb
addtokenizer
Yicorner May 13, 2023
25d1130
addtokenizer2
Yicorner May 13, 2023
063f08e
add longformer tokenizer
Yicorner May 13, 2023
eb6c494
add longformer tokenizer2
Yicorner May 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
LongformerModel
  • Loading branch information
Yicorner committed Apr 5, 2023
commit 690d0203f98976993fcb7f1f530dc1062e8b5cdf
271 changes: 268 additions & 3 deletions mindnlp/models/longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@
# pylint: disable=invalid-name
import inspect
import math
from typing import List, Set, Tuple, Callable
from typing import List, Set, Tuple, Callable, Optional

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
from mindspore import ops
from ..utils.activations import ACT2FN

from .longformer_config import LongformerConfig
from ...abc.backbones.pretrained import PretrainedModel
from ..utils import logging
from ..utils.mixin import CellUtilMixin
logger = logging.get_logger(__name__)

def apply_chunking_to_forward(
forward_fn: Callable[..., mindspore.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
Expand Down Expand Up @@ -1301,4 +1305,265 @@ def construct(self, features, **kwargs):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
self.bias = self.decoder.bias # qbh delete if device
self.bias = self.decoder.bias # qbh delete if device


class LongformerPreTrainedModel(PretrainedModel, CellUtilMixin):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""

config_class = LongformerConfig
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"]

def post_init(self):
pass

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Dense):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing = value


class LongformerModel(LongformerPreTrainedModel):
"""
This class copied code from [`RobertaModel`] and overwrote standard self-attention with longformer self-attention
to provide the ability to process long sequences following the self-attention approach described in [Longformer:
the Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, and Arman Cohan.
Longformer self-attention combines a local (sliding window) and global attention to extend to long documents
without the O(n^2) increase in memory and compute.

The self-attention module `LongformerSelfAttention` implemented here supports the combination of local and global
attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated
attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future
release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA
kernel to be memory and compute efficient.

"""

def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
assert config.attention_window > 0, "`config.attention_window` has to be positive"
config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
else:
assert len(config.attention_window) == config.num_hidden_layers, (
"`len(config.attention_window)` should equal `config.num_hidden_layers`. "
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
)

self.embeddings = LongformerEmbeddings(config)
self.encoder = LongformerEncoder(config)
self.pooler = LongformerPooler(config) if add_pooling_layer else None

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

def _pad_to_window_size(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
inputs_embeds: Tensor,
pad_token_id: int,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer self-attention."""
# padding
attention_window = (
self.config.attention_window
if isinstance(self.config.attention_window, int)
else max(self.config.attention_window)
)

assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
batch_size, seq_len = input_shape[:2]

padding_len = (attention_window - seq_len % attention_window) % attention_window

# this path should be recorded in the ONNX export, it is fine with padding_len == 0 as well
if padding_len > 0:
logger.info(
f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
f"`config.attention_window`: {attention_window}"
)
if input_ids is not None:
input_ids = ops.pad(input_ids, (0, padding_len), value=pad_token_id)
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = ops.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len),
self.config.pad_token_id,
dtype=mindspore.int64,
)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = ops.cat([inputs_embeds, inputs_embeds_padding], axis=-2)

attention_mask = ops.pad(
attention_mask, (0, padding_len), value=0
) # no attention on the padding tokens
token_type_ids = ops.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0

return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds

def _merge_to_attention_mask(self, attention_mask: Tensor, global_attention_mask: Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
return attention_mask

def construct(
self,
input_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
global_attention_mask: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Tuple:
r"""

Returns:

Examples:

```python
>>> import torch
>>> from transformers import LongformerModel, AutoTokenizer

>>> model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

>>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1

>>> attention_mask = torch.ones(
... input_ids.shape, dtype=torch.long, device=input_ids.device
... ) # initialize to local attention
>>> global_attention_mask = torch.zeros(
... input_ids.shape, dtype=torch.long, device=input_ids.device
... ) # initialize to global attention to be deactivated for all tokens
>>> global_attention_mask[
... :,
... [
... 1,
... 4,
... 21,
... ],
... ] = 1 # Set global attention to random tokens for the sake of this example
>>> # Usually, set global attention based on the task. For example,
>>> # classification: the <s> token
>>> # QA: question tokens
>>> # LM: potentially on the beginning of sentences and paragraphs
>>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
>>> sequence_output = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")


if attention_mask is None:
attention_mask = ops.ones(input_shape)
if token_type_ids is None:
token_type_ids = ops.zeros(input_shape, dtype=mindspore.int64)

# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)

padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pad_token_id=self.config.pad_token_id,
)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, :
]

embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)

encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
padding_len=padding_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict or return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
39 changes: 38 additions & 1 deletion tests/ut/models/longformer/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mindnlp.models.longformer.longformer import LongformerEncoder
from mindnlp.models.longformer.longformer import LongformerPooler
from mindnlp.models.longformer.longformer import LongformerLMHead
from mindnlp.models.longformer.longformer import LongformerModel


class TestModelingEmbeddings(unittest.TestCase):
Expand Down Expand Up @@ -314,4 +315,40 @@ def test_modeling_longformer_embedding(self):
ms_outputs = ms_model(
features=ms_features,
)
assert (1, 8, 30522) == ms_outputs.shape
assert (1, 8, 30522) == ms_outputs.shape


class TestModelingLongformerModel(unittest.TestCase):
r"""
Test model bert
"""
def setUp(self):
"""
Set up.
"""
self.input = None

def test_modeling_longformer_embedding(self):
r"""
Test model bert with pynative mode
"""
ms_config = LongformerConfig(
attention_window=[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
max_position_embeddings=40,
)
ms_model = LongformerModel(ms_config)
ms_model.set_train(False)
input_ids = np.random.randint(1, 10, (1, 10))
attention_mask = np.random.randint(0, 2, (1, 10))
global_attention_mask = np.random.randint(0, 2, (1, 10))

ms_input_ids = mindspore.Tensor(input_ids, dtype=mindspore.int32)
ms_attention_mask = mindspore.Tensor(attention_mask, dtype=mindspore.int32)
ms_global_attention_mask = mindspore.Tensor(global_attention_mask, dtype=mindspore.int32)
ms_outputs = ms_model(
input_ids=ms_input_ids,
attention_mask=ms_attention_mask,
global_attention_mask=ms_global_attention_mask,
)
assert (1, 10, 768) == ms_outputs[0].shape
assert (1, 768) == ms_outputs[1].shape