Skip to content

Commit

Permalink
fix dataloader read issue
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Mar 2, 2020
1 parent 7be6812 commit b38f5e9
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 37 deletions.
2 changes: 1 addition & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def add_training_args(parser):
group.add_argument('--continual-learning', type=str2bool, nargs='?',
const=True, default=False,
help='If true, train new and old losses separately.')
group.add_argument('--always_mlm', type=str2bool, nargs='?',
group.add_argument('--always-mlm', type=str2bool, nargs='?',
const=True, default=False,
help='If true, train new and old losses separately.')
group.add_argument('--no-aux', action='store_true',
Expand Down
4 changes: 4 additions & 0 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,14 @@ def get_sentence(self, target_seq_length, num_sents, rng, non_contiguous=False,
while diff_doc and idx == self.idx:
idx = rng.randint(0, self.ds_len - 1)
doc = self.sentence_split(self.get_doc(idx))

print(doc)
# Get enough sentences for target length
if len(doc) < 2:
print(idx, doc, "YIKES")
print(self.ds.split_inds[idx])
self.ds.wrapped_data.set_flag()
print(self.ds.wrapped_data[self.ds.split_inds[idx]])
doc = self.sentence_split(self.get_doc(rng.randint(0, self.ds_len - 1)))
end_idx = rng.randint(0, len(doc) - 1)
start_idx = end_idx - 1
Expand Down
51 changes: 27 additions & 24 deletions data_utils/lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pickle as pkl
import time
from itertools import accumulate
from threading import Lock
from filelock import FileLock

import torch

Expand Down Expand Up @@ -109,22 +109,19 @@ class lazy_array_loader(object):
"""
def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb')
self.file = self._file
#memory map file if necessary
self.mem_map = mem_map
if self.mem_map:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
self.datapath = os.path.join(lazypath, data_type)
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
self.lens = pkl.load(open(lenpath, 'rb'))
self.ends = list(accumulate(self.lens))
self.dumb_ends = list(self.ends)
self.read_lock = Lock()
self.read_lock = FileLock("bert_data_filelock.txt")
self.process_fn = map_fn
self.map_fn = map_fn
self._tokenizer = None
self.flag = False

def set_flag(self):
self.flag = True

def SetTokenizer(self, tokenizer):
"""
Expand Down Expand Up @@ -156,7 +153,10 @@ def __getitem__(self, index):
except OSError as e:
print(e)
print(index, start, end)
return None
return None
if self.flag:
print("------>", index, start, end, rtn)
self.flag = False
if self.map_fn is not None:
return self.map_fn(rtn)
else:
Expand All @@ -176,25 +176,28 @@ def __getitem__(self, index):
def __len__(self):
return len(self.ends)

def file_read(self, start=0, end=None):
def file_read(self, start=0, end=None, flag=False):
"""read specified portion of file"""

# atomic reads to avoid race conditions with multiprocess dataloader
self.read_lock.acquire()
# seek to start of file read
self.file.seek(start)
# read to end of file if no end point provided
if end is None:
rtn = self.file.read()
#else read amount needed to reach end point
else:
rtn = self.file.read(end-start)
self.read_lock.release()
with self.read_lock:
with open(self.datapath, 'rb') as f:
offset = 0
# seek to start of file read
while f.tell() != end or offset != start:
f.seek(start)
offset = f.tell()
rtn = f.read(end-start)
if f.tell() != end or offset != start:
print("Locking isn't working. Expected ({}, {}), Received: ({}, {}), Off by ({}, {})".format(start, end, offset, f.tell(), offset - start, f.tell() - end), flush=True)
if self.flag:
print("???", rtn)
print(f.tell(), f.read(25))
#TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string
rtn = rtn.decode('utf-8')
# rtn = str(rtn)
if self.mem_map:
rtn = rtn.decode('unicode_escape')
#if self.mem_map:
# rtn = rtn.decode('unicode_escape')
return rtn

6 changes: 3 additions & 3 deletions evaluate/config/test_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ dropout = 0.1 // following BERT paper
optimizer = bert_adam
batch_size = 32
max_epochs = 3
lr = .00001
lr = .00002
min_lr = .0000001
lr_patience = 4
patience = 20
max_vals = 10000

// Tasks
pretrain_tasks = "qqp,rte,sst,sts-b" // glue
target_tasks = "qqp,rte,sst,sts-b" // glue
pretrain_tasks = glue
target_tasks = glue

// Control-flow stuff
do_pretrain = 0
Expand Down
27 changes: 18 additions & 9 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def next_stage():
return {k: total_tokens for k in modes}
assert len(modes) == len(stage_splits[stage_idx])
current_stage = {k: v for k, v in zip(modes, stage_splits[stage_idx])}
print("Starting stage {} of {}, with task distribution: ".format(stage_idx, len(stage_splits)))
print(current_stage)
stage_idx += 1
return current_stage

Expand All @@ -279,6 +281,8 @@ def get_mode_from_stage(current_stage, args):
:return: selected mode
"""
modes = args.modes.split(',')
if args.always_mlm:
modes = modes[1:]
p = np.array([current_stage[m] for m in modes])
p /= np.sum(p)
return [np.random.choice(modes, p=p)]
Expand Down Expand Up @@ -307,7 +311,7 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti

train_data.dataset.set_args(modes)
sent_tasks = [m for m in modes if m in train_data.dataset.sentence_tasks]
tok_tasks = [m for m in modes if m not in train_data.dataset.sentence_tasks]
tok_tasks = [m for m in modes if m not in ([train_data.dataset.sentence_tasks] + ["mlm"])]

data_iters = iter(train_data)

Expand All @@ -316,7 +320,7 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
# ERNIE 2.0's continual multi task learning
if args.continual_learning:
# test 1
modes_ = get_mode_from_stage(current_stage)
modes_ = get_mode_from_stage(current_stage, args)
if args.always_mlm:
# test 2
modes_ = ['mlm'] + modes_
Expand All @@ -333,8 +337,8 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
# Summing all tasks
else:
# test 5 when incremental is False, test 6 when incremental is True
sent_task = [] if len(sent_tasks) == 0 else sent_tasks[iteration % len(sent_tasks)]
modes_ = ['mlm'] + [sent_task] + tok_tasks
sent_task = [] if len(sent_tasks) == 0 else [sent_tasks[iteration % len(sent_tasks)]]
modes_ = ['mlm'] + sent_task + tok_tasks


while True:
Expand All @@ -354,11 +358,16 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti

log_tokens += num_tokens.item()
tot_tokens += num_tokens.item()
for m in modes_:
current_stage[m] = max(0, current_stage[m] - num_tokens.item())

if sum(current_stage.values()) == 0:
current_stage = next_stage()
if args.continual_learning:
for m in modes_:
if args.always_mlm and m == "mlm":
continue
current_stage[m] = max(0, current_stage[m] - num_tokens.item())

if sum(current_stage.values()) == 0:
ns = next_stage()
for m in ns:
current_stage[m] = ns[m]

# Update learning rate.
lr_scheduler.step(step_num=(epoch-1) * max_tokens + tot_tokens)
Expand Down

0 comments on commit b38f5e9

Please sign in to comment.