-
Notifications
You must be signed in to change notification settings - Fork 1
/
wikipedia_preprocessing.py
56 lines (41 loc) · 1.61 KB
/
wikipedia_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
from torch.functional import F
import transformers as ts
import datasets as ds
from datasets import Dataset, load_dataset, load_from_disk
import numpy as np
import random
random.seed(42)
ds = load_dataset("wikipedia", "20220301.en", cache_dir="dataset/wikipedia")
tokenizer = ts.AutoTokenizer.from_pretrained("bert-base-uncased")
k = 10
def mappingFunction(dataset):
outputs = {
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
"special_tokens_mask": [],
}
for text in dataset["text"]:
output = tokenizer(
text,
max_length=256,
truncation=True,
stride=128,
return_overflowing_tokens=True,
return_special_tokens_mask=True,
)
if len(output["input_ids"]) <= k:
input_ids, attention_mask, token_type_ids, token_mask = (output["input_ids"], output["attention_mask"], output["token_type_ids"], output["special_tokens_mask"])
else:
input_ids, attention_mask, token_type_ids, token_mask = zip(*random.sample(list(zip(output["input_ids"], output["attention_mask"], output["token_type_ids"], output["special_tokens_mask"])), k))
outputs["input_ids"] += input_ids
outputs["attention_mask"] += attention_mask
outputs["token_type_ids"] += token_type_ids
outputs["special_tokens_mask"] += token_mask
return outputs
dataset = ds["train"].map(mappingFunction, remove_columns=ds["train"].column_names, batched=True)
datasetPath = "tokenizedDatasets/wikipedia-256/"
dataset.save_to_disk(datasetPath)
print(load_from_disk(datasetPath))