-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
128 lines (101 loc) · 3.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from config import *
import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim as optim
def ndcg(scores, labels, k):
scores = scores.cpu()
labels = labels.cpu()
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hits = labels.gather(1, cut)
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits.float() * weights).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in labels.sum(1)])
ndcg = dcg / idcg
return ndcg.mean()
def get_metric(pred_list, topk=10):
NDCG = 0.0
Recall = 0.0
# [batch] the answer's rank
for rank in pred_list:
if rank < topk:
NDCG += 1.0 / np.log2(rank + 2.0)
Recall += 1.0
return Recall /len(pred_list), NDCG /len(pred_list)
def absolute_recall_mrr_ndcg_for_ks(scores, labels, ks):
metrics = {}
labels = F.one_hot(labels, num_classes=scores.size(1))
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
for k in sorted(ks, reverse=True):
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
metrics['Recall@%d' % k] = \
(hits.sum(1) / torch.min(torch.Tensor([k]).to(
labels.device), labels.sum(1).float())).mean().cpu().item()
metrics['MRR@%d' % k] = \
(hits / torch.arange(1, k+1).unsqueeze(0).to(
labels.device)).sum(1).mean().cpu().item()
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
metrics['NDCG@%d' % k] = ndcg.cpu().item()
return metrics
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)