-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_distribution.py
54 lines (38 loc) · 1.5 KB
/
plot_distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
file2 = "esci-data-main/shopping_queries_dataset/teacher_dynamic_scores.tsv"
#file1 = "esci-data-main/shopping_queries_dataset/teacher_dynamic_scores.tsv"
file1 = "/Users/shuaiwang/workspace/2023s1PHD/extra/Balance_training/Data_Sophia/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv"
file_dict = {file1:"MS MARCO",
file2: "Amazon Shopping Queries"}
out = "esci-data-main/shopping_queries_dataset/a.pdf"
files = [file1, file2]
#data_margins = {}
plt.figure(figsize=(10, 6), dpi=80)
for file in files:
data_margins = []
with open(file) as f:
for line in tqdm(f):
ps, ns, qid, didp, didn = line.strip().split()
if file == file2:
margin = (float(ps) - float(ns))/3
else:
margin = float(ps)-float(ns)
#if margin not in data_margins:
# data_margins[margin] = 0
data_margins.append(margin)
#sorted_keys = sorted(data_margins.keys())
sns.distplot(data_margins, hist=False, kde=True,
kde_kws={'fill': True, 'linewidth': 3},
label=file_dict[file])
#plt.hist(data_margins, bins=math.ceil((max(data_margins)-min(data_margins))))
plt.legend(prop={'size': 16})
plt.xlabel('Margin Score', size=16)
plt.ylabel('Density', size=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
plt.savefig(out)
print(max(data_margins), min(data_margins))