forked from kuangliu/pytorch-retinanet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
96 lines (75 loc) · 3.42 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import one_hot_embedding
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, num_classes=20):
super(FocalLoss, self).__init__()
self.num_classes = num_classes
def focal_loss(self, x, y):
'''Focal loss.
Args:
x: (tensor) sized [N,D].
y: (tensor) sized [N,].
Return:
(tensor) focal loss.
'''
alpha = 0.25
gamma = 2
t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) # [N,21]
t = t[:,1:] # exclude background
t = Variable(t).cuda() # [N,20]
p = x.sigmoid()
pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1-p
w = alpha*t + (1-alpha)*(1-t) # w = alpha if t > 0 else 1-alpha
w = w * (1-pt).pow(gamma)
return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_alt(self, x, y):
'''Focal loss alternative.
Args:
x: (tensor) sized [N,D].
y: (tensor) sized [N,].
Return:
(tensor) focal loss.
'''
alpha = 0.25
t = one_hot_embedding(y.data.cpu(), 1+self.num_classes)
t = t[:,1:]
t = Variable(t).cuda()
xt = x*(2*t-1) # xt = x if t > 0 else -x
pt = (2*xt+1).sigmoid()
w = alpha*t + (1-alpha)*(1-t)
loss = -w*pt.log() / 2
return loss.sum()
def forward(self, loc_preds, loc_targets, cls_preds, cls_targets):
'''Compute loss between (loc_preds, loc_targets) and (cls_preds, cls_targets).
Args:
loc_preds: (tensor) predicted locations, sized [batch_size, #anchors, 4].
loc_targets: (tensor) encoded target locations, sized [batch_size, #anchors, 4].
cls_preds: (tensor) predicted class confidences, sized [batch_size, #anchors, #classes].
cls_targets: (tensor) encoded target labels, sized [batch_size, #anchors].
loss:
(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + FocalLoss(cls_preds, cls_targets).
'''
batch_size, num_boxes = cls_targets.size()
pos = cls_targets > 0 # [N,#anchors]
num_pos = pos.data.long().sum()
################################################################
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
################################################################
mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]
masked_loc_preds = loc_preds[mask].view(-1,4) # [#pos,4]
masked_loc_targets = loc_targets[mask].view(-1,4) # [#pos,4]
loc_loss = F.smooth_l1_loss(masked_loc_preds, masked_loc_targets, size_average=False)
################################################################
# cls_loss = FocalLoss(loc_preds, loc_targets)
################################################################
pos_neg = cls_targets > -1 # exclude ignored anchors
mask = pos_neg.unsqueeze(2).expand_as(cls_preds)
masked_cls_preds = cls_preds[mask].view(-1,self.num_classes)
cls_loss = self.focal_loss_alt(masked_cls_preds, cls_targets[pos_neg])
print('loc_loss: %.3f | cls_loss: %.3f' % (loc_loss.data[0]/num_pos, cls_loss.data[0]/num_pos), end=' | ')
loss = (loc_loss+cls_loss)/num_pos
return loss