Skip to content

Commit

Permalink
Implement fp16 finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
duzx16 committed May 5, 2021
1 parent 35b1625 commit 01a311f
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 58,058 deletions.
28,996 changes: 0 additions & 28,996 deletions .pytorch_pretrained_bert/bert-base-cased-vocab.txt

This file was deleted.

28,996 changes: 0 additions & 28,996 deletions .pytorch_pretrained_bert/bert-large-cased-vocab.txt

This file was deleted.

2 changes: 1 addition & 1 deletion config_tasks/model_blocklm_roberta_1.25.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ MODEL_ARGS="--block-lm \
--max-position-embeddings 1024 \
--tokenizer-model-type roberta \
--tokenizer-type GPT2BPETokenizer \
--load-pretrained /dataset/fd5061f6/english_data/checkpoints/blocklm-roberta-1.25-blank04-22-14-01"
--load-pretrained /dataset/c07bd62b/checkpoints/blocklm-roberta-1.25-blank04-22-14-01"
4 changes: 2 additions & 2 deletions finetune_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from utils import print_rank_0
from utils import Timers
from train_utils import setup_model_and_optimizer, train_step
from utils import load_checkpoint, save_checkpoint, load_pretrained
from train_utils import setup_model_and_optimizer, train_step, load_pretrained
from utils import load_checkpoint, save_checkpoint
from pretrain_gpt2 import report_iteration_metrics
from pretrain_gpt2 import evaluate_and_print_results
from pretrain_gpt2 import initialize_distributed
Expand Down
4 changes: 0 additions & 4 deletions model/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def forward(self, *inputs, **kwargs):
def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles]
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook

return sd

def load_state_dict(self, state_dict, strict=True):
Expand Down
17 changes: 11 additions & 6 deletions scripts/finetune_superglue.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
DATA_ROOT=/dataset/fd5061f6/english_data/superglue
DATA_ROOT=/root/data/superglue
source config_tasks/model_blocklm_roberta_1.25.sh
source $1

CHECKPOINT_PATH="/dataset/fd5061f6/english_data/finetune_checkpoints"
CHECKPOINT_PATH="/root/data/finetune_checkpoints"

if [ -z $N_GPU ];then
N_GPU=2
fi
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
DISTRIBUTED_ARGS="--nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"
DATESTR=$(date +"%m-%d-%H-%M")
DISTRIBUTED_ARGS="--nproc_per_node ${N_GPU} --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"

PER_GPU_BS=$(($BATCH_SIZE/$N_GPU))

mkdir logs
python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_gpt2.py \
Expand All @@ -17,12 +21,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_gpt2.py \
--save ${CHECKPOINT_PATH} \
--seq-length ${MAX_SEQ_LEN} \
--checkpoint-activations \
--batch-size 8 \
--eval-batch-size 16 \
--save-epoch 5 \
--save-epoch 1000 \
$MODEL_ARGS \
$TRAIN_ARGS \
$COMMON_ARGS \
--fp16 \
--batch-size ${PER_GPU_BS} \
--epochs ${EPOCH_SINGLE} \
--lr ${LR_SINGLE} \
--overwrite \
Expand Down
13 changes: 9 additions & 4 deletions scripts/finetune_superglue_grid.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
DATA_ROOT=/dataset/fd5061f6/english_data/superglue
DATA_ROOT=/dataset/c07bd62b/superglue
source config_tasks/model_blocklm_roberta_1.25.sh
source $1

CHECKPOINT_PATH="/dataset/fd5061f6/english_data/finetune_checkpoints"
CHECKPOINT_PATH="/dataset/c07bd62b/finetune_checkpoints"

if [ -z $N_GPU ];then
N_GPU=2
fi
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
DISTRIBUTED_ARGS="--nproc_per_node ${N_GPU} --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"

DATESTR=$(date +"%m-%d-%H-%M")
GRID_LOG=logs/grid_${EXPERIMENT_NAME}_${DATESTR}.txt

mkdir logs
for lr in 6e-6 1e-5 2e-5
do
for seed in 1234 5678 3456
Expand All @@ -27,10 +30,12 @@ do
--seq-length ${MAX_SEQ_LEN} \
--checkpoint-activations \
--eval-batch-size 16 \
--save-epoch 5 \
--save-epoch 1000 \
$MODEL_ARGS \
$TRAIN_ARGS \
$COMMON_ARGS \
--fp16 \
--attention-scale 8.0 \
--batch-size ${PER_GPU_BS} \
--epochs ${EPOCH_SINGLE} \
--lr-decay-style linear \
Expand Down
46 changes: 45 additions & 1 deletion train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,51 @@
PyTorchDistributedDataParallel as TorchDDP, \
DistributedDataParallel as LocalDDP, gpt2_get_params_for_weight_decay_optimization
from model.modeling import BertForMultipleChoice, BertForSequenceClassification
from utils import print_rank_0
from utils import print_rank_0, get_checkpoint_name, get_checkpoint_iteration


def load_pretrained(model, checkpoint_path, args):
load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
checkpoint_name = get_checkpoint_name(load_dir, tag, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading pretrained model {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')
if args.deepspeed:
model = model.module
if isinstance(model, TorchDDP):
model = model.module
if isinstance(model, FP16_Module):
model = model.module
if hasattr(model, "model"):
model = model.model

# Model.
def extend_embedding_weights(state_weights, model_weights):
original_length = state_weights.shape[0]
assert original_length <= args.max_position_embeddings + 1
new_weights = model_weights.clone()
new_weights[:original_length] = state_weights
return new_weights

if args.block_lm:
if "transformer.block_position_embeddings.weight" in sd["module"]:
position_weights = sd['module']["transformer.position_embeddings.weight"]
if args.max_position_embeddings + 1 > position_weights.shape[0]:
sd['module']["transformer.position_embeddings.weight"] = extend_embedding_weights(
position_weights, model.state_dict()["transformer.position_embeddings.weight"].data)
print_rank_0(f"Extend position embedding to {args.max_position_embeddings + 1}")
if "transformer.block_position_embeddings.weight" in sd["module"]:
block_position_weights = sd['module']["transformer.block_position_embeddings.weight"]
if args.max_position_embeddings + 1 > block_position_weights.shape[0]:
sd['module']["transformer.block_position_embeddings.weight"] = extend_embedding_weights(
block_position_weights,
model.state_dict()["transformer.block_position_embeddings.weight"].data)
print_rank_0(f"Extend block position embedding to {args.max_position_embeddings + 1}")
missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
if missing_keys or unexpected_keys:
print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")


def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_length=None):
Expand Down
48 changes: 0 additions & 48 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import json
import subprocess

from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from fp16 import FP16_Optimizer
import mpu
from tensorboardX import SummaryWriter
Expand Down Expand Up @@ -227,8 +226,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, b
save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
else:
# Only rank zer0 of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module

if mpu.get_data_parallel_rank() == 0:
checkpoint_name = get_checkpoint_name(args.save, tag)
Expand Down Expand Up @@ -318,48 +315,6 @@ def get_checkpoint_iteration(load_path):
return load_path, iteration, release, True


def load_pretrained(model, checkpoint_path, args):
load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
checkpoint_name = get_checkpoint_name(load_dir, tag, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading pretrained model {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')
if args.deepspeed:
model = model.module
if isinstance(model, torchDDP):
model = model.module
if hasattr(model, "model"):
model = model.model

# Model.
def extend_embedding_weights(state_weights, model_weights):
original_length = state_weights.shape[0]
assert original_length <= args.max_position_embeddings + 1
new_weights = model_weights.clone()
new_weights[:original_length] = state_weights
return new_weights

if args.block_lm:
if "transformer.block_position_embeddings.weight" in sd["module"]:
position_weights = sd['module']["transformer.position_embeddings.weight"]
if args.max_position_embeddings + 1 > position_weights.shape[0]:
sd['module']["transformer.position_embeddings.weight"] = extend_embedding_weights(
position_weights, model.state_dict()["transformer.position_embeddings.weight"].data)
print_rank_0(f"Extend position embedding to {args.max_position_embeddings + 1}")
if "transformer.block_position_embeddings.weight" in sd["module"]:
block_position_weights = sd['module']["transformer.block_position_embeddings.weight"]
if args.max_position_embeddings + 1 > block_position_weights.shape[0]:
sd['module']["transformer.block_position_embeddings.weight"] = extend_embedding_weights(
block_position_weights,
model.state_dict()["transformer.block_position_embeddings.weight"].data)
print_rank_0(f"Extend block position embedding to {args.max_position_embeddings + 1}")
missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
if missing_keys or unexpected_keys:
print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")


def load_checkpoint(model, optimizer, lr_scheduler, args):
"""Load a model checkpoint."""

Expand Down Expand Up @@ -392,9 +347,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')

if isinstance(model, torchDDP):
model = model.module

# Model.
missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
if missing_keys or unexpected_keys:
Expand Down

0 comments on commit 01a311f

Please sign in to comment.