Skip to content

Commit

Permalink
Implement COPA continuous prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
duzx16 committed Mar 10, 2021
1 parent 5432d1a commit 59cbee5
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 56 deletions.
1 change: 1 addition & 0 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def add_finetune_config_args(parser):
group.add_argument('--load-pretrained', type=str, help="Load pretrained model", default=None)
group.add_argument('--pool-token', type=str, choices=['start', 'pad', 'cls'],
help='The token to pool the sequence representation', default='cls')
group.add_argument('--continuous-prompt', action='store_true', help="Use continuous prompt for PET")
group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task')
group.add_argument('--multi-token', action='store_true', help='Use multi token for cloze evaluation')
group.add_argument('--segment-length', type=int, default=0, help="The maximum segment length for cloze evaluation")
Expand Down
16 changes: 12 additions & 4 deletions finetune_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def process_batch(batch, args):
new_batch['loss_mask'] = new_batch['loss_mask'].half()
if "segment_id" in batch:
new_batch["segment_id"] = batch["segment_id"].long().cuda().contiguous()
if "prompt_pos" in batch:
new_batch["prompt_pos"] = batch["prompt_pos"].long().cuda().contiguous()
return new_batch
# if args.fp16:
# attention_mask = attention_mask.half()
Expand Down Expand Up @@ -114,13 +116,19 @@ def print_masked_text(batch_id):
print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist()))
print(tokenizer.DecodeIds(target_ids[batch_id][target_positions].tolist()))
print(position_ids[batch_id][:, target_positions])

if not args.fast_decode:
logits, *mems = model(tokens, position_ids, attention_mask, target_ids, logit_mask)
if args.continuous_prompt:
prompt_pos = data["prompt_pos"]
logits, *mems = model(tokens, position_ids, attention_mask, target_ids, logit_mask,
prompt_pos=prompt_pos)
else:
logits, *mems = model(tokens, position_ids, attention_mask, target_ids, logit_mask)
else:
dec_input_ids, dec_position_ids, dec_attention_mask = data['dec_text'], data['dec_position'], data['dec_mask']
dec_input_ids, dec_position_ids, dec_attention_mask = data['dec_text'], data['dec_position'], data[
'dec_mask']
dec_target_ids, dec_logit_mask = data['dec_target'], data['dec_logit_mask']
logits, *mems = model(tokens, position_ids, attention_mask, dec_input_ids, dec_position_ids, dec_attention_mask, dec_target_ids, dec_logit_mask)
logits, *mems = model(tokens, position_ids, attention_mask, dec_input_ids, dec_position_ids,
dec_attention_mask, dec_target_ids, dec_logit_mask)
else:
tokens, labels, position_ids, attention_mask = data['text'], data['label'], data['position'], data[
'attention_mask']
Expand Down
17 changes: 10 additions & 7 deletions model/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@


class ClozeModel(torch.nn.Module):
def __init__(self, language_model, take_softmax=True, length_penalty=0.0):
def __init__(self, language_model: GPT2Model, take_softmax=True, length_penalty=0.0):
super(ClozeModel, self).__init__()
self.model = language_model
self.take_softmax = take_softmax
self.length_penalty = length_penalty

def forward(self, input_ids, position_ids, attention_mask, target_ids=None, logit_mask=None):
def forward(self, input_ids, position_ids, attention_mask, target_ids=None, logit_mask=None, prompt_pos=None):
if target_ids == None:
outputs, *mems = self.model(input_ids, position_ids, attention_mask)
return (outputs, *mems)
Expand All @@ -39,7 +39,9 @@ def forward(self, input_ids, position_ids, attention_mask, target_ids=None, logi
position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
target_ids = target_ids.reshape(-1, target_ids.size(-1))
logit_mask = logit_mask.reshape(-1, logit_mask.size(-1))
outputs, *mems = self.model(input_ids, position_ids, attention_mask)
if prompt_pos is not None:
prompt_pos = prompt_pos.reshape(-1, prompt_pos.size(-1))
outputs, *mems = self.model(input_ids, position_ids, attention_mask, prompt_pos=prompt_pos)
if self.take_softmax:
outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
batch_ids = torch.arange(target_ids.size(0), dtype=torch.long, device=target_ids.device)
Expand All @@ -62,7 +64,7 @@ def __init__(self, language_model, take_softmax=True, length_penalty=0.0):
self.take_softmax = take_softmax
self.length_penalty = length_penalty

def forward(self, input_ids, position_ids, attention_mask,
def forward(self, input_ids, position_ids, attention_mask,
dec_input_ids, dec_position_ids, dec_attention_mask, dec_target_ids, dec_logit_mask):
# encoder
outputs, *mems = self.model(input_ids, position_ids, attention_mask, return_memory=True, detach_memory=False)
Expand All @@ -71,7 +73,8 @@ def forward(self, input_ids, position_ids, attention_mask,

enc_mems = []
for hidden in mems:
hidden = hidden.unsqueeze(1).expand(-1,num_choices,-1,-1).reshape(batch_size*num_choices, *hidden.size()[1:])
hidden = hidden.unsqueeze(1).expand(-1, num_choices, -1, -1).reshape(batch_size * num_choices,
*hidden.size()[1:])
enc_mems.append(hidden)

def build_dec_mask_matrix(seq_length, sep, memory_length=0):
Expand All @@ -80,10 +83,10 @@ def build_dec_mask_matrix(seq_length, sep, memory_length=0):

# sep = dec_attention_mask
ids = torch.arange(memory_length, device=sep.device, dtype=sep.dtype).view(1, -1)
mask = ids < sep.view(-1, 1) # batch * mem
mask = ids < sep.view(-1, 1) # batch * mem
mask = mask.unsqueeze(1).float().expand(-1, seq_length, -1)

m = m.expand(batch_size*num_choices, -1, -1)
m = m.expand(batch_size * num_choices, -1, -1)
m = torch.cat((mask, m), dim=2)
m = m.unsqueeze(1)
return m
Expand Down
31 changes: 26 additions & 5 deletions model/gpt2_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def __init__(self,
parallel_output=True,
relative_encoding=False,
block_position_encoding=False,
output_predict=True
output_predict=True,
spell_length=None
):
super(GPT2Model, self).__init__()

self.parallel_output = parallel_output
self.output_predict = output_predict

self.hidden_size = hidden_size
init_method = init_method_normal(std=0.02)

# Word embeddings (parallel).
Expand All @@ -82,12 +83,32 @@ def __init__(self,
checkpoint_num_layers,
relative_encoding=relative_encoding,
block_position_encoding=block_position_encoding)

def forward(self, input_ids, position_ids, attention_mask, *mems, return_memory=False, detach_memory=True):
if spell_length is not None:
self.spell_length = spell_length
self.spell_embeddings = torch.nn.Embedding(self.spell_length, self.hidden_size)
self.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=2,
# dropout=self.lstm_dropout,
bidirectional=True,
batch_first=True) # .to(torch.device("cuda"))
self.mlp_head = torch.nn.Sequential(torch.nn.Linear(2 * self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size))

def forward(self, input_ids, position_ids, attention_mask, *mems, return_memory=False, detach_memory=True,
prompt_pos=None):
# Embeddings.
batch_size = input_ids.size(0)
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings

if prompt_pos is not None:
embeddings = embeddings.clone()
prompt_embeds = self.spell_embeddings.weight.unsqueeze(0)
prompt_embeds = self.lstm_head(prompt_embeds)[0]
prompt_embeds = self.mlp_head(prompt_embeds)
batch_index = torch.arange(batch_size, device=input_ids.device).unsqueeze(1)
embeddings[batch_index, prompt_pos] = prompt_embeds
# Transformer.
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems,
return_memory=return_memory, detach_memory=detach_memory)
Expand Down
3 changes: 2 additions & 1 deletion scripts/finetune_superglue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_gpt2.py \
--checkpoint-activations \
--batch-size 8 \
--eval-batch-size 16 \
--save-epoch 5 \
--save-epoch 20 \
--overwrite \
$MODEL_ARGS \
$TRAIN_ARGS \
$COMMON_ARGS \
Expand Down
8 changes: 4 additions & 4 deletions tasks/superglue/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(self, args, split, tokenizer, for_train=False):
examples.sort(key=lambda x: x.num_choices)
if args.cloze_eval:
pvp = PVPS[task_name](args, tokenizer, processor.get_labels(), args.seq_length, pattern_id=args.pattern_id,
is_multi_token=args.multi_token, max_segment_length=args.segment_length,
fast_decode=args.fast_decode, split=split)
is_multi_token=args.multi_token, max_segment_length=args.segment_length,
fast_decode=args.fast_decode, split=split, continuous_prompt=args.continuous_prompt)
for example in examples:
sample = pvp.encode(example)
self.samples.append(sample)
Expand Down Expand Up @@ -386,9 +386,9 @@ def _create_examples(path: str, set_type: str, cloze_eval=True) -> List[InputExa
if set_type == 'train' and 'candidates' in example_json and len(candidates) > 9:
for i in range(0, len(candidates), 9):
_meta = copy.deepcopy(meta)
_meta['candidates'] = candidates[i:i+9]
_meta['candidates'] = candidates[i:i + 9]
if len(_meta['candidates']) < 9:
_meta['candidates'] += candidates[:9-len(_meta['candidates'])]
_meta['candidates'] += candidates[:9 - len(_meta['candidates'])]
example = InputExample(guid=guid, text_a=text_a, label=label, meta=_meta, idx=idx)
examples.append(example)
else:
Expand Down
8 changes: 5 additions & 3 deletions tasks/superglue/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ def single_dataset_provider(split):

def main(args):
model_kwargs = {}
processor = PROCESSORS[args.task.lower()]()
pvp = PVPS[args.task.lower()](args, None, processor.get_labels(), args.seq_length,
pattern_id=args.pattern_id, is_multi_token=args.multi_token)
if args.continuous_prompt:
model_kwargs["spell_length"] = pvp.spell_length
if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative:
from tasks.language_model.finetune import lm_forward_step
finetune(args, train_valid_datasets_provider, model_kwargs,
end_of_epoch_callback_provider=metrics_func_provider, forward_step=lm_forward_step)
else:
processor = PROCESSORS[args.task.lower()]()
if args.cloze_eval:
pvp = PVPS[args.task.lower()](args, None, processor.get_labels(), args.seq_length,
pattern_id=args.pattern_id, is_multi_token=args.multi_token)
multi_token = pvp.is_multi_token
else:
multi_token = args.task.lower() in MULTI_CHOICE_DATASETS
Expand Down
Loading

0 comments on commit 59cbee5

Please sign in to comment.