Skip to content

Commit

Permalink
Transformers 4.44 support (#1996)
Browse files Browse the repository at this point in the history
* test

* fix conll2003 dataset with remote code

* sdpa for new bloom attention block

* style

* fix bloom modeling

* better version ranges to reflect max and min transformers support

* pin right version

* use input dims
  • Loading branch information
IlyasMoutawwakil authored Sep 2, 2024
1 parent 3b55875 commit 7cc57e4
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 124 deletions.
218 changes: 146 additions & 72 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from ...utils import check_if_transformers_greater


# TODO (CRITICAL): Layer-wise attention scaling is broken for several archs.
Expand All @@ -23,7 +26,7 @@
def raise_on_head_mask(head_mask: Optional[torch.Tensor]):
if head_mask is not None:
raise ValueError(
"layer_head_mask different than None is unsupported for now with BetterTransformer, please"
"layer_head_mask (or head_mask) different than None is unsupported for now with BetterTransformer, please"
"open a PR or an issue at https://github.com/huggingface/optimum."
)

Expand Down Expand Up @@ -534,88 +537,159 @@ def bart_forward(
return attn_output, None, past_key_value


# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)
if check_if_transformers_greater("4.44"):
from transformers.cache_utils import Cache
from transformers.models.bloom.modeling_bloom import dropout_add

# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
raise_on_head_mask(head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

batch_size, q_length, _ = hidden_states.shape
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)

if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])

if attention_mask is not None: # no matter the length, we just slice it
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
causal_mask = attention_mask[:, :, :, :kv_length]
alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
outputs = (output_tensor, layer_past)

batch_size, q_length, _, _ = query_layer.shape
return outputs

# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)
else:
# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)

# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
batch_size, q_length, _, _ = query_layer.shape

# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor
key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
else:
present = None

return (output_tensor, present)
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
)
else:
present = None

return (output_tensor, present)
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.dropout_prob_attn = config.attention_dropout

self.module_mapping = None
self.layer_idx = getattr(layer, "layer_idx", None)

submodules = ["query_key_value", "dense", "attention_dropout"]
for attr in submodules:
setattr(self, attr, getattr(layer, attr))
Expand Down
38 changes: 21 additions & 17 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,27 +338,31 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
if check_if_transformers_greater("4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
21 changes: 13 additions & 8 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def prepare_past_key_values(
dtype = constructor.float16 if self.use_fp16 else constructor.float32

# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.model_type == "bloom":
if self.__class__.__name__ == "ORTBloomForCausalLM":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
Expand All @@ -354,9 +353,9 @@ def prepare_past_key_values(
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = 1 if "value" in name else 2

shape[index] += sequence_length
pkv_output_shape[name] = shape

elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
Expand All @@ -371,9 +370,9 @@ def prepare_past_key_values(
shape = [*value.shape]
shape[1] += sequence_length
pkv_output_shape[name] = shape

else:
num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads

shape = (batch_size, num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

Expand Down Expand Up @@ -534,9 +533,9 @@ def _from_pretrained(

# Since https://github.com/huggingface/optimum/pull/871/
# changed axis notation/naming during export, we need to update the dims
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
for input_name in input_dims.keys():
if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length":
input_dims[input_name][2] = "past_sequence_length"
override_dims = True

if override_dims:
Expand All @@ -559,6 +558,12 @@ def _from_pretrained(
size_threshold=0,
)

# Since transformers 4.44, the bloom model has been updated to use the standard cache format
use_old_bloom_modeling = not check_if_transformers_greater("4.44")
for input_name in input_dims.keys():
if input_dims[input_name][0] == "batch_size x num_heads":
use_old_bloom_modeling = True

del onnx_model

model = ORTModel.load_model(
Expand All @@ -568,7 +573,7 @@ def _from_pretrained(
provider_options=provider_options,
)

if config.model_type == "bloom":
if config.model_type == "bloom" and use_old_bloom_modeling:
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
Expand Down
Loading

0 comments on commit 7cc57e4

Please sign in to comment.