Skip to content

Commit

Permalink
add BartConfig.force_bos_token_to_be_generated (huggingface#6526)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and fabiocapsouza committed Nov 15, 2020
1 parent dcc7379 commit 5a679c3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
for SequenceClassification
is_encoder_decoder (:obj:`int`, optional, defaults to True):
True
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only true for `bart-large-cnn`.
"""

Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
normalize_embedding=True,
static_position_embeddings=False,
add_bias_logits=False,
force_bos_token_to_be_generated=False,
**common_kwargs
):
r"""
Expand Down Expand Up @@ -195,6 +198,8 @@ def __init__(
# pos embedding offset
self.extra_pos_embeddings = self.pad_token_id + 1

self.force_bos_token_to_be_generated = force_bos_token_to_be_generated

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand Down
18 changes: 5 additions & 13 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,23 +1073,15 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask,
}

def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1:
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits

def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")

@staticmethod
def _reorder_cache(past, beam_idx):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MarianMTModel(BartForConditionalGeneration):
"""

def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf")
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
2 changes: 1 addition & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def test_xsum_summarization_same_as_fairseq(self):
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")

EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
).to(torch_device)
Expand Down

0 comments on commit 5a679c3

Please sign in to comment.