Skip to content

Commit

Permalink
Beef up Llama tests (huggingface#22314)
Browse files Browse the repository at this point in the history
* tmp commit

* beef up llama tests
  • Loading branch information
gante authored Mar 22, 2023
1 parent 12febc2 commit fd3eb3e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,10 +1463,10 @@ def test_generate_with_head_masking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device)
# We want to test only encoder-decoder models
if not config.is_encoder_decoder:
continue
model = model_class(config).to(torch_device)

head_masking = {
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device),
Expand Down
31 changes: 14 additions & 17 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from transformers import LlamaConfig, is_torch_available
from transformers.testing_utils import require_torch, torch_device

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
Expand Down Expand Up @@ -254,10 +256,21 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": LlamaModel,
"text-classification": LlamaForSequenceClassification,
"text-generation": LlamaForCausalLM,
"zero-shot": LlamaForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False

def setUp(self):
self.model_tester = LlamaModelTester(self)
Expand Down Expand Up @@ -316,22 +329,6 @@ def test_llama_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@unittest.skip("LLaMA does not support head pruning.")
def test_head_pruning(self):
pass

@unittest.skip("LLaMA does not support head pruning.")
def test_head_pruning_integration(self):
pass

@unittest.skip("LLaMA does not support head pruning.")
def test_head_pruning_save_load_from_config_init(self):
pass

@unittest.skip("LLaMA does not support head pruning.")
def test_head_pruning_save_load_from_pretrained(self):
pass

@unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass

0 comments on commit fd3eb3e

Please sign in to comment.