Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Feb 3, 2020
2 parents 429ce41 + bfa8fe4 commit 8801350
Showing 1 changed file with 51 additions and 37 deletions.
88 changes: 51 additions & 37 deletions data_utils/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import queue
import re
import unicodedata
import unidecode
from tqdm import tqdm
import torch
from textwrap import shorten
Expand All @@ -33,6 +34,9 @@ def get_doc_len(s, tokenizer):
def process_document(document, max_doc_length, tokenizer=None):
str_lens = []
writes = []

if type(document) == str:
document = document.split("\n")

tok_total, word_total, sentence_total, document_total = 0, 0, 0, 0

Expand All @@ -42,16 +46,16 @@ def process_document(document, max_doc_length, tokenizer=None):
string_document = re.sub(r'[^\w\s.,!?:;\"\'“”‘’]', '', string_document)
# Filter documents where special characters makes up > 10% of the document
if float(len(string_document)) < 0.9 * len(' '.join(document)):
return None
return [], [], 0, 0, 0, 0
# Filter documents containing less than 10 words
if len(string_document.split(' ')) < 10:
return None
if len(string_document.split(' ')) < 10:
return [], [], 0, 0, 0, 0
# Filter documents containing less than 100 characters
if len(string_document) < 100:
return None
return [], [], 0, 0, 0, 0
# Filter documents containing a single sentence
if len(document) < 2:
return None
return [], [], 0, 0, 0, 0

num_toks = 0
doc_len = 0
Expand All @@ -70,8 +74,14 @@ def process_document(document, max_doc_length, tokenizer=None):
continue
if isinstance(s, dict):
s = s['text']
s += "\n"
# Ensure exactly one terminal newline char
s = s.strip("\n") + "\n"
# Translate some weird utf-8 characters to their more regular counterparts
s = s.translate(DatasetWriter.transl_table)
# Remove the rest of the weird utf-8 characters
#s = ''.join([chr(c) for c in s.encode('utf-8') if c < 128]) # [9,10,13] + list(range(32,127))])
#s = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\xff]', '', s)
s = unidecode.unidecode(s)
encoded = unicodedata.normalize('NFKD', s).encode('utf-8') # clean(s)
doc_bytes += encoded
str_cnt += len(encoded)
Expand All @@ -92,6 +102,7 @@ def process_document(document, max_doc_length, tokenizer=None):
document_total += 1
num_sents = 0
doc_len = 0
num_toks = 0

# Append write data
writes += [doc_bytes + doc_separator]
Expand All @@ -111,14 +122,14 @@ def process_document(document, max_doc_length, tokenizer=None):
# Append write data
writes += [doc_bytes + doc_separator]
str_lens.append(str_cnt + 1) # + 1 for doc separator

return writes, str_lens, tok_total, word_total, sentence_total, document_total

class DatasetWriter:

transl_table = dict([(ord(x), ord(y)) for x, y in zip(u"‘’´“”–-", u"'''\"\"--")])
transl_table = dict([(ord(x), ord(y)) for x, y in zip(u"‘’´“”––-æ", u"'''\"\"---e")])

def __init__(self, name, read_path, path_ext=None, max_doc_length=512, preamble_len=100, from_text_files=False, split_on_newlines=False):
def __init__(self, name, read_path, path_ext=None, max_doc_length=1024, preamble_len=100, from_text_files=False, split_on_newlines=False):
"""
:param name [string]: Name of the dataset
:param read_path Union[string, List[string]]: If using text files, the base read path to the files, else a list of datasets
Expand Down Expand Up @@ -158,7 +169,10 @@ def create(self):
doc_iter = self.dataset_iterator(self.read_path) if not self.from_text_files else \
self.text_file_iterator(self.read_path, self.path_ext)
for doc_info in doc_iter:
writes, str_lens, toks, words, sents, documents = doc_info
if len(doc_info) == 1:
doc_info = doc_info[0]
writes, str_lens, toks, words, sents, documents = doc_info

self.write_document(writes, str_lens)
self.update_stats(toks, words, sents, documents)

Expand All @@ -174,6 +188,15 @@ def init_dataset_stats(self):
# self.shortest_len = self.max_doc_length

def print_stats(self):
if type(self.tok_total) == torch.Tensor:
self.tok_total = self.tok_total.item()
if type(self.word_total) == torch.Tensor:
self.word_total = self.word_total.item()
if type(self.sentence_total) == torch.Tensor:
self.sentence_total = self.sentence_total.item()
if type(self.document_total) == torch.Tensor:
self.document_total = self.document_total.item()

stat_str = ""
stat_str += "Total number of tokens: {}\n".format(self.tok_total)
stat_str += "Total number of words: {}\n".format(self.word_total)
Expand All @@ -191,6 +214,10 @@ def print_stats(self):

def write_document(self, writes, str_lens):
assert len(writes) == len(str_lens)
#if type(writes[0]) == tuple:
# writes = [w[0] for w in writes]
#if type(str_lens[0]) == torch.Tensor:
# str_lens = [s.item() for s in str_lens]
for i in range(len(writes)):
self.write_file.write(writes[i])
self.str_lens.append(str_lens[i])
Expand All @@ -202,7 +229,6 @@ def update_stats(self, toks, words, sents, documents):
self.sentence_total += sents
self.document_total += documents


def dataset_iterator(self, paths):
data_set_args = {
'path': paths, # ['wikipedia', 'cnn_dailymail', 'gutenberg'],
Expand All @@ -228,15 +254,23 @@ def dataset_iterator(self, paths):
print("Starting length:", len(ds))

fd = FilterDataset(ds, tokenizer, self.max_doc_length)
sampler = torch.utils.data.SequentialSampler(fd)
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, False)
#sampler = torch.utils.data.SequentialSampler(fd)
#batch_sampler = torch.utils.data.BatchSampler(sampler, 1, False)

data_loader = torch.utils.data.dataloader.DataLoader(fd,
batch_sampler=batch_sampler,
num_workers=31,
collate_fn=lambda x: x,
num_workers=30,
pin_memory=True)

for i, doc_info in data_loader:
dl_iter = iter(data_loader)
for i in tqdm(range(len(ds))):
try:
doc_info = next(dl_iter)
except (TypeError) as e:
print("Caught {}".format(e))
continue
if len(doc_info[0]) == 0:
continue
yield doc_info

def text_file_iterator(self, base_read_path, read_paths_exts=None):
Expand Down Expand Up @@ -288,26 +322,6 @@ def convert_into_sentences(self, text_file):
return paragraphs


# def worker_init():
# global ds_
# global tokenizer_
# ds_, tokenizer_ = data_utils.make_dataset(**data_set_args)
# ds_.SetTokenizer(None)
#
# def work(self_idx):
# start_idx = self_idx * bin_size
# end_idx = min((self_idx + 1) * bin_size, len(ds))
# word_in_num_docs = {}
# for i in range(int(start_idx), int(end_idx)):
# doc = get_doc(ds_, i)
# tokens = set(sentence_tokenize(tokenizer_, doc))
# for tok in tokens:
# word_in_num_docs[tok] = word_in_num_docs.get(tok, 0) + 1
# print("Finished with bin", self_idx, flush=True)
#
# return word_in_num_docs


class FilterDataset(data.Dataset):
"""
Abstract bert dataset.
Expand Down Expand Up @@ -343,7 +357,7 @@ def __getitem__(self, idx):
if __name__ == "__main__":
#base_read_path = "/scratch/gobi1/datasets/NLP-Corpus/CNN_dailymail/"
#read_path_extension = ["cnn/stories/", "dailymail/stories/"]
base_read_path = ['bookcorpus', 'wikipedia'] #"/h/stephaneao/bookcorpus_clean"
base_read_path = ['bookcorpus', 'wikipedia'] #"/h/stephaneao/bookcorpus_clean"
read_path_extension = None #["books_large_p1_clean.txt", "books_large_p2_clean.txt"]
with DatasetWriter("bert_corpus", base_read_path, read_path_extension, from_text_files=False, split_on_newlines=True) as dw:
dw.create()

0 comments on commit 8801350

Please sign in to comment.