forked from SCIR-HI/Med-ChatGLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat_dataset.py
126 lines (112 loc) · 4.32 KB
/
chat_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
""" PyTorch ChatGLM Dataset. """
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
def get_masks_and_position_ids(
seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True
):
mask_position = (
seq_len - 2
) # is equal to `seq.index(mask_token)` or `seq.index(150001)`
attention_mask = torch.ones(
(1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., : mask_position - 1] = 1
attention_mask = (attention_mask < 0.5).bool()
if position_encoding_2d:
# is equal to `seq_length = seq.index(150004)`
seq_length = seq_len - 1
position_ids = torch.arange(
context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat(
(
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(
context_length - seq_length, dtype=torch.long, device=device
)
+ 1,
)
)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(
context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
return attention_mask, position_ids
def chat_data_collator(features: list) -> dict:
# 只对target的部分计算loss
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids) + 1
input_ids = []
attention_mask_list = []
position_ids_list = []
labels_list = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]
labels = (
[-100] * (seq_len - 1)
+ ids[(seq_len - 1):]
+ [tokenizer.eos_token_id]
+ [-100] * (longest - ids_l - 1)
)
ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
attention_mask, position_ids = get_masks_and_position_ids(
ids, seq_len, longest, _ids.device, gmask=False
)
labels_list.append(torch.LongTensor(labels))
input_ids.append(_ids)
attention_mask_list.append(attention_mask)
position_ids_list.append(position_ids)
input_ids = torch.stack(input_ids)
labels = torch.stack(labels_list)
attention_mask = torch.stack(attention_mask_list)
position_ids = torch.stack(position_ids_list)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
class Chat_Dataset(Dataset):
def __init__(self, data_dir, max_seq_length) -> None:
super().__init__()
self.content = self.load_json(data_dir)
self.encoded_content = self.encode(
tokenizer, self.content, max_seq_length)
self.features = self.encoded_content[0].keys()
def load_json(self, data_dir):
content = []
with open(data_dir, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
content.append(json.loads(line))
return content
def __getitem__(self, index):
return self.encoded_content[index]
def __len__(self):
return len(self.encoded_content)
def get_ori_item(self, index):
return self.content[index]
def encode(self, tokenizer, content, max_seq_length):
encoded_content = []
for example in content:
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(
prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
encoded_content.append(
{"input_ids": input_ids, "seq_len": len(prompt_ids)})
return encoded_content