Skip to content

Commit

Permalink
add label padding
Browse files Browse the repository at this point in the history
  • Loading branch information
pankeyu authored and pankeyu committed Nov 30, 2022
1 parent 5368aa7 commit d68ef83
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
8 changes: 5 additions & 3 deletions prompt_tasks/PET/class_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, confusion_matrix


Expand Down Expand Up @@ -75,7 +76,7 @@ def compute(self, round_num=2) -> dict:
}
}
"""
classes, class_metrics, res = sorted(list(set(self.goldens))), {}, {}
classes, class_metrics, res = sorted(list(set(self.goldens) | set(self.predictions))), {}, {}
res['accuracy'] = round(accuracy_score(self.goldens, self.predictions), round_num) # 构建全局指标
res['precision'] = round(precision_score(self.goldens, self.predictions, average='weighted'), round_num)
res['recall'] = round(recall_score(self.goldens, self.predictions, average='weighted'), round_num)
Expand All @@ -96,8 +97,9 @@ def compute(self, round_num=2) -> dict:
res['class_metrics'] = class_metrics
except Exception as e:
print(f'[Warning] Something wrong when calculate class_metrics: {e}')
print(f'goldens: {set(self.goldens)}')
print(f'predictions: {set(self.predictions)}')
print(f'-> goldens: {set(self.goldens)}')
print(f'-> predictions: {set(self.predictions)}')
print(f'-> diff elements: {set(self.predictions) - set(self.goldens)}')
res['class_metrics'] = {}

return res
Expand Down
3 changes: 3 additions & 0 deletions prompt_tasks/PET/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def evaluate_model(model, metric, data_loader, global_step, tokenizer, verbalize
token_type_ids=batch['token_type_ids'].to(args.device),
attention_mask=batch['attention_mask'].to(args.device)).logits
mask_labels = batch['mask_labels'].numpy().tolist() # (batch, label_num)
for i in range(len(mask_labels)): # 去掉label中的[PAD] token
while tokenizer.pad_token_id in mask_labels[i]:
mask_labels[i].remove(tokenizer.pad_token_id)
mask_labels = [''.join(tokenizer.convert_ids_to_tokens(t)) for t in mask_labels] # id转文字
predictions = convert_logits_to_ids(logits, batch['mask_positions']).cpu().numpy().tolist() # (batch, label_num)
predictions = verbalizer.batch_find_main_label(predictions) # 找到子label属于的主label
Expand Down
17 changes: 11 additions & 6 deletions prompt_tasks/PET/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,18 @@ def convert_example(
'mask_labels': []
}

for example in examples['text']:
for i, example in enumerate(examples['text']):
try:
if train_mode:
label, content = example.strip().split('\t')
else:
content = example.strip()

content = content[:max_seq_len-10] # 防止当[MASK]在尾部的时候被截掉
s, origin_p, o = content.split(',')
inputs_dict={
'textA': content, # 传入对应prompt的自定义参数
'Subject': s, # 传入对应prompt的自定义参数
'OriginP': origin_p,
'Object': o,
'MASK': '[MASK]'
}
encoded_inputs = template(
Expand All @@ -92,15 +94,18 @@ def convert_example(
mask_length=max_label_len
)
except:
print(f'"{example}" -> {traceback.format_exc()}')
print(f'Error Line {i+1}: "{example}" -> {traceback.format_exc()}')
exit()
tokenized_output['input_ids'].append(encoded_inputs["input_ids"])
tokenized_output['token_type_ids'].append(encoded_inputs["token_type_ids"])
tokenized_output['attention_mask'].append(encoded_inputs["attention_mask"])
tokenized_output['mask_positions'].append(encoded_inputs["mask_position"])
if train_mode:
label_encoded = tokenizer(text=[label])
tokenized_output['mask_labels'].append(label_encoded['input_ids'][0][1:-1])
label_encoded = tokenizer(text=[label]) # 将label补到最大长度
label_encoded = label_encoded['input_ids'][0][1:-1]
label_encoded = label_encoded[:max_label_len]
label_encoded = label_encoded + [tokenizer.pad_token_id] * (max_label_len - len(label_encoded))
tokenized_output['mask_labels'].append(label_encoded)

for k, v in tokenized_output.items():
if return_tensor:
Expand Down
2 changes: 2 additions & 0 deletions prompt_tasks/PET/verbalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def find_sub_labels(self, label: Union[list, str]) -> dict:
}
"""
if type(label) == list: # 如果传入为id_list, 则通过tokenizer转回来
while self.tokenizer.pad_token_id in label:
label.remove(self.tokenizer.pad_token_id)
label = ''.join(self.tokenizer.convert_ids_to_tokens(label))

if label not in self.label_dict:
Expand Down
8 changes: 5 additions & 3 deletions prompt_tasks/p-tuning/class_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, confusion_matrix


Expand Down Expand Up @@ -75,7 +76,7 @@ def compute(self, round_num=2) -> dict:
}
}
"""
classes, class_metrics, res = sorted(list(set(self.goldens))), {}, {}
classes, class_metrics, res = sorted(list(set(self.goldens) | set(self.predictions))), {}, {}
res['accuracy'] = round(accuracy_score(self.goldens, self.predictions), round_num) # 构建全局指标
res['precision'] = round(precision_score(self.goldens, self.predictions, average='weighted'), round_num)
res['recall'] = round(recall_score(self.goldens, self.predictions, average='weighted'), round_num)
Expand All @@ -96,8 +97,9 @@ def compute(self, round_num=2) -> dict:
res['class_metrics'] = class_metrics
except Exception as e:
print(f'[Warning] Something wrong when calculate class_metrics: {e}')
print(f'goldens: {set(self.goldens)}')
print(f'predictions: {set(self.predictions)}')
print(f'-> goldens: {set(self.goldens)}')
print(f'-> predictions: {set(self.predictions)}')
print(f'-> diff elements: {set(self.predictions) - set(self.goldens)}')
res['class_metrics'] = {}

return res
Expand Down
8 changes: 5 additions & 3 deletions utils/class_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score, confusion_matrix


Expand Down Expand Up @@ -75,7 +76,7 @@ def compute(self, round_num=2) -> dict:
}
}
"""
classes, class_metrics, res = sorted(list(set(self.goldens))), {}, {}
classes, class_metrics, res = sorted(list(set(self.goldens) | set(self.predictions))), {}, {}
res['accuracy'] = round(accuracy_score(self.goldens, self.predictions), round_num) # 构建全局指标
res['precision'] = round(precision_score(self.goldens, self.predictions, average='weighted'), round_num)
res['recall'] = round(recall_score(self.goldens, self.predictions, average='weighted'), round_num)
Expand All @@ -96,8 +97,9 @@ def compute(self, round_num=2) -> dict:
res['class_metrics'] = class_metrics
except Exception as e:
print(f'[Warning] Something wrong when calculate class_metrics: {e}')
print(f'goldens: {set(self.goldens)}')
print(f'predictions: {set(self.predictions)}')
print(f'-> goldens: {set(self.goldens)}')
print(f'-> predictions: {set(self.predictions)}')
print(f'-> diff elements: {set(self.predictions) - set(self.goldens)}')
res['class_metrics'] = {}

return res
Expand Down

0 comments on commit d68ef83

Please sign in to comment.