Skip to content

Commit

Permalink
first step to fix load/save - not working yet
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Feb 16, 2020
1 parent 1639ebf commit 0960934
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
23 changes: 13 additions & 10 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def add_training_args(parser):
help='gradient clipping')
group.add_argument('--epochs', type=int, default=8,
help='upper epoch limit')
group.add_argument('--log-interval', type=int, default=100000,
group.add_argument('--log-interval', type=int, default=1000000,
help='report interval')
group.add_argument('--train-iters', type=int, default=1000000,
help='number of iterations per epoch')
group.add_argument('--train-tokens', type=int, default=500000000,
group.add_argument('--train-tokens', type=int, default=100000000,
help='number of tokens per epoch')
group.add_argument('--seed', type=int, default=1234,
help='random seed')
Expand All @@ -112,23 +112,23 @@ def add_training_args(parser):
help='Output directory to save checkpoints to.')
group.add_argument('--save-iters', type=int, default=None,
help='Save every so often iterations.')
group.add_argument('--save-optim', action='store_true',
group.add_argument('--save-optim', default=True,
help='Save current optimizer.')
group.add_argument('--save-rng', action='store_true',
group.add_argument('--save-rng', default=True,
help='Save current rng state.')
group.add_argument('--save-all-rng', action='store_true',
group.add_argument('--save-all-rng', default=True,
help='Save current rng state of each rank in '
'distributed training.')
group.add_argument('--load', type=str, default=None,
help='Path to a particular model checkpoint. \
(ex. `savedir/model.1000.pt`)')
group.add_argument('--load-optim', action='store_true',
group.add_argument('--load-optim', default=True,
help='Load most recent optimizer corresponding '
'to `--load`.')
group.add_argument('--load-rng', action='store_true',
group.add_argument('--load-rng', default=True,
help='Load most recent rng state corresponding '
'to `--load`.')
group.add_argument('--load-all-rng', action='store_true',
group.add_argument('--load-all-rng', default=True,
help='Load most recent rng state of each rank in '
'distributed training corresponding to `--load`('
'complementary to `--save-all-rng`).')
Expand Down Expand Up @@ -170,7 +170,7 @@ def add_evaluation_args(parser):
group.add_argument('--eval-iters', type=int, default=2000,
help='number of iterations per epoch to run '
'validation/test for')
group.add_argument('--eval-tokens', type=int, default=5000000, #00,
group.add_argument('--eval-tokens', type=int, default=1000000, #00,
help='number of tokens per epoch to run '
'validation/test for')
group.add_argument('--eval-seq-length', type=int, default=None,
Expand Down Expand Up @@ -286,7 +286,7 @@ def get_args():
m = re.search(r'(?m)^Cpus_allowed:\s*(.*)$',
open('/proc/self/status').read())
nw = bin(int(m.group(1).replace(',', ''), 16)).count('1')
args.num_workers = int(0.85 * nw) # leave 1 cpu for main process
args.num_workers = int(0.80 * nw) # leave cpu for main process

args.model_type += '_inc' if args.incremental else ''
args.model_type += '_alt' if args.alternating else ''
Expand All @@ -297,6 +297,9 @@ def get_args():
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

args.save_all_rng = args.save_all_rng and args.world_size > 1
args.load_all_rng = args.load_all_rng and args.world_size > 1

args.dynamic_loss_scale = True

args.fp32_embedding = False
Expand Down
7 changes: 4 additions & 3 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,9 @@ def get_sentence(self, target_seq_length, num_sents, rng, non_contiguous=False,
while diff_doc and idx == self.idx:
idx = rng.randint(0, self.ds_len - 1)
doc = self.sentence_split(self.get_doc(idx))

# Get enough sentences for target length
if len(doc) < 2:
print(idx, doc, "YIKES")
end_idx = rng.randint(0, len(doc) - 1)
start_idx = end_idx - 1
total_length = 0
Expand Down Expand Up @@ -859,8 +860,8 @@ def get_sentence(self, target_seq_length, num_sents, rng, non_contiguous=False,


if len(sentences) < num_sent_required:
print(doc)
print(len(sentences), num_sent_required)
print(idx, doc)
#print(len(sentences), num_sent_required)
# TODO get rid of this
#print(doc)
sentences = [self.sentence_tokenize("Data processing is hard."), self.sentence_tokenize("Sorry about the mistakes.")]
Expand Down
Binary file modified idf.p
Binary file not shown.
9 changes: 5 additions & 4 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def setup_model_and_optimizer(args, tokenizer):
if args.load is not None:
epoch, i, total_iters = load_checkpoint(model, optimizer,
lr_scheduler, args)
args.resume_dataloader = True
if args.resume_dataloader:
args.epoch = epoch
args.mid_epoch_iters = i
Expand Down Expand Up @@ -160,7 +161,7 @@ def forward_step(data, model, criterion, modes, args):
if "rg" in modes:
aux_labels['rg'] = torch.autograd.Variable(torch.arange(tokens[0].shape[0]).long()).cuda()
if "fs" in modes:
aux_labels['fs'] = torch.autograd.Variable(torch.ones(args.batch_size * 2 * args.seq_length).long()).cuda()
aux_labels['fs'] = torch.autograd.Variable(torch.ones(tokens[0].shape[0] * 2 * args.seq_length).long()).cuda()
# Forward model.
scores = model(modes, tokens, types, tasks, att_mask, checkpoint_activations=args.checkpoint_activations)
assert sorted(list(scores.keys())) == sorted(modes)
Expand Down Expand Up @@ -334,7 +335,7 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
save_checkpoint(model_suffix, epoch, iteration, model, optimizer,
lr_scheduler, args)

print("Learnt using {} tokens over {} iterations this epoch".format(tot_tokens, tot_iteration))
print("Learnt using {} tokens over {} iterations this epoch".format(tot_tokens, tot_iteration + iteration))
return tot_iteration, skipped_iters

def evaluate(epoch, data_source, model, criterion, elapsed_time, args, test=False):
Expand Down Expand Up @@ -499,7 +500,7 @@ def main():
total_iters += iteration
skipped_iters += skipped

if args.save and False:
if args.save:
ck_path = 'ck/model_{}.pt'.format(epoch)
print('saving ck model to:',
os.path.join(args.save, ck_path))
Expand All @@ -510,7 +511,7 @@ def main():

if val_loss < best_val_loss:
best_val_loss = val_loss
if args.save and False:
if args.save:
best_path = 'best/model.pt'
print('saving best model to:',
os.path.join(args.save, best_path))
Expand Down

0 comments on commit 0960934

Please sign in to comment.