Skip to content

Commit

Permalink
[LED Test] fix common inputs pt for flaky pt-tf led test (huggingface…
Browse files Browse the repository at this point in the history
…#9459)

* fix common inputs pt flakey led

* fix other tests correspondingly
  • Loading branch information
patrickvonplaten authored Jan 7, 2021
1 parent ae5a32b commit a400fe8
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions tests/test_modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(
# because its local attention only attends to `self.attention_window + 1` locations
# (assuming no token with global attention, otherwise the last dimension of attentions
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.encoder_key_length = self.attention_window + 1
# x is set to 1
self.encoder_key_length = self.attention_window + 2

# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
Expand Down Expand Up @@ -149,6 +150,10 @@ def prepare_config_and_inputs(self):

def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
global_attention_mask[:, -1] = 1
inputs_dict["global_attention_mask"] = global_attention_mask

return config, inputs_dict

def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
Expand Down Expand Up @@ -196,9 +201,11 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):
encoder.save_pretrained(tmpdirname)
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)

encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
0
]
encoder_last_hidden_state_2 = encoder(
inputs_dict["input_ids"],
attention_mask=inputs_dict["attention_mask"],
global_attention_mask=inputs_dict["global_attention_mask"],
)[0]

self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)

Expand Down Expand Up @@ -390,7 +397,8 @@ def test_attention_outputs(self):
)
out_len = len(outputs)

correct_outlen = 5
# global attention outputs are added as well => so +1 here
correct_outlen = 6

# loss is at first position
if "labels" in inputs_dict:
Expand Down

0 comments on commit a400fe8

Please sign in to comment.