Based on the script run_language_modeling.py
.
Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2, BERT, DistilBERT and RoBERTa. GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT, DistilBERT and RoBERTa are fine-tuned using a masked language modeling (MLM) loss.
Before running the following example, you should get a file that contains text on which the language model will be trained or fine-tuned. A good example of such text is the WikiText-2 dataset.
We will refer to two different files: $TRAIN_FILE
, which contains text for training, and $TEST_FILE
, which contains
text that will be used for evaluation.
The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before the tokenization). The loss here is that of causal language modeling.
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python run_language_modeling.py \
--output_dir=output \
--model_type=gpt2 \
--model_name_or_path=gpt2 \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE
This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches a score of ~20 perplexity once fine-tuned on the dataset.
The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their pre-training: masked language modeling.
In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge slightly slower (over-fitting takes more epochs).
We use the --mlm
flag so that the script may change its loss function.
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm