Skip to content

Commit

Permalink
Implement model parallel for finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
duzx16 committed May 21, 2021
1 parent 55821c8 commit 019cc41
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 34 deletions.
155 changes: 155 additions & 0 deletions change_mp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import sys
import os
import torch
import copy

checkpoint = sys.argv[1]
target_mp = int(sys.argv[2])

assert os.path.isdir(checkpoint)
iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt')
if os.path.exists(iteration_file):
with open(iteration_file) as fin:
iteration = int(fin.read().strip())
checkpoint = os.path.join(checkpoint, str(iteration))
else:
iteration = None

filenames = os.listdir(checkpoint)
filenames = [filename for filename in filenames if filename.startswith("mp_rank_")]
filenames = sorted(filenames,
key=lambda x: int(x.split('_')[2]))
filenames = [os.path.join(checkpoint, x) for x in filenames]

if target_mp == len(filenames):
print("MP size keeps the same.")
exit(0)

if sys.argv[1][-1] == '/':
new_checkpoint = sys.argv[1][:-1] + '_MP' + sys.argv[2]
else:
new_checkpoint = sys.argv[1] + '_MP' + sys.argv[2]
if not os.path.exists(new_checkpoint):
os.mkdir(new_checkpoint)
if iteration is not None:
with open(os.path.join(new_checkpoint, 'latest_checkpointed_iteration.txt'), 'w') as fout:
fout.write("{}\n".format(iteration))
new_checkpoint = os.path.join(new_checkpoint, str(iteration))
if not os.path.exists(new_checkpoint):
os.mkdir(new_checkpoint)

preserve_keys = [
"lr_scheduler",
"skipped_steps",
"global_steps",
"global_samples",
"dp_world_size",
"iteration",
"client_lr_scheduler",
"np_rng_state",
"random_rng_state",
"torch_rng_state",
"cuda_rng_state",
"rng_tracker_states",

]

if target_mp < len(filenames):
print("Decrease MP size.")
assert len(filenames) % target_mp == 0
ratio = len(filenames) // target_mp
for i in range(target_mp):
start = ratio * i
end = ratio * (i + 1)
d = torch.load(filenames[start],
map_location='cpu')
for k in d.keys():
if k != 'module':
if k in preserve_keys:
pass
elif k == "mp_world_size":
d[k] = target_mp
else:
d[k] = None
for j in range(start + 1, end):
d_new = torch.load(filenames[j],
map_location='cpu')
for k, v in d_new['module'].items():
assert len(v.shape) < 3
if len(v.shape) == 2 and 'position' not in k:
if 'query' in k:
size_1 = d['module'][k].shape[0] // 3
size_2 = v.shape[0] // 3
target = d['module'][k]
d['module'][k] = torch.cat([
target[:size_1, :], v[:size_2, :],
target[size_1:size_1 * 2, :], v[size_2:size_2 * 2, :],
target[size_1 * 2:, :], v[size_2 * 2:, :]], 0)
elif 'word' in k or 'h_to_4h' in k or 'relative' in k or "r_w_bias" in k or "r_r_bias" in k:
d['module'][k] = torch.cat([d['module'][k], v], 0)
else:
d['module'][k] = torch.cat([d['module'][k], v], 1)
elif len(v.shape) == 1 and 'query_key_value' in k:
size_1 = d['module'][k].shape[0] // 3
size_2 = v.shape[0] // 3
target = d['module'][k]
d['module'][k] = torch.cat([
target[:size_1], v[:size_2],
target[size_1:size_1 * 2], v[size_2:size_2 * 2],
target[size_1 * 2:], v[size_2 * 2:]], 0)
elif len(v.shape) == 1 and ('dense_h_to_4h' in k or "attention.relative" in k):
d['module'][k] = torch.cat([d['module'][k], v], 0)
filename = os.path.join(new_checkpoint, "mp_rank_{:02d}_model_states.pt".format(i))
torch.save(d, filename)

if target_mp > len(filenames):
print("Increase MP size.")
assert target_mp % len(filenames) == 0
ratio = target_mp // len(filenames)
for i in range(len(filenames)):
start = ratio * i
end = ratio * (i + 1)
d = torch.load(filenames[i],
map_location='cpu')
for j in range(start, end):
d_new = {}
shift = j - start
for k, v in d.items():
if k != 'module':
if k in preserve_keys:
d_new[k] = copy.deepcopy(d[k])
elif k == "mp_world_size":
d_new[k] = target_mp
else:
d_new[k] = None
d_new['module'] = {}
with torch.no_grad():
for k, v in d['module'].items():
assert len(v.shape) < 3
if len(v.shape) == 2 and 'position' not in k:
if 'query' in k:
part = v.shape[0] // ratio // 3
d_new['module'][k] = torch.cat([v[shift * part:(shift + 1) * part, :].clone(),
v[(shift + ratio) * part:(shift + 1 + ratio) * part,
:].clone(),
v[(shift + 2 * ratio) * part:(shift + 1 + 2 * ratio) * part,
:].clone()], 0)
elif 'word' in k or 'h_to_4h' in k or 'relative' in k or "r_w_bias" in k or "r_r_bias" in k:
part = v.shape[0] // ratio
d_new['module'][k] = v[shift * part:(shift + 1) * part, :].clone()
else:
part = v.shape[1] // ratio
d_new['module'][k] = v[:, shift * part:(shift + 1) * part].clone()
elif len(v.shape) == 1 and ('dense_h_to_4h' in k or "attention.relative" in k):
part = v.shape[0] // ratio
d_new['module'][k] = v[shift * part:(shift + 1) * part].clone()
elif len(v.shape) == 1 and 'query_key_value' in k:
part = v.shape[0] // ratio // 3
d_new['module'][k] = torch.cat(
[v[shift * part:(shift + 1) * part].clone(),
v[(shift + ratio) * part:(shift + 1 + ratio) * part].clone(),
v[(shift + 2 * ratio) * part:(shift + 1 + 2 * ratio) * part].clone()], 0)
else:
d_new['module'][k] = v.clone()
filename = os.path.join(new_checkpoint, "mp_rank_{:02d}_model_states.pt".format(j))
torch.save(d_new, filename)
3 changes: 1 addition & 2 deletions configure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def prepare_tokenizer(args):
assert eod_token == tokenizer.get_command('pad').Id
before = num_tokens
after = before
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
multiple = args.make_vocab_size_divisible_by
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
Expand Down
46 changes: 27 additions & 19 deletions finetune_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from arguments import get_args
from filelock import FileLock
import pathlib
import mpu

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Expand Down Expand Up @@ -41,24 +42,32 @@

def process_batch(batch, args):
"""Process batch and produce inputs for the model."""
new_batch = {'text': batch['text'].long().cuda().contiguous(), 'label': batch['label'].long().cuda().contiguous()}
for key in ['text', 'label', 'types', 'target', 'logit_mask', 'position', 'segment_id', 'prompt_pos',
'dec_text', 'dec_position', 'dec_mask', 'dec_target', 'dec_logit_mask']:
if key in batch:
new_batch[key] = batch[key].long().cuda().contiguous()
keys = ["text", "label"]
if args.pretrained_bert:
keys += ["padding_mask", "types"]
else:
keys += ["mask", "position"]
if args.cloze_eval:
if args.fast_decode:
keys += ["dec_text", "dec_position", "dec_mask", "dec_target", "dec_logit_mask"]
else:
keys += ["target", "logit_mask"]
if args.segment_length > 0:
keys += ["segment_id"]
if args.continuous_prompt:
keys += ["prompt_pos"]
if args.variable_num_choices or "loss_mask" in batch:
keys.append("loss_mask")
# Broadcast data.
datatype = torch.int64
data_b = mpu.broadcast_data(keys, batch, datatype)

if "padding_mask" in batch:
attention_mask = batch['padding_mask'].float().cuda().contiguous()
attention_mask = data_b['padding_mask'].float().cuda().contiguous()
if args.fp16:
attention_mask = attention_mask.half()
new_batch["attention_mask"] = attention_mask
elif "mask" in batch:
attention_mask = batch['mask'].long().cuda().contiguous()
new_batch["attention_mask"] = attention_mask
if "loss_mask" in batch:
new_batch["loss_mask"] = batch["loss_mask"].float().cuda().contiguous()
if args.fp16:
new_batch['loss_mask'] = new_batch['loss_mask'].half()
return new_batch
data_b["padding_mask"] = attention_mask
return data_b


tokenizer = None
Expand All @@ -78,11 +87,11 @@ def finetune_forward_step(batch, model, args, timers, mems):

# Forward model.
if args.pretrained_bert:
tokens, types, labels, attention_mask = data['text'], data['types'], data['label'], data['attention_mask']
tokens, types, labels, attention_mask = data['text'], data['types'], data['label'], data['padding_mask']
logits = model(tokens, token_type_ids=types, attention_mask=attention_mask, checkpoint_activations=True)
elif args.cloze_eval:
tokens, labels, position_ids = data['text'], data['label'], data['position']
attention_mask = data['attention_mask']
attention_mask = data['mask']

def print_masked_text(batch_id):
output_tokens = []
Expand Down Expand Up @@ -121,8 +130,7 @@ def print_masked_text(batch_id):
logits, *mems = model(tokens, position_ids, attention_mask, dec_input_ids, dec_position_ids,
dec_attention_mask, dec_target_ids, dec_logit_mask)
else:
tokens, labels, position_ids, attention_mask = data['text'], data['label'], data['position'], data[
'attention_mask']
tokens, labels, position_ids, attention_mask = data['text'], data['label'], data['position'], data['mask']
logits, *mems = model(tokens, position_ids, attention_mask)

if "segment_id" in data:
Expand Down
4 changes: 3 additions & 1 deletion generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def read_context(tokenizer, args, output):
terminate_runs = terminate_runs_tensor[0].item()

if terminate_runs == 1:
return terminate_runs, raw_text, None, None
return terminate_runs, None, None, None

context_length_tensor = torch.cuda.LongTensor([context_length])

Expand All @@ -258,6 +258,8 @@ def read_context(tokenizer, args, output):
context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
if mpu.get_model_parallel_rank() != 0:
raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
return terminate_runs, raw_text, context_tokens_tensor, context_length


Expand Down
6 changes: 4 additions & 2 deletions scripts/ds_finetune_superglue.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
MP_SIZE=4
DATA_ROOT=/dataset/c07bd62b/superglue
GLUE_DATA_ROOT=/dataset/c07bd62b/glue_data
source config_tasks/model_blocklm_10B.sh
Expand All @@ -9,7 +10,7 @@ CHECKPOINT_PATH="/dataset/c07bd62b/finetune_checkpoints"
MASTER_PORT=$(shuf -n 1 -i 10000-65535)

OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
DISTRIBUTED_ARGS="${OPTIONS_NCCL} deepspeed --include localhost:2,3 --master_port $MASTER_PORT"
DISTRIBUTED_ARGS="${OPTIONS_NCCL} deepspeed --num_gpus 4 --num_nodes 1 --master_port $MASTER_PORT"
DATESTR=$(date +"%m-%d-%H-%M")

mkdir logs
Expand All @@ -25,13 +26,14 @@ run_cmd="${DISTRIBUTED_ARGS} finetune_gpt2.py \
--checkpoint-activations \
--eval-batch-size 16 \
--save-epoch 100 \
--num-workers 0 \
--num-workers 1 \
--no-load-optim \
--no-load-lr-scheduler \
--fp16 \
$MODEL_ARGS \
$TRAIN_ARGS \
$COMMON_ARGS \
--model-parallel-size ${MP_SIZE} \
--epochs ${EPOCH_SINGLE} \
--lr ${LR_SINGLE} \
--overwrite \
Expand Down
3 changes: 1 addition & 2 deletions scripts/generate_block.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ script_dir=$(dirname $script_path)

config_json="$script_dir/ds_config.json"

MASTER_PORT=${MASTER_PORT} python generate_samples.py \
python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT generate_samples.py \
--DDP-impl none \
--model-parallel-size $MPSIZE \
--deepspeed_config ${config_json} \
$MODEL_ARGS \
--fp16 \
--cache-dir cache \
Expand Down
14 changes: 7 additions & 7 deletions tasks/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def multichoice_evaluate(model, dataloader, example_dict, args):
model.eval()
port = get_spare_port()
store = torch.distributed.TCPStore(args.master_ip, port,
mpu.get_data_parallel_world_size(),
torch.distributed.get_world_size(),
torch.distributed.get_rank() == 0, datetime.timedelta(seconds=30))
with torch.no_grad():
# For all the batches in the dataset.
Expand All @@ -126,12 +126,11 @@ def multichoice_evaluate(model, dataloader, example_dict, args):
data = process_batch(batch, args)
if args.pretrained_bert:
tokens, types, labels_, attention_mask = data['text'], data['types'], data['label'], data[
'attention_mask']
'padding_mask']
inputs = [tokens, types, attention_mask]
elif args.cloze_eval:
tokens, labels_, position_ids = data['text'], data['label'], data['position']
attention_mask, target_ids, logit_mask = data['attention_mask'], data.get('target'), data.get(
'logit_mask')
attention_mask, target_ids, logit_mask = data['mask'], data['target'], data['logit_mask']
if not args.fast_decode:
inputs = [tokens, position_ids, attention_mask, target_ids, logit_mask]
if args.continuous_prompt:
Expand All @@ -145,7 +144,7 @@ def multichoice_evaluate(model, dataloader, example_dict, args):
dec_target_ids, dec_logit_mask]
else:
tokens, labels_, position_ids, attention_mask = data['text'], data['label'], data['position'], data[
'attention_mask']
'mask']
inputs = [tokens, position_ids, attention_mask]
if len(inputs[0].shape) == 3 and inputs[0].size(1) > segment_length:
logit_list = []
Expand Down Expand Up @@ -186,8 +185,9 @@ def multichoice_evaluate(model, dataloader, example_dict, args):
labels = labels_.tolist()
if args.task.lower() == 'wsc':
predicted = [1 if pred == 0 else 0 for pred in predicted]
for uid, prediction, label in zip(uid_list, predicted, labels):
store.set(uid, str((prediction, label)))
if mpu.get_model_parallel_rank() == 0:
for uid, prediction, label in zip(uid_list, predicted, labels):
store.set(uid, str((prediction, label)))
model.train()
torch.distributed.barrier()
predictions, labels, examples = [], [], []
Expand Down
13 changes: 13 additions & 0 deletions tasks/superglue/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, args, split, tokenizer, for_train=False):
task_name = args.task.lower()
data_dir = args.data_dir
self.processor = PROCESSORS[task_name](args)
args.variable_num_choices = self.processor.variable_num_choices
print_rank_0(
f"Creating {task_name} dataset from file at {data_dir} (split={split})"
)
Expand Down Expand Up @@ -128,6 +129,10 @@ def output_prediction(self, predictions, examples, output_file):
data = {"idx": example.idx, "label": prediction}
output.write(json.dumps(data) + "\n")

@property
def variable_num_choices(self):
return False

@abstractmethod
def get_train_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the train set."""
Expand Down Expand Up @@ -299,6 +304,10 @@ def get_classifier_input(self, example: InputExample, tokenizer):
class WscProcessor(DataProcessor):
"""Processor for the WSC data set."""

@property
def variable_num_choices(self):
return self.args.wsc_negative

def get_train_examples(self, data_dir, cloze_eval=True):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train", cloze_eval=cloze_eval)

Expand Down Expand Up @@ -622,6 +631,10 @@ def get_classifier_input(self, example: InputExample, tokenizer):
class RecordProcessor(DataProcessor):
"""Processor for the ReCoRD data set."""

@property
def variable_num_choices(self):
return True

def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")

Expand Down
Loading

0 comments on commit 019cc41

Please sign in to comment.