Skip to content

Commit

Permalink
ModernBERT bug fixes (#35404)
Browse files Browse the repository at this point in the history
* bug fixes

* organize imports

* wrap cpu warning in reference_compile

* Avoid needing repad_logits_with_grad, always repad with grads when training

I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that?

* Revert "Avoid needing repad_logits_with_grad, always repad with grads when training"

This reverts commit cedcb4e.

* Fix grammar: keep -> keeps

* Propagate grammar fix with modular_model_converter

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent e97d7a5 commit 1e3ddcb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@
- local: model_doc/mobilebert
title: MobileBERT
- local: model_doc/modernbert
title: ModernBert
title: ModernBERT
- local: model_doc/mpnet
title: MPNet
- local: model_doc/mpt
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/modernbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ rendered properly in your Markdown viewer.
-->

# ModernBert
# ModernBERT

<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/models?filter=modernbert">
Expand All @@ -27,7 +27,7 @@ rendered properly in your Markdown viewer.

## Overview

The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.

It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class ModernBertConfig(PretrainedConfig):
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
Examples:
Expand Down Expand Up @@ -164,6 +167,7 @@ def __init__(
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -203,6 +207,7 @@ def __init__(
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
Expand Down
28 changes: 20 additions & 8 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# limitations under the License.

import math
from contextlib import nullcontext
from typing import Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -632,12 +633,14 @@ def _autoset_attn_implementation(
):
# If the user didn't specify anything, try to use flash_attention_2 if available.
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
if config._attn_implementation_internal is None:
config._attn_implementation_internal = "flash_attention_2"
try:
return cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
Expand All @@ -647,7 +650,7 @@ def _autoset_attn_implementation(
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map,
check_device_map=check_device_map,
)
Expand All @@ -672,6 +675,14 @@ def _maybe_set_compile(self):
)
self.config.reference_compile = False

if self.device.type == "cpu":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False

if self.config.reference_compile is None:
self.config.reference_compile = is_triton_available()

Expand Down Expand Up @@ -763,8 +774,8 @@ def _pad_modernbert_output(
MODERNBERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
Expand All @@ -790,7 +801,7 @@ def _pad_modernbert_output(
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
far-away tokens in the local attention layers.
far-away tokens in the local attention layers when not using Flash Attention.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
Expand All @@ -805,11 +816,11 @@ def _pad_modernbert_output(
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
max_seqlen (`int`, *optional*):
Maximum sequence length in the batch. Used to pad the output tensors.
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
batch_size (`int`, *optional*):
Batch size of the input sequences. Used to pad the output tensors.
seq_len (`int`, *optional*):
Sequence length of the input sequences. Used to pad the output tensors.
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -1128,8 +1139,9 @@ def forward(
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
with torch.no_grad():
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
Expand Down
33 changes: 25 additions & 8 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import math
from contextlib import nullcontext
from typing import Dict, Literal, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -141,6 +142,9 @@ class ModernBertConfig(PretrainedConfig):
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
Examples:
Expand Down Expand Up @@ -196,6 +200,7 @@ def __init__(
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -235,6 +240,7 @@ def __init__(
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
Expand Down Expand Up @@ -857,12 +863,14 @@ def _autoset_attn_implementation(
):
# If the user didn't specify anything, try to use flash_attention_2 if available.
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
if config._attn_implementation_internal is None:
config._attn_implementation_internal = "flash_attention_2"
try:
return cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
Expand All @@ -872,7 +880,7 @@ def _autoset_attn_implementation(
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map,
check_device_map=check_device_map,
)
Expand All @@ -897,6 +905,14 @@ def _maybe_set_compile(self):
)
self.config.reference_compile = False

if self.device.type == "cpu":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False

if self.config.reference_compile is None:
self.config.reference_compile = is_triton_available()

Expand All @@ -916,8 +932,8 @@ def resize_token_embeddings(self, *args, **kwargs):
MODERNBERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
Expand All @@ -943,7 +959,7 @@ def resize_token_embeddings(self, *args, **kwargs):
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
far-away tokens in the local attention layers.
far-away tokens in the local attention layers when not using Flash Attention.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
Expand All @@ -958,11 +974,11 @@ def resize_token_embeddings(self, *args, **kwargs):
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
max_seqlen (`int`, *optional*):
Maximum sequence length in the batch. Used to pad the output tensors.
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
batch_size (`int`, *optional*):
Batch size of the input sequences. Used to pad the output tensors.
seq_len (`int`, *optional*):
Sequence length of the input sequences. Used to pad the output tensors.
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -1281,8 +1297,9 @@ def forward(
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
with torch.no_grad():
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
Expand Down

0 comments on commit 1e3ddcb

Please sign in to comment.