From 7515630558be4c1bfc339c19391f7eef4f7b732e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Ciscyy=E2=80=9D?= <“wenvoi@163.com”> Date: Mon, 8 Aug 2022 19:43:22 +0800 Subject: [PATCH] EIoU SIoU alpha --- utils/loss.py | 4 ++- utils/metrics.py | 91 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index ab84454..221e8b8 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from utils.metrics import bbox_iou, box_iou +from utils.metrics import bbox_iou, box_iou, bbox_alpha_iou from utils.torch_utils import de_parallel, is_parallel from utils.general import xywh2xyxy import torch.nn.functional as F @@ -135,7 +135,9 @@ def __call__(self, p, targets): # predictions, targets, model pxy = ps[:, :2].sigmoid() * 2 - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] pbox = torch.cat((pxy, pwh), 1) # predicted box + # IoU update add EIoU, SIoU iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) + # iou = bbox_alpha_iou(pbox.T, tbox[i], x1y1x2y2=False, alpha=3, CIoU=True) # use alpha IoU lbox += (1.0 - iou).mean() # iou loss # Objectness diff --git a/utils/metrics.py b/utils/metrics.py index 857fa5d..62503fb 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -198,8 +198,58 @@ def print(self): for i in range(self.nc + 1): print(' '.join(map(str, self.matrix[i]))) +def bbox_alpha_iou(box1, box2, x1y1x2y2=False, GIoU=False, DIoU=False, CIoU=False, alpha=2, eps=1e-9): + # Returns tsqrt_he IoU of box1 to box2. box1 is 4, box2 is nx4 + box2 = box2.T + + # Get the coordinates of bounding boxes + if x1y1x2y2: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: # transform from xywh to xyxy + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) -def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + # Union Area + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + union = w1 * h1 + w2 * h2 - inter + eps + + # change iou into pow(iou+eps) + # iou = inter / union + iou = torch.pow(inter/union + eps, alpha) + # beta = 2 * alpha + if GIoU or DIoU or CIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal + rho_x = torch.abs(b2_x1 + b2_x2 - b1_x1 - b1_x2) + rho_y = torch.abs(b2_y1 + b2_y2 - b1_y1 - b1_y2) + rho2 = ((rho_x ** 2 + rho_y ** 2) / 4) ** alpha # center distance + if DIoU: + return iou - rho2 / c2 # DIoU + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha_ciou = v / ((1 + eps) - inter / union + v) + # return iou - (rho2 / c2 + v * alpha_ciou) # CIoU + return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoU + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf + # c_area = cw * ch + eps # convex area + # return iou - (c_area - union) / c_area # GIoU + c_area = torch.max(cw * ch + eps, union) # convex area + return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoU + else: + return iou # torch.log(iou+eps) or iou + +def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, EIoU=False, SIoU=False, eps=1e-7): # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 box2 = box2.T @@ -223,21 +273,46 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps= union = w1 * h1 + w2 * h2 - inter + eps iou = inter / union - if CIoU or DIoU or GIoU: + if CIoU or DIoU or GIoU or EIoU or SIoU: cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height - if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + if CIoU or DIoU or EIoU or SIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared - if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + if DIoU: #DIoU + return iou - rho2 / c2 # DIoU + elif CIoU: #CIoU https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) - return iou - (rho2 / c2 + v * alpha) # CIoU - return iou - rho2 / c2 # DIoU - c_area = cw * ch + eps # convex area - return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou - (rho2 / c2 + v * alpha) # CIoU + elif SIoU:# SIoU + s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5) + sin_alpha_1 = torch.abs(s_cw) / sigma + sin_alpha_2 = torch.abs(s_ch) / sigma + threshold = pow(2, 0.5) / 2 + sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1) + angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2) + rho_x = (s_cw / cw) ** 2 + rho_y = (s_ch / ch) ** 2 + gamma = angle_cost - 2 + distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y) + omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) + omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) + shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4) + return iou - 0.5 * (distance_cost + shape_cost) + else:# EIoU + w_dis=torch.pow(b1_x2-b1_x1-b2_x2+b2_x1, 2) + h_dis=torch.pow(b1_y2-b1_y1-b2_y2+b2_y1, 2) + cw2=torch.pow(cw , 2)+eps + ch2=torch.pow(ch , 2)+eps + return iou-(rho2/c2+w_dis/cw2+h_dis/ch2) + else: + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf return iou # IoU