This repository contains implementation of the paper "Learning to Win Lottery Tickets in BERT Transfer via Task-agnostic Mask Training" (accepted by NAACL 2022).
The code for task-agnostic mask training is based on huggingface/transformers, TinyBERT (for TAMT-KD) and maskbert.
The code for downstream fine-tuning and IMP is modified from BERT-Tickets.
TAMT learns the subnetwork structures on the pre-training dataset, using either the MLM loss or the KD loss. The identified subnetwork is then fine-tuned on a range of downstream tasks, in place of the original BERT model.
The pre-training performance of a BERT subnetwork correlates with its down-stream transferability.
Python3
torch>1.4.0
Download the pre-training dataset WikiText-103 to mask_training/data/wikitext-103
.
Download the GLUE datasets to imp_and_fine_tune/glue
and the SQuAD v1.1 dataset to mask_training/data/squad
.
Download bert-base-uncased
and roberta-base
from huggingface models.
The following instructions use BERT-base as the example, results of RoBERTa-base can be reproduced in similar fashion.
The scripts for running TAMT-MLM are in the folder mask_training/shell_scripts/train_mlm
.
To perform TAMT-MLM on BERT-base with 0.7 sparsity, run
bash mask_training/shell_scripts/train_mlm/bert/0.7.sh
Before running TAMT-KD, we first need to prepare the training data by running the following command, which will store the training data to mask_training/data/wikitext-103-kd/epoch_0.json
.
bash mask_training/shell_scripts/generate_kd_data_bert.sh
The scripts for running TAMT-KD are in the folder mask_training/shell_scripts/train_kd
.
To perform TAMT-KD on BERT-base with 0.7 sparsity, run
bash mask_training/shell_scripts/train_kd/bert/0.7.sh
The scripts for running TAMT-MLM+KD are in the folder mask_training/shell_scripts/train_mlm_kd
.
To perform TAMT-MLM+KD on BERT-base with 0.7 sparsity, run
bash mask_training/shell_scripts/train_mlm_kd/bert/0.7.sh
As a default setting, we initialize TAMT with OMP mask, to perform TAMT with random mask initialization, run
bash mask_training/shell_scripts/$tamt_type/bert/rand_mask_init/0.7.sh
where $tamt_type=train_mlm or train_kd
.
The scripts for IMP are in the folder imp_and_fine_tune/shell_scripts/imp
.
To perform IMP on BERT-base with an interval of 2,792 training steps between pruning steps, run
bash imp_and_fine_tune/shell_scripts/imp/bert/prun_step2792/pretrain_imp_seed1.sh
The scripts for OMP and random pruning are in the folder imp_and_fine_tune/shell_scripts/oneshot
.
To perform OMP on BERT-base, run
bash imp_and_fine_tune/shell_scripts/oneshot/bert_mag.sh
To perform random pruning on BERT-base, run
bash imp_and_fine_tune/shell_scripts/oneshot/bert_rand.sh
To fine-tune the full BERT-base on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/full_bert/mnli.sh
To fine-tune the TAMT-MLM, TAMT-KD or TAMT-MLM+KD BERT-base subnetwork (0.7 sparsity) on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/$tamt_type/main_result/bert/mnli/0.7.sh
where $tamt_type=train_mlm, train_kd or train_mlm_kd
.
To fine-tune the TAMT-MLM or TAMT-KD BERT-base subnetworks (0.7 sparsity) obtained from 1,000 steps of pre-training on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/$tamt_type/steps/bert/mnli/0.7/step1000.sh
where $tamt_type=train_mlm or train_kd
.
To fine-tune the IMP BERT-base subnetwork (0.7 sparsity) on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/imp_pretrain/bert/mnli/0.7.sh
To fine-tune the IMP BERT-base subnetwork (0.7 sparsity) obtained from 200*6=1,200 steps of pre-training on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/imp_pretrain/bert/steps/mnli/0.7/seed1/step200.sh
To fine-tune the OMP BERT-base subnetwork (0.7 sparsity) on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/oneshot/bert/mnli/0.7.sh
To fine-tune the randomly pruned BERT-base subnetwork (0.7 sparsity) on task MNLI, run
bash imp_and_fine_tune/shell_scripts/run_glue/rand/bert/mnli/0.7.sh
To analyse the correlation between subnetworks' pre-training and downstream performace, we need to calculate the MLM and KD loss.
Here we provide instructions on how to evaluate the MLM loss. KD loss can be evaluated in the same way using the scripts in mask_training/shell_scripts/eval_kd
To evaluate the MLM loss of the models (including the original pre-trained BERT-base and the subnetworks) on the validation set of Wikitext, run
bash mask_training/shell_scripts/eval_mlm/$name.sh
where $name=bert, omp, imp, rand, tamt_mlm or tamt_kd
. Note that we don't need to evalute the KD loss for BERT-base itself.
To compute the similarity between OMP, IMP, TAMT-MLM and TAMT-KD masks, run:
bash mask_training/shell_scripts/mask_sim.sh
To compute the distance from OMP mask, run:
bash mask_training/shell_scripts/mask_dist.sh
If you use this repository in a published research, please cite our paper:
@inproceedings{Liu2022TAMT,
author = {Yuanxin Liu, Fandong Meng, Zheng Lin, Peng Fu, Yanan Cao, Weipinng Wang, Jie Zhou},
title = {Learning to Win Lottery Tickets in BERT Transfer via Task-agnostic Mask Training},
booktitle = {NAACL 2022},
year = {2022}
}