-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add BartConfig.force_bos_token_to_be_generated #6526
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6526 +/- ##
==========================================
- Coverage 80.59% 78.19% -2.40%
==========================================
Files 156 156
Lines 28058 28055 -3
==========================================
- Hits 22612 21939 -673
- Misses 5446 6116 +670
Continue to review full report at Codecov.
|
) | ||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" | ||
scores[:, all_but_token_ids_mask] = -float("inf") | ||
def _force_token_ids_generation(self, scores, token_id) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just cleanup
@@ -479,7 +479,7 @@ def test_xsum_summarization_same_as_fairseq(self): | |||
self.assertFalse(model.config.is_valid_mbart()) | |||
tok = BartTokenizer.from_pretrained("facebook/bart-large") | |||
|
|||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state." | |||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
summaries change a bit for xsum, but ROUGE increases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for cleaning this. Just a few nits on the doc.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
)" This reverts commit 5a679c3.
This PR adds a config flag that makes a generation hack from bart's
adjust_logits_during_generation
, optional. This change (setting the flag to False) improves metrics for bart-large-xsum and mbart-large-en-ro, but not bart-large-cnn (hence the need for the flag).I remember @patrickvonplaten asked me about this in February, and I ran on CNN and then decided the hack needed to stay, could have been more careful.
cc @patil-suraj
Todo
[x] test xsum
[x] test cnn
[x] test en-ro
[x] update cnn config
[x] test pegasus
[x] update distilbart configs
[] should max_length be changed in xsum config?
TODO:
Metrics
(all on val.source vs. val.target) the first number is runtime over the whole val set. "noforce" means this PR without any config change. (so BOS token is not forced to be generated.)
bart-large-cnn
master: 86:23 {"rouge1": 44.79, "rouge2": 21.64, "rougeL": 31.18}
noforce: 87:40 {'rouge1': 44.26, 'rouge2': 21.22, 'rougeL': 30.72}
bart-large-xsum
master: 41:59, {'rouge1': 45.16, 'rouge2': 21.77, 'rougeL': 36.35}
noforce: 34:12, {'rouge1': 45.45, 'rouge2': 22.38, 'rougeL': 37.25}
mbart-large-en-ro
master: 04:58 BLEU=27.83
noforce: 04:42, BLEU=28.15
pegasus-xsum
master: 56:12 {'rouge1': 46.69, 'rouge2': 24.13, 'rougeL': 38.79}
noforce: 54:15 {'rouge1': 46.98, 'rouge2': 24.43, 'rougeL': 39.11}
Commands