Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BartConfig.force_bos_token_to_be_generated #6526

Merged
merged 4 commits into from
Aug 18, 2020

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Aug 16, 2020

This PR adds a config flag that makes a generation hack from bart's adjust_logits_during_generation, optional. This change (setting the flag to False) improves metrics for bart-large-xsum and mbart-large-en-ro, but not bart-large-cnn (hence the need for the flag).

I remember @patrickvonplaten asked me about this in February, and I ran on CNN and then decided the hack needed to stay, could have been more careful.
cc @patil-suraj

Todo

[x] test xsum
[x] test cnn
[x] test en-ro
[x] update cnn config
[x] test pegasus
[x] update distilbart configs
[] should max_length be changed in xsum config?

TODO:

  • check pegasus
  • should max_length be 1 lower for configs where hack is removed? Yes.

Metrics

(all on val.source vs. val.target) the first number is runtime over the whole val set. "noforce" means this PR without any config change. (so BOS token is not forced to be generated.)

bart-large-cnn
master: 86:23 {"rouge1": 44.79, "rouge2": 21.64, "rougeL": 31.18}
noforce: 87:40 {'rouge1': 44.26, 'rouge2': 21.22, 'rougeL': 30.72}

bart-large-xsum
master: 41:59, {'rouge1': 45.16, 'rouge2': 21.77, 'rougeL': 36.35}
noforce: 34:12, {'rouge1': 45.45, 'rouge2': 22.38, 'rougeL': 37.25}

mbart-large-en-ro
master: 04:58 BLEU=27.83
noforce: 04:42, BLEU=28.15

pegasus-xsum
master: 56:12 {'rouge1': 46.69, 'rouge2': 24.13, 'rougeL': 38.79}
noforce: 54:15 {'rouge1': 46.98, 'rouge2': 24.43, 'rougeL': 39.11}

Commands

export DATA_DIR=wmt_en_ro
python run_eval.py facebook/mbart-large-en-ro \
    $DATA_DIR/val.source gens/mbart-enro-branch-gens.txt \
    --reference_path $DATA_DIR/val.target \
    --score_path gens/mbart-enro-master-bleu.json \
    --task translation_en_to_ro \
    --device cuda \
    --fp16 \
    --bs 32

export DATA_DIR=$CNN_DIR
python run_eval.py facebook/bart-large-cnn \
    $DATA_DIR/val.source gens/cnn_val_generations_no_force.txt \
    --reference_path $DATA_DIR/val.target \
    --score_path gens/cnn_gen_no_force_rouge.txt \
    --device cuda \
    --fp16 \
    --bs 32


export DATA_DIR=$CNN_DIR
python run_eval.py facebook/bart-large-cnn \
    $DATA_DIR/val.source gens/cnn_val_generations_master.txt \
    --reference_path $DATA_DIR/val.target \
    --score_path gens/cnn_gen_master_rouge.txt \
    --device cuda \
    --fp16 \
    --bs 32

@sshleifer sshleifer changed the title WIP: dont force bos token WIP: (bart, mbart): dont force bos token Aug 16, 2020
@sshleifer sshleifer changed the title WIP: (bart, mbart): dont force bos token WIP: (bart, mbart): don't force bos token at step 1 Aug 16, 2020
@codecov
Copy link

codecov bot commented Aug 16, 2020

Codecov Report

Merging #6526 into master will decrease coverage by 2.39%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6526      +/-   ##
==========================================
- Coverage   80.59%   78.19%   -2.40%     
==========================================
  Files         156      156              
  Lines       28058    28055       -3     
==========================================
- Hits        22612    21939     -673     
- Misses       5446     6116     +670     
Impacted Files Coverage Δ
src/transformers/configuration_bart.py 94.00% <100.00%> (+0.12%) ⬆️
src/transformers/modeling_bart.py 95.56% <100.00%> (-0.20%) ⬇️
src/transformers/modeling_marian.py 90.00% <100.00%> (ø)
src/transformers/optimization.py 25.55% <0.00%> (-70.00%) ⬇️
src/transformers/pipelines.py 26.26% <0.00%> (-53.43%) ⬇️
src/transformers/modeling_tf_bert.py 66.00% <0.00%> (-32.38%) ⬇️
src/transformers/optimization_tf.py 33.33% <0.00%> (-24.33%) ⬇️
src/transformers/modeling_tf_auto.py 48.79% <0.00%> (-18.08%) ⬇️
src/transformers/data/processors/squad.py 13.76% <0.00%> (-14.38%) ⬇️
src/transformers/modeling_auto.py 64.36% <0.00%> (-14.37%) ⬇️
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2060181...5e951e1. Read the comment docs.

@sshleifer sshleifer changed the title WIP: (bart, mbart): don't force bos token at step 1 BartConfig.force_bos_token_to_be_generated to make hack optional Aug 16, 2020
@sshleifer sshleifer changed the title BartConfig.force_bos_token_to_be_generated to make hack optional add BartConfig.force_bos_token_to_be_generated Aug 16, 2020
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
def _force_token_ids_generation(self, scores, token_id) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just cleanup

@@ -479,7 +479,7 @@ def test_xsum_summarization_same_as_fairseq(self):
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")

EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summaries change a bit for xsum, but ROUGE increases.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for cleaning this. Just a few nits on the doc.

src/transformers/configuration_bart.py Outdated Show resolved Hide resolved
src/transformers/configuration_bart.py Outdated Show resolved Hide resolved
sshleifer and others added 2 commits August 17, 2020 09:01
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sshleifer sshleifer merged commit 1529bf9 into huggingface:master Aug 18, 2020
@sshleifer sshleifer deleted the dont-force-bos branch August 18, 2020 23:15
@sshleifer sshleifer mentioned this pull request Sep 8, 2020
3 tasks
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants