forked from HarderThenHarder/transformers_tasks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
115 lines (98 loc) · 6.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# !/usr/bin/env python3
"""
==== No Bugs in code, just some Random Unexpected FEATURES ====
┌─────────────────────────────────────────────────────────────┐
│┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│
││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││
│├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│
││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││
│├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│
││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││
│├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│
││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││
│└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│
│ │Fn │ Alt │ Space │ Alt │Win│ HHKB │
│ └───┴─────┴───────────────────────┴─────┴───┘ │
└─────────────────────────────────────────────────────────────┘
工具类。
Author: pankeyu
Date: 2023/01/04
"""
import traceback
import numpy as np
def convert_example(examples: dict, tokenizer, max_source_seq_len: int, max_target_seq_len: int):
"""
将样本数据转换为模型接收的输入数据。
Args:
examples (dict): 训练数据样本, e.g. -> {
"text": [
'"广东桥头镇位于东莞市东北部,[MASK],全镇总面积56平方公里,常住人口10多万人,是闻名中外的供水香港工程的源头"中[MASK]位置的文本是: 地处东江南岸,毗邻惠州市,属埔田地区',
'"[MASK],著名画家,毕业于中央美术学院"中[MASK]位置的文本是: 任之',
...
]
}
max_source_seq_len (int): encoder 最大输入长度
max_target_seq_len (int): decoder 最大输入长度
Returns:
dict (str: np.array) -> tokenized_output = {
'input_ids': [[1525, 10, ...], [758, 2345, ...]],
'attention_mask': [[1, 1, ...], [1, 1, ...]],
'decoder_input_ids': [[0, 822, ...], [0, 10, ...]],
'labels': [[822, 10, ...], [125, 58...]]
}
"""
tokenized_output = {
'input_ids': [], # encoder 输入
'attention_mask': [], # encoder attention mask
'decoder_input_ids': [], # decoder 输入(right shift)
'labels': [] # decoder 标签
}
for example in examples['text']:
try:
origin_text, mask_labels = example.split('\t')
mask_labels = mask_labels + tokenizer.eos_token
output_ids = tokenizer.encode( # 处理 decoder 输入
text=mask_labels,
truncation=True,
max_length=max_target_seq_len
)
decoder_input_ids = output_ids[:-2] # bert-tokenizer 会加一个[CLS],这个就当成<eos>了,因为是 right-shift
# 所以要-1,又因为 bert-tokenizer会加一个[SEP],所以要-2
decoder_input_ids = decoder_input_ids + [tokenizer.pad_token_id] * (max_target_seq_len - len(decoder_input_ids))
lables = output_ids[1:-1] # 去掉 [CLS] 和 [SEP]
lables = lables + [-100] * (max_target_seq_len - len(lables)) # -100 用于忽略在计算 label loss 时忽略 padding token
inputs = tokenizer( # 处理 encoder 输入
text=origin_text,
truncation=True,
max_length=max_source_seq_len,
padding='max_length'
)
except:
print(f'"{example}" -> {traceback.format_exc()}')
continue
tokenized_output['input_ids'].append(inputs["input_ids"])
tokenized_output['attention_mask'].append(inputs["attention_mask"])
tokenized_output['decoder_input_ids'].append(decoder_input_ids)
tokenized_output['labels'].append(lables)
for k, v in tokenized_output.items():
tokenized_output[k] = np.array(v)
return tokenized_output
if __name__ == '__main__':
from rich import print
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("uer/t5-small-chinese-cluecorpussmall")
tokenizer.eos_token = tokenizer.sep_token
tokenizer.bos_token = tokenizer.cls_token
res = convert_example({
"text": [
'"[MASK],著名画家,毕业于中央美术学院"中[MASK]位置的文本是: 任之'
]
},
tokenizer=tokenizer,
max_source_seq_len=50,
max_target_seq_len=20
)
print(res)
print('input_ids: ', tokenizer.convert_ids_to_tokens(res['input_ids'][0]))
print('decoder_input_ids: ', tokenizer.convert_ids_to_tokens(res['decoder_input_ids'][0]))
print('labels: ', tokenizer.convert_ids_to_tokens(res['labels'][0]))