Skip to content

Commit

Permalink
merge with github
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas0809 committed Mar 2, 2021
2 parents 1e9f2c3 + d37283d commit 84e0753
Show file tree
Hide file tree
Showing 22 changed files with 817 additions and 141 deletions.
8 changes: 7 additions & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def add_training_args(parser):
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--switch-linear', action='store_true', help="Switch to linear decay for cosine decay")
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
Expand Down Expand Up @@ -192,6 +193,9 @@ def add_training_args(parser):
# BlockLM training args
group.add_argument('--block-lm', action='store_true', help="whether use the BlockLM pre-training")
group.add_argument('--bert-prob', type=float, default=0.5)
group.add_argument('--infill-prob', type=float, default=0.5)
group.add_argument('--avg-block-length', type=int, default=3)
group.add_argument('--task-mask', action='store_true', help="Use different mask for generation and blank filling")
group.add_argument('--no-shuffle-block', action='store_true', help="not shuffle the blocks when filling the blank")
group.add_argument('--no-block-position', action='store_true',
help='Use (rough) absolute positions instead of block positions')
Expand Down Expand Up @@ -333,7 +337,9 @@ def add_finetune_config_args(parser):
help='The token to pool the sequence representation', default='cls')
group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task')
group.add_argument('--multi-token', action='store_true', help='Use multi token for cloze evaluation')
group.add_argument('--loss-func', type=str, choices=["cross_entropy", "hinge"], default="cross_entropy")
group.add_argument('--segment-length', type=int, default=0, help="The maximum segment length for cloze evaluation")
group.add_argument('--loss-func', type=str, choices=["cross_entropy", "hinge", "generative", "mix"],
default="cross_entropy")
group.add_argument('--pattern-id', type=int, default=0)
group.add_argument('--fast-decode', action='store_true', help="Fast decode for multi-token cloze")
group.add_argument('--eval-valid', action='store_true', help="Whether evaluate on the valid set")
Expand Down
44 changes: 26 additions & 18 deletions blocklm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def index_in_list(lst, val, start=None):
class ConstructBlockStrategy:
def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, infill_prob=0.5, min_gpt_ratio=0.5,
block_ratio=0.15, average_block_length=3, max_block_length=40, average_gap_length=3,
block_position_encoding=True, encoder_decoder=False, shuffle_blocks=True, sentinel_token=False):
block_position_encoding=True, encoder_decoder=False, shuffle_blocks=True, sentinel_token=False,
task_mask=False):
self.args = args
self.tokenizer = tokenizer
self.count = 0
Expand All @@ -49,6 +50,8 @@ def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, infill_prob=0
self.shuffle_blocks = shuffle_blocks
self.sentinel_token = sentinel_token
self.gap_length_distribution = [poisson.pmf(i, average_gap_length) for i in range(0, max_block_length)]
self.generation_mask = 'gMASK' if task_mask else 'MASK'
self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id

@staticmethod
def sample_spans(span_lengths, total_length, rng, offset=0):
Expand Down Expand Up @@ -108,7 +111,7 @@ def sample_span_in_document(self, tokens, masked_lengths, rng):
mask_index += current_count
return mask_spans

def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng):
def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng, generation_task=False):
position_ids = np.ones(len(tokens), dtype=np.long)
for start, end in block_spans:
position_ids[start + 1: end] = 0
Expand Down Expand Up @@ -142,9 +145,13 @@ def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng):
source_tokens, source_position_ids = [], []
last = 0
for start, end, idx in block_spans:
mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
if generation_task:
mask_id = self.generation_mask
else:
mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
mask_id = self.tokenizer.get_command(mask_token).Id
source_tokens.append(tokens[last: start])
source_tokens.append([self.tokenizer.get_command(mask_token).Id])
source_tokens.append([mask_id])
source_position_ids.append(position_ids[last: start])
source_position_ids.append([position_ids[start]])
last = end
Expand All @@ -169,14 +176,15 @@ def make_block_data(self, tokens, loss_masks, attention_mask, block_spans, rng):
position_ids = [position_ids, block_position_ids]
return tokens, targets, loss_masks, position_ids

def generate_blank_data(self, sample, masked_lengths, attention_mask, rng):
def generate_blank_data(self, sample, masked_lengths, attention_mask, rng, generation_task=False):
rng.shuffle(masked_lengths)
tokens, loss_masks = sample['text'], sample['loss_mask']
assert tokens[0] == self.tokenizer.get_command('ENC').Id
block_spans = self.sample_span_in_document(tokens, masked_lengths, rng)
if len(block_spans) < len(masked_lengths):
return None
data = self.make_block_data(tokens, loss_masks, attention_mask, block_spans, rng)
data = self.make_block_data(tokens, loss_masks, attention_mask, block_spans, rng,
generation_task=generation_task)
return data

def construct_blocks(self, samples):
Expand Down Expand Up @@ -229,29 +237,29 @@ def construct_blocks(self, samples):
tokens, loss_masks = sample['text'], sample['loss_mask']
source_tokens, target_tokens = tokens[:division], tokens[division:]
target_masks = loss_masks[division:]
tokens = np.concatenate((source_tokens, [self.tokenizer.get_command('MASK').Id,
self.tokenizer.get_command('sop').Id], target_tokens[:-1],
[self.tokenizer.get_command('pad').Id]))
targets = np.concatenate((source_tokens, [self.tokenizer.get_command('MASK').Id], target_tokens,
[self.tokenizer.get_command('pad').Id]))
tokens = np.concatenate((
source_tokens, [self.generation_mask, self.tokenizer.get_command('sop').Id],
target_tokens[:-1], [self.tokenizer.get_command('pad').Id]))
targets = np.concatenate(
(source_tokens, [self.generation_mask], target_tokens, [self.tokenizer.get_command('pad').Id]))
loss_masks = np.concatenate((np.zeros(len(source_tokens) + 1, dtype=np.long), target_masks, [0]))
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
position_ids = np.arange(len(source_tokens) + len(target_tokens) + 2, dtype=np.long)
position_ids[len(source_tokens) + 1:] = len(source_tokens)
if self.block_position_encoding:
position_ids = np.arange(len(source_tokens) + len(target_tokens) + 2, dtype=np.long)
position_ids[len(source_tokens) + 1:] = len(source_tokens)
block_position_ids = np.concatenate(
(np.zeros(len(source_tokens), dtype=np.long),
np.arange(len(target_tokens) + 2, dtype=np.long)))
position_id_batch.append([position_ids, block_position_ids])
else:
position_ids = np.arange(len(source_tokens) + len(target_tokens) + 2, dtype=np.long)
position_ids[len(source_tokens) + 1:] -= 1
position_id_batch.append(position_ids)
block_position_ids = np.concatenate((np.zeros(len(source_tokens) + 1, dtype=np.long),
np.ones(len(target_tokens) + 1, dtype=np.long)))
position_id_batch.append([position_ids, block_position_ids])
else:
tokens, targets, loss_masks, position_ids = self.generate_blank_data(sample, [generation_length],
attention_mask, rng)
attention_mask, rng,
generation_task=True)
token_batch.append(tokens)
target_batch.append(targets)
loss_mask_batch.append(loss_masks)
Expand Down
8 changes: 4 additions & 4 deletions config/ds_blockta_large.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ gpt_options=" \
--tokenizer-model-type roberta \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
--lr-decay-iters 300000 \
--lr-decay-ratio 0.01 \
--warmup .1 \
--lr-decay-style linear \
--lr-decay-iters 500000 \
--lr-decay-ratio 0.025 \
--warmup .06 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--fp16 \
Expand Down
25 changes: 25 additions & 0 deletions config_tasks/seq_cnndm_org.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
EXPERIMENT_NAME=${MODEL_TYPE}-cnndm_org
TASK_NAME=cnn_dm_original
DATA_PATH="/root/data/cnn_dm_original"

TRAIN_ARGS="--epochs 15 \
--batch-size 8 \
--lr 3e-5 \
--lr-decay-style linear \
--warmup 0.06 \
--weight-decay 1.0e-1
--label-smoothing 0.1"

COMMON_ARGS="--save-interval 10000 \
--log-interval 50 \
--eval-interval 1000 \
--eval-iters 100"

TASK_ARGS="--src-seq-length 608 \
--tgt-seq-length 160 \
--min-tgt-length 55 \
--length-penalty 0.7 \
--no-repeat-ngram-size 3 \
--num-beams 5 \
--select-topk \
--eval-batch-size 4"
6 changes: 4 additions & 2 deletions configure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def prepare_tokenizer(args):
add_sentinel_token = args.max_position_embeddings
tokenizer = make_tokenizer(args.tokenizer_type, None, args.tokenizer_path, args.vocab_size,
args.tokenizer_model_type, add_block_symbols=args.block_lm, cache_dir=args.cache_dir,
add_sentinel_token=add_sentinel_token)
add_sentinel_token=add_sentinel_token, add_task_mask=args.task_mask)
if mpu.get_model_parallel_rank() == 0:
num_tokens = tokenizer.num_tokens
eod_token = tokenizer.get_command('eos').Id
Expand Down Expand Up @@ -115,10 +115,12 @@ def make_data_loader(dataset, tokenizer, batch_size, num_iters, args):
use_block = args.block_lm or args.encoder_decoder
if use_block:
strategy = ConstructBlockStrategy(args, tokenizer, args.max_position_embeddings, bert_prob=args.bert_prob,
infill_prob=args.infill_prob, average_block_length=args.avg_block_length,
shuffle_blocks=not args.no_shuffle_block,
block_position_encoding=not args.no_block_position,
sentinel_token=args.sentinel_token,
encoder_decoder=args.encoder_decoder)
encoder_decoder=args.encoder_decoder,
task_mask=args.task_mask)
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
Expand Down
8 changes: 7 additions & 1 deletion data_utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class BertWordPieceTokenizer(Tokenizer):
"""

def __init__(self, tokenizer_model_type=None, cache_dir=None, add_block_symbols=False, add_sentinel_token=0,
**kwargs):
add_task_mask=False, **kwargs):
# default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased'
Expand Down Expand Up @@ -751,6 +751,12 @@ def __init__(self, tokenizer_model_type=None, cache_dir=None, add_block_symbols=
])
self.num_tokens += 2
self.num_command_tokens += 2
if add_task_mask:
self._command_tokens.extend([
CommandToken('gMASK', '[gMASK]', self.num_tokens)
])
self.num_tokens += 1
self.num_command_tokens += 1
if add_sentinel_token > 0:
for i in range(1, add_sentinel_token):
self._command_tokens.extend([CommandToken(f'MASK{i}', f'[MASK{i}]', self.num_tokens),
Expand Down
Loading

0 comments on commit 84e0753

Please sign in to comment.