-
Notifications
You must be signed in to change notification settings - Fork 54
/
rank_loss.py
211 lines (166 loc) · 9.42 KB
/
rank_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch.nn.functional as F
import torch
from torch.nn import BCELoss, BCEWithLogitsLoss
from itertools import product
class RankLoss:
@staticmethod
def softmax_ce_loss(y_pred, *args, **kwargs):
return F.cross_entropy(y_pred, torch.zeros((y_pred.size(0),)).long().cuda())
@staticmethod
def pointwise_rmse(y_pred, y_true=None):
if y_true is None:
y_true = torch.zeros_like(y_pred).to(y_pred.device)
y_true[:, 0] = 1
errors = (y_true - y_pred)
squared_errors = errors ** 2
valid_mask = (y_true != -100).float()
mean_squared_errors = torch.sum(squared_errors, dim=1) / torch.sum(valid_mask, dim=1)
rmses = torch.sqrt(mean_squared_errors)
return torch.mean(rmses)
@staticmethod
def pointwise_bce(y_pred, y_true=None):
if y_true is None:
y_true = torch.zeros_like(y_pred).float().to(y_pred.device)
y_true[:, 0] = 1
loss = F.binary_cross_entropy(torch.sigmoid(y_pred), y_true)
return loss
@staticmethod
def list_net(y_pred, y_true=None, padded_value_indicator=-100, eps=1e-10):
"""
ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
if y_true is None:
y_true = torch.zeros_like(y_pred).to(y_pred.device)
y_true[:, 0] = 1
preds_smax = F.softmax(y_pred, dim=1)
true_smax = F.softmax(y_true, dim=1)
preds_smax = preds_smax + eps
preds_log = torch.log(preds_smax)
return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
@staticmethod
def rank_net(y_pred, y_true=None, padded_value_indicator=-100, weight_by_diff=False,
weight_by_diff_powed=False):
"""
RankNet loss introduced in "Learning to Rank using Gradient Descent".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param weight_by_diff: flag indicating whether to weight the score differences by ground truth differences.
:param weight_by_diff_powed: flag indicating whether to weight the score differences by the squared ground truth differences.
:return: loss value, a torch.Tensor
"""
if y_true is None:
y_true = torch.zeros_like(y_pred).to(y_pred.device)
y_true[:, 0] = 1
# here we generate every pair of indices from the range of document length in the batch
document_pairs_candidates = list(product(range(y_true.shape[1]), repeat=2))
pairs_true = y_true[:, document_pairs_candidates]
selected_pred = y_pred[:, document_pairs_candidates]
# here we calculate the relative true relevance of every candidate pair
true_diffs = pairs_true[:, :, 0] - pairs_true[:, :, 1]
pred_diffs = selected_pred[:, :, 0] - selected_pred[:, :, 1]
# here we filter just the pairs that are 'positive' and did not involve a padded instance
# we can do that since in the candidate pairs we had symetric pairs so we can stick with
# positive ones for a simpler loss function formulation
the_mask = (true_diffs > 0) & (~torch.isinf(true_diffs))
pred_diffs = pred_diffs[the_mask]
weight = None
if weight_by_diff:
abs_diff = torch.abs(true_diffs)
weight = abs_diff[the_mask]
elif weight_by_diff_powed:
true_pow_diffs = torch.pow(pairs_true[:, :, 0], 2) - torch.pow(pairs_true[:, :, 1], 2)
abs_diff = torch.abs(true_pow_diffs)
weight = abs_diff[the_mask]
# here we 'binarize' true relevancy diffs since for a pairwise loss we just need to know
# whether one document is better than the other and not about the actual difference in
# their relevancy levels
true_diffs = (true_diffs > 0).type(torch.float32)
true_diffs = true_diffs[the_mask]
return BCEWithLogitsLoss(weight=weight)(pred_diffs, true_diffs)
@staticmethod
def lambda_loss(y_pred, y_true=None, eps=1e-10, padded_value_indicator=-100, weighing_scheme=None, k=None,
sigma=1., mu=10., reduction="mean", reduction_log="binary"):
"""
LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param weighing_scheme: a string corresponding to a name of one of the weighing schemes
:param k: rank at which the loss is truncated
:param sigma: score difference weight used in the sigmoid function
:param mu: optional weight used in NDCGLoss2++ weighing scheme
:param reduction: losses reduction method, could be either a sum or a mean
:param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
:return: loss value, a torch.Tensor
"""
if y_true is None:
y_true = torch.zeros_like(y_pred).to(y_pred.device)
y_true[:, 0] = 1
device = y_pred.device
# Here we sort the true and predicted relevancy scores.
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
padded_pairs_mask = torch.isfinite(true_diffs)
if weighing_scheme != "ndcgLoss1_scheme":
padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)
ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
ndcg_at_k_mask[:k, :k] = 1
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
true_sorted_by_preds.clamp_(min=0.)
y_true_sorted.clamp_(min=0.)
# Here we find the gains, discounts and ideal DCGs per slate.
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
D = torch.log2(1. + pos_idxs.float())[None, :]
maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
# Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
if weighing_scheme is None:
weights = 1.
else:
weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds) # type: ignore
# We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
if reduction_log == "natural":
losses = torch.log(weighted_probas)
elif reduction_log == "binary":
losses = torch.log2(weighted_probas)
else:
raise ValueError("Reduction logarithm base can be either natural or binary")
if reduction == "sum":
loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
elif reduction == "mean":
loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
else:
raise ValueError("Reduction method can be either sum or mean")
return loss
def ndcgLoss1_scheme(G, D, *args):
return (G / D)[:, :, None]
def ndcgLoss2_scheme(G, D, *args):
pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
deltas.diagonal().zero_()
return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])
def lambdaRank_scheme(G, D, *args):
return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(
G[:, :, None] - G[:, None, :])
def ndcgLoss2PP_scheme(G, D, *args):
return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)
def rankNet_scheme(G, D, *args):
return 1.
def rankNetWeightedByGTDiff_scheme(G, D, *args):
return torch.abs(args[1][:, :, None] - args[1][:, None, :])
def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))