Closed
Description
System Info
transformers
version: 4.35.2- Platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
- Python version: 3.11.7
- Huggingface_hub version: 0.19.4
- Safetensors version: 0.4.0
- Accelerate version: 0.24.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- Tensorflow version (GPU?): 2.15.0 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes.
- Using distributed or parallel set-up in script?: No.
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
# Imports
from transformers import AutoTokenizer, AutoModelForCausalLM
text = "Hello World! This is a test string."
# Gemma
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
gemma_tokens = gemma_tokenizer(text, return_tensors="pt")
gemma_output_dict = gemma_model(**gemma_tokens, output_attentions=True, output_hidden_states=True)
first_gemma_attn_pattern = gemma_output_dict['attentions'][0][0, 0, 0]
print(first_gemma_attn_pattern)
# GPT2
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_tokens = gpt2_tokenizer(text, return_tensors="pt")
gpt2_output_dict = gpt2_model(**gpt2_tokens, output_attentions=True, output_hidden_states=True)
first_gpt2_attn_pattern = gpt2_output_dict['attentions'][0][0, 0, 0]
print(first_gpt2_attn_pattern)
first_gemma_attn_pattern outputs:
tensor([0.1393, 0.1916, 0.0398, 0.1050, 0.0786, 0.0850, 0.0610, 0.1118, 0.1076,
0.0803], grad_fn=<SelectBackward0>)
first_gpt2_attn_pattern outputs:
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward0>)
Expected behavior
I would expect that in both cases the first row of the attention pattern for each model is: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=) due to causal masking. This does not seem to be the case for Gemma where causal masking doesn't appear to be applied.
Metadata
Assignees
Labels
No labels