Skip to content

Commit

Permalink
[Refactor] Add reduction_override in MSELoss (#5437)
Browse files Browse the repository at this point in the history
* [Refactor] Add reduction_override in MSELoss

* add loss test unit

* add loss test unit
  • Loading branch information
BIGWangYuDong authored Jun 29, 2021
1 parent 30a7073 commit 33bc4c5
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 8 deletions.
19 changes: 13 additions & 6 deletions mmdet/models/losses/mse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def __init__(self, reduction='mean', loss_weight=1.0):
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self, pred, target, weight=None, avg_factor=None):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function of loss.
Args:
Expand All @@ -36,14 +41,16 @@ def forward(self, pred, target, weight=None, avg_factor=None):
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * mse_loss(
pred,
target,
weight,
reduction=self.reduction,
avg_factor=avg_factor)
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
89 changes: 87 additions & 2 deletions tests/test_models/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pytest
import torch

from mmdet.models.losses import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss,
IoULoss)
from mmdet.models.losses import (BalancedL1Loss, BoundedIoULoss, CIoULoss,
CrossEntropyLoss, DIoULoss,
DistributionFocalLoss, FocalLoss,
GaussianFocalLoss, GIoULoss, IoULoss, L1Loss,
MSELoss, QualityFocalLoss, SmoothL1Loss,
VarifocalLoss)


@pytest.mark.parametrize(
Expand All @@ -14,3 +18,84 @@ def test_iou_type_loss_zeros_weight(loss_class):

loss = loss_class()(pred, target, weight)
assert loss == 0.


@pytest.mark.parametrize('loss_class', [
IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, MSELoss, L1Loss,
SmoothL1Loss, BalancedL1Loss, FocalLoss, QualityFocalLoss,
GaussianFocalLoss, DistributionFocalLoss, VarifocalLoss, CrossEntropyLoss
])
def test_loss_with_reduction_override(loss_class):
pred = torch.rand((10, 4))
target = torch.rand((10, 4))

with pytest.raises(AssertionError):
# only reduction_override from [None, 'none', 'mean', 'sum']
# is not allowed
reduction_override = True
loss_class()(pred, target, reduction_override=reduction_override)


@pytest.mark.parametrize('loss_class', [
IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, MSELoss, L1Loss,
SmoothL1Loss, BalancedL1Loss
])
def test_regression_losses(loss_class):
pred = torch.rand((10, 4))
target = torch.rand((10, 4))

# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)

# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)

# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)

with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)

# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)


@pytest.mark.parametrize('loss_class', [FocalLoss, CrossEntropyLoss])
def test_classification_losses(loss_class):
pred = torch.rand((10, 5))
target = torch.randint(0, 5, (10, ))

# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)

# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)

# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)

with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)

# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)

0 comments on commit 33bc4c5

Please sign in to comment.