Skip to content

Commit

Permalink
wsc binary prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas0809 committed Mar 12, 2021
1 parent 20234c1 commit ae99495
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 55 deletions.
5 changes: 2 additions & 3 deletions config_tasks/task_multirc.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
EXPERIMENT_NAME=${MODEL_TYPE}-MultiRC
TASK_NAME=multirc
DATA_PATH="/root/data/superglue/MultiRC"
MAX_SEQ_LEN=512
MAX_SEQ_LEN=430

LR_RANGE=(1e-5)
EPOCH_RANGE=(10)

LR_SINGLE=1e-5
EPOCH_SINGLE=12

TRAIN_ARGS="--batch-size 16 \
--lr-decay-style linear \
TRAIN_ARGS="--lr-decay-style linear \
--warmup 0.1 \
--weight-decay 1.0e-1"

Expand Down
11 changes: 5 additions & 6 deletions config_tasks/task_wsc.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
EXPERIMENT_NAME=${MODEL_TYPE}-WSC
TASK_NAME=wsc
EXPERIMENT_NAME=${MODEL_TYPE}-${TASK_NAME}
DATA_PATH="/root/data/superglue/WSC-negative"
MAX_SEQ_LEN=128

Expand All @@ -9,15 +9,14 @@ EPOCH_RANGE=(20)
LR_SINGLE=1e-5
EPOCH_SINGLE=20

TRAIN_ARGS="--batch-size 8 \
--lr-decay-style linear \
TRAIN_ARGS="--lr-decay-style linear \
--warmup 0.1 \
--weight-decay 0.1 \
--length-penalty 1 \
--loss-func mix \
--wsc-negative \
--length-penalty 1"
--wsc-negative"

COMMON_ARGS="--save-interval 10000 \
--log-interval 50 \
--eval-interval 1000 \
--eval-iters 100"
--eval-iters 100"
2 changes: 2 additions & 0 deletions data_utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,8 @@ def IdToToken(self, Id, type_token=False):
return Id.token
if type_token:
return self.type_id_map[Id].token
if Id in self.command_id_map:
return self.command_id_map[Id].token
return self.text_tokenizer.decoder[Id]

def TokenToId(self, token, type_token=False):
Expand Down
25 changes: 19 additions & 6 deletions finetune_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tasks.data_utils import build_data_loader
from utils import get_sample_writer, get_log_dir, print_and_save_args
from model import GPT2Model
from model import GPT2Model, VerbalizerModel
from arguments import get_args

# coding=utf-8
Expand Down Expand Up @@ -111,11 +111,21 @@ def print_masked_text(batch_id):
if logit_mask[batch_id][i]:
target_positions.append(i)
print(target_positions)
print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist()))
print(tokenizer.DecodeIds(target_ids[batch_id][target_positions].tolist()))
print(position_ids[batch_id][:, target_positions])

if not args.fast_decode:
print([tokenizer.IdToToken(token) for token in tokens[batch_id][target_positions].tolist()])
print([tokenizer.IdToToken(token) for token in target_ids[batch_id].tolist()])
print(labels[batch_id].item())
# print([tokenizer.IdToToken(token) for token in target_ids[batch_id][target_positions].tolist()])
# print(position_ids[batch_id][:, target_positions])

# print_masked_text(0)
# print_masked_text(1)
if not args.multi_token:
logits, lm_logits, *mems = model(tokens, position_ids, attention_mask, target_ids, logit_mask)
# batch_size = logits.size(0)
# lm_labels = target_ids[range(batch_size), labels]
# loss_func = torch.nn.CrossEntropyLoss()
# lm_loss = loss_func(lm_logits, lm_labels)
elif not args.fast_decode:
logits, *mems = model(tokens, position_ids, attention_mask, target_ids, logit_mask)
else:
dec_input_ids, dec_position_ids, dec_attention_mask = data['dec_text'], data['dec_position'], data['dec_mask']
Expand All @@ -125,6 +135,7 @@ def print_masked_text(batch_id):
tokens, labels, position_ids, attention_mask = data['text'], data['label'], data['position'], data[
'attention_mask']
logits, *mems = model(tokens, position_ids, attention_mask)

if "segment_id" in data:
from torch_scatter import scatter_sum
if "loss_mask" in data:
Expand All @@ -151,6 +162,8 @@ def print_masked_text(batch_id):
else:
raise NotImplementedError

# loss = loss + lm_loss

# Reduce loss for logging.

return loss, mems, 'bert'
Expand Down
20 changes: 20 additions & 0 deletions format_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sys
import json

data = []

with open(sys.argv[1]) as f:
hyper = ''
for line in f:
if line[0] == '{':
res = json.loads(line)
data.append((hyper, res))
else:
hyper = line.strip()

keys = [key for key in data[0][1] if type(data[0][1][key]) is not str]
print(keys)
for hyper, res in data:
print(hyper, end='\t')
print('\t'.join([f'{res[key]:.4f}' for key in keys]))

17 changes: 13 additions & 4 deletions model/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,27 @@ def build_dec_mask_matrix(seq_length, sep, memory_length=0):


class VerbalizerModel(torch.nn.Module):
def __init__(self, language_model):
def __init__(self, language_model, hidden_size=None, vocab_size=None, num_class=None):
super().__init__()
self.model = language_model
# self.dense = torch.nn.Linear(hidden_size, hidden_size)
# self.layer_norm = torch.nn.LayerNorm(hidden_size)
# self.final = torch.nn.Linear(hidden_size, num_class)

def forward(self, input_ids, position_ids, attention_mask, target_ids, logit_mask):
assert len(input_ids.shape) == 2
outputs, *mems = self.model(input_ids, position_ids, attention_mask)
# Original
batch_ids = torch.arange(outputs.size(0), dtype=attention_mask.dtype, device=attention_mask.device)
output = outputs[batch_ids, attention_mask]
target_output = outputs[batch_ids, attention_mask]
batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids)
output = output[batch_ids, target_ids]
return (output, *mems)
output = target_output[batch_ids, target_ids]

# output = self.layer_norm(self.dense(output))
# output = self.final(output)
lm_logits = target_output

return (output, lm_logits, *mems)


class PoolingModel(torch.nn.Module):
Expand Down
5 changes: 3 additions & 2 deletions scripts/finetune_superglue.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
source config_tasks/model_blocklm_roberta_large.sh $2
source config_tasks/model_blocklm_roberta_large.sh
source $1

CHECKPOINT_PATH="/root/data/finetune_checkpoints"

MASTER_PORT=$(shuf -n 1 -i 10000-65535)
DISTRIBUTED_ARGS="--nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"
DISTRIBUTED_ARGS="--nproc_per_node 4 --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"
DATESTR=$(date +"%m-%d-%H-%M")

mkdir logs
Expand All @@ -24,4 +24,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_gpt2.py \
$COMMON_ARGS \
--epochs ${EPOCH_SINGLE} \
--lr ${LR_SINGLE} \
--overwrite \
2>&1 | tee logs/log-${EXPERIMENT_NAME}.txt
42 changes: 31 additions & 11 deletions scripts/finetune_superglue_grid.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@ N_GPU=2
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
DISTRIBUTED_ARGS="--nproc_per_node ${N_GPU} --nnodes 1 --node_rank 0 --master_addr localhost --master_port $MASTER_PORT"

GRID_LOG=logs/grid_${EXPERIMENT_NAME}.txt
echo $EXPERIMENT_NAME > $GRID_LOG
DATESTR=$(date +"%m-%d-%H-%M")
GRID_LOG=logs/grid_${EXPERIMENT_NAME}_${DATESTR}.txt

for lr in 5e-6 1e-5 2e-5
for lr in 1e-5 #2e-5
do
for bs in 16 32
for bs in 16 #32
do
for seed in 1 2 3 4 5
for epoch in 40 #10 20 40
do
HYPER=${lr}-${bs}-${seed}
DATESTR=$(date +"%m-%d-%H-%M")
for warmup in 0.1 #0.06 0
do
for wd in 0.1 0.01 0
do
for beta2 in 0.98 # 0.999
do
for eps in 1e-6 # 1e-8
do
for seed in 1 2 3 # 4 5
do
HYPER=${lr}-b${bs}-ep${epoch}-wm${warmup}-wd${wd}-${seed}
PER_GPU_BS=$((bs/N_GPU))
python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_gpt2.py \
--finetune \
Expand All @@ -29,17 +38,28 @@ do
--seq-length ${MAX_SEQ_LEN} \
--eval-batch-size 16 \
$MODEL_ARGS \
$TRAIN_ARGS \
$COMMON_ARGS \
--epochs ${EPOCH_SINGLE} \
--lr-decay-style linear \
--epochs ${epoch} \
--lr ${lr} \
--weight-decay ${wd} \
--warmup ${warmup} \
--batch-size ${PER_GPU_BS} \
--seed ${seed} \
--optimizer adam \
--adam-beta2 ${beta2} \
--adam-eps ${eps} \
--overwrite \
2>&1 | tee logs/log-${EXPERIMENT_NAME}-${HYPER}.txt
echo $lr $bs $seed >> $GRID_LOG
echo $lr $bs $epoch $warmup $seed >> $GRID_LOG
cat runs/${EXPERIMENT_NAME}/${HYPER}/results.json >> $GRID_LOG
done
done
done
done
done
done
done
done
done

echo $EXPERIMENT_NAME >> $GRID_LOG
30 changes: 19 additions & 11 deletions tasks/superglue/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@
SPLIT_TYPES = [TRAIN_SET, DEV_SET, TEST_SET, UNLABELED_SET]


def get_output_func(task_name):
return PROCESSORS[task_name]().output_prediction
def get_output_func(task_name, args):
return PROCESSORS[task_name](args).output_prediction


class GlueDataset(Dataset):

def __init__(self, args, split, tokenizer, for_train=False):
task_name = args.task.lower()
data_dir = args.data_dir
processor = PROCESSORS[task_name]()
processor = PROCESSORS[task_name](args)
print_rank_0(
f"Creating {task_name} dataset from file at {data_dir} (split={split})"
)
Expand Down Expand Up @@ -107,7 +107,8 @@ class DataProcessor(ABC):
task
"""

def __init__(self):
def __init__(self, args):
self.args = args
self.num_truncated = 0

def output_prediction(self, predictions, examples, output_file):
Expand Down Expand Up @@ -171,11 +172,7 @@ def encode(self, example: InputExample, tokenizer, args):

class RteProcessor(DataProcessor):
"""Processor for the RTE data set."""

def __init__(self):
super().__init__()
self.mnli_processor = MnliProcessor()


def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")

Expand Down Expand Up @@ -315,8 +312,7 @@ def get_classifier_input(self, example: InputExample, tokenizer):
text_b = target
return text_a, text_b

@staticmethod
def _create_examples(path: str, set_type: str, cloze_eval=True) -> List[InputExample]:
def _create_examples(self, path: str, set_type: str, cloze_eval=True) -> List[InputExample]:
examples = []

with open(path, encoding='utf8') as f:
Expand Down Expand Up @@ -374,6 +370,17 @@ def _create_examples(path: str, set_type: str, cloze_eval=True) -> List[InputExa
text_a = ' '.join(words_a)
meta['span1_index'], meta['span2_index'] = span1_index, span2_index

if self.args.task == 'wsc1':
example = InputExample(guid=guid, text_a=text_a, text_b=span1_text,
label=label, meta=meta, idx=idx)
examples.append(example)
if set_type == 'train' and label == 'True':
for cand in candidates:
example = InputExample(guid=guid, text_a=text_a, text_b=cand,
label='False', meta=meta, idx=idx)
examples.append(example)
continue

if cloze_eval and set_type == 'train' and label != 'True':
continue
if set_type == 'train' and 'candidates' in example_json and len(candidates) > 9:
Expand Down Expand Up @@ -970,6 +977,7 @@ def _create_examples(self, path: str) -> List[InputExample]:
"rte": RteProcessor,
"cb": CbProcessor,
"wsc": WscProcessor,
"wsc1": WscProcessor,
"boolq": BoolQProcessor,
"copa": CopaProcessor,
"multirc": MultiRcProcessor,
Expand Down
17 changes: 7 additions & 10 deletions tasks/superglue/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
"rte": [("accuracy", accuracy_metric)],
"boolq": [("accuracy", accuracy_metric)],
"wic": [("accuracy", accuracy_metric)],
'wsc': [("accuracy", accuracy_metric)],
"wsc": [("accuracy", accuracy_metric)],
"wsc1": [("accuracy", accuracy_metric)],
"cb": [("accuracy", accuracy_metric), ("f1-macro", f1_macro_metric)],
"multirc": [("f1a", f1_metric), ("em", multirc_em), ("acc", accuracy_metric)]
}
Expand All @@ -50,7 +51,7 @@ def metrics_func_provider(args, tokenizer, is_test):
def single_dataset_provider(split):
return GlueDataset(args, split, tokenizer)

output_func = get_output_func(args.task.lower())
output_func = get_output_func(args.task.lower(), args)
eval_func = None
if args.task.lower() == 'wsc' and args.cloze_eval and not args.wsc_negative:
from tasks.language_model.finetune import classify_evaluate
Expand All @@ -67,17 +68,13 @@ def main(args):
finetune(args, train_valid_datasets_provider, model_kwargs,
end_of_epoch_callback_provider=metrics_func_provider, forward_step=lm_forward_step)
else:
processor = PROCESSORS[args.task.lower()]()
if args.cloze_eval:
pvp = PVPS[args.task.lower()](args, None, processor.get_labels(), args.seq_length,
pattern_id=args.pattern_id, is_multi_token=args.multi_token)
multi_token = pvp.is_multi_token
else:
multi_token = args.task.lower() in MULTI_CHOICE_DATASETS
processor = PROCESSORS[args.task.lower()](args)
multi_token = args.task.lower() in ['copa', 'wsc', 'record']
args.multi_token = multi_token
if not multi_token:
model_kwargs["model_type"] = "multiple_choice" if args.cloze_eval else "classification"
model_kwargs["multi_token"] = False
model_kwargs["num_labels"] = len(PROCESSORS[args.task.lower()]().get_labels())
model_kwargs["num_labels"] = len(processor.get_labels())
else:
model_kwargs["model_type"] = "multiple_choice"
model_kwargs["multi_token"] = True
Expand Down
Loading

0 comments on commit ae99495

Please sign in to comment.