Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Make accuracy meter consistent with one-hot targets (#349)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #349

This diff makes accuracy meter consistent with both class index and one-hot targets. When the targets are single-labeled, this meter works exactly as present. For multi-labeled targets, accuracy at top k considers a sample as correctly classified if it outputs atleast one correct class within top k.

This diff also makes both accuracy and PR meters infer target type (integer or 0/1) and calculate metrics accordingly.

Reviewed By: mannatsingh

Differential Revision: D19388765

fbshipit-source-id: 7ac63b4f16f454da28ae4f34b6163b362952aaf0
  • Loading branch information
simran2905 authored and facebook-github-bot committed Jan 22, 2020
1 parent c763a34 commit 1aeeeeb
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 83 deletions.
22 changes: 22 additions & 0 deletions classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,28 @@ def convert_to_one_hot(targets, classes):
return one_hot_targets


def maybe_convert_to_one_hot(target, model_output):
"""
This function infers whether target is integer or 0/1 encoded
and converts it to 0/1 encoding if necessary.
"""
target_shape_list = list(target.size())

if len(target_shape_list) == 1 or (
len(target_shape_list) == 2 and target_shape_list[1] == 1
):
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])

assert (target.shape == model_output.shape) and (
torch.min(target.eq(0) + target.eq(1)) == 1
), (
"Target must be one-hot/multi-label encoded and of the "
"same shape as model_output."
)

return target


def bind_method_to_class(method, cls):
"""
Binds an already bound method to the provided class.
Expand Down
25 changes: 16 additions & 9 deletions classy_vision/meters/accuracy_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import torch
from classy_vision.generic.distributed_util import all_reduce_sum
from classy_vision.generic.util import is_pos_int
from classy_vision.generic.util import is_pos_int, maybe_convert_to_one_hot
from classy_vision.meters import ClassyMeter

from . import register_meter


@register_meter("accuracy")
class AccuracyMeter(ClassyMeter):
"""Meter to calculate top-k accuracy for single label
"""Meter to calculate top-k accuracy for single label/ multi label
image classification task.
"""

Expand Down Expand Up @@ -134,20 +134,27 @@ def update(self, model_output, target, **kwargs):
args:
model_output: tensor of shape (B, C) where each value is
either logit or class probability.
target: tensor of shape (B).
Note: For binary classification, C=2.
target: tensor of shape (B, C), which is one-hot /
multi-label encoded, or tensor of shape (B) /
(B, 1), integer encoded
"""
# Due to dummy samples, in some corner cases, the whole batch could
# be dummy samples, in that case we want to not update meters on that
# process
if model_output.shape[0] == 0:
return
_, pred = model_output.topk(max(self._topk), dim=1, largest=True, sorted=True)

correct_predictions = pred.eq(target.unsqueeze(1).expand_as(pred))
# Convert target to 0/1 encoding if isn't
target = maybe_convert_to_one_hot(target, model_output)

_, pred = model_output.topk(max(self._topk), dim=1, largest=True, sorted=True)
for i, k in enumerate(self._topk):
self._curr_correct_predictions_k[i] += (
correct_predictions[:, :k].float().sum().item()
torch.gather(target, dim=1, index=pred[:, :k])
.long()
.max(dim=1)
.values.sum()
.item()
)
self._curr_sample_count += model_output.shape[0]

Expand All @@ -165,8 +172,8 @@ def validate(self, model_output_shape, target_shape):
model_output_shape
)
assert (
len(target_shape) == 1
), "target_shape must be (B) \
len(target_shape) > 0 and len(target_shape) < 3
), "target_shape must be (B) or (B, C) \
Found shape {}".format(
target_shape
)
Expand Down
42 changes: 10 additions & 32 deletions classy_vision/meters/precision_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from classy_vision.generic.distributed_util import all_reduce_sum
from classy_vision.generic.util import convert_to_one_hot, is_pos_int
from classy_vision.generic.util import is_pos_int, maybe_convert_to_one_hot
from classy_vision.meters import ClassyMeter

from . import register_meter
Expand All @@ -21,24 +21,16 @@ class PrecisionAtKMeter(ClassyMeter):
image classification task. Note, ties are resolved randomly.
"""

def __init__(self, topk, target_is_one_hot=True, num_classes=-1):
def __init__(self, topk):
"""
args:
topk: list of int `k` values.
target_is_one_hot: boolean, if class labels are one-hot encoded.
num_classes: int, number of classes.
"""
assert isinstance(topk, list), "topk must be a list"
assert len(topk) > 0, "topk list should have at least one element"
assert [is_pos_int(x) for x in topk], "each value in topk must be >= 1"
if not target_is_one_hot:
assert (
type(num_classes) == int and num_classes > 0
), "num_classes must be positive integer"

self._topk = topk
self._target_is_one_hot = target_is_one_hot
self._num_classes = num_classes

# _total_* variables store running, in-sync totals for the
# metrics. These should not be communicated / summed.
Expand All @@ -65,11 +57,7 @@ def from_config(cls, config: Dict[str, Any]) -> "PrecisionAtKMeter":
Returns:
A PrecisionAtKMeter instance.
"""
return cls(
topk=config["topk"],
target_is_one_hot=config.get("target_is_one_hot", True),
num_classes=config.get("num_classes", -1),
)
return cls(topk=config["topk"])

@property
def name(self):
Expand Down Expand Up @@ -147,29 +135,19 @@ def update(self, model_output, target, **kwargs):
args:
model_output: tensor of shape (B, C) where each value is
either logit or class probability.
target: tensor of shape (B, C), one-hot encoded
or integer encoded or tensor of shape (B),
integer encoded.
Note: For binary classification, C=2.
For integer encoded target, C=1.
target: tensor of shape (B, C), which is one-hot /
multi-label encoded, or tensor of shape (B) /
(B, 1), integer encoded
"""
target_shape_list = list(target.size())

if self._target_is_one_hot is False:
assert len(target_shape_list) == 1 or (
len(target_shape_list) == 2 and target_shape_list[1] == 1
), "Integer encoded target must be single labeled"
target = convert_to_one_hot(target.view(-1, 1), self._num_classes)

assert (
torch.min(target.eq(0) + target.eq(1)) == 1
), "Target must be one-hot encoded vector"

# Due to dummy samples, in some corner cases, the whole batch could
# be dummy samples, in that case we want to not update meters on that
# process
if model_output.shape[0] == 0:
return

# Convert target to 0/1 encoding if isn't
target = maybe_convert_to_one_hot(target, model_output)

_, pred_classes = model_output.topk(
max(self._topk), dim=1, largest=True, sorted=True
)
Expand Down
42 changes: 9 additions & 33 deletions classy_vision/meters/recall_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from classy_vision.generic.distributed_util import all_reduce_sum
from classy_vision.generic.util import convert_to_one_hot, is_pos_int
from classy_vision.generic.util import is_pos_int, maybe_convert_to_one_hot
from classy_vision.meters import ClassyMeter

from . import register_meter
Expand All @@ -24,20 +24,12 @@ def __init__(self, topk, target_is_one_hot=True, num_classes=None):
"""
args:
topk: list of int `k` values.
target_is_one_hot: boolean, if class labels are one-hot encoded.
num_classes: int, number of classes.
"""
assert isinstance(topk, list), "topk must be a list"
assert len(topk) > 0, "topk list should have at least one element"
assert [is_pos_int(x) for x in topk], "each value in topk must be >= 1"
if not target_is_one_hot:
assert (
type(num_classes) == int and num_classes > 0
), "num_classes must be positive integer"

self._topk = topk
self._target_is_one_hot = target_is_one_hot
self._num_classes = num_classes

# _total_* variables store running, in-sync totals for the
# metrics. These should not be communicated / summed.
Expand All @@ -64,11 +56,7 @@ def from_config(cls, config: Dict[str, Any]) -> "RecallAtKMeter":
Returns:
A RecallAtKMeter instance.
"""
return cls(
topk=config["topk"],
target_is_one_hot=config.get("target_is_one_hot", True),
num_classes=config.get("num_classes", None),
)
return cls(topk=config["topk"])

@property
def name(self):
Expand Down Expand Up @@ -146,31 +134,19 @@ def update(self, model_output, target, **kwargs):
args:
model_output: tensor of shape (B, C) where each value is
either logit or class probability.
target: tensor of shape (B, C), one-hot encoded
or integer encoded or tensor of shape (B),
integer encoded.
Note:
For binary classification, C=2. For integer encoded target, C=1.
target: tensor of shape (B, C), which is one-hot /
multi-label encoded, or tensor of shape (B) /
(B, 1), integer encoded
"""

target_shape_list = list(target.size())

if self._target_is_one_hot is False:
assert len(target_shape_list) == 1 or (
len(target_shape_list) == 2 and target_shape_list[1] == 1
), "Integer encoded target must be single labeled"
target = convert_to_one_hot(target.view(-1, 1), self._num_classes)

assert (
torch.min(target.eq(0) + target.eq(1)) == 1
), "Target must be one-hot encoded vector"
# Due to dummy samples, in some corner cases, the whole batch could
# be dummy samples, in that case we want to not update meters on that
# process
if model_output.shape[0] == 0:
return

# Convert target to 0/1 encoding if isn't
target = maybe_convert_to_one_hot(target, model_output)

_, pred_classes = model_output.topk(
max(self._topk), dim=1, largest=True, sorted=True
)
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/meters/video_accuracy_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_meter("video_accuracy")
class VideoAccuracyMeter(VideoMeter):
"""Meter to calculate top-k video-level accuracy for single label
"""Meter to calculate top-k video-level accuracy for single/multi label
video classification task.
Video-level accuarcy is computed by averaging clip-level predictions and
Expand Down
63 changes: 61 additions & 2 deletions test/meters_accuracy_meter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,65 @@ def test_double_meter_update_and_reset(self):

self.meter_update_and_reset_test(meter, model_outputs, targets, expected_value)

def test_single_meter_update_and_reset_onehot(self):
"""
This test verifies that the meter works as expected on a single
update + reset + same single update with onehot target.
"""
meter = AccuracyMeter(topk=[1, 2])

# Batchsize = 3, num classes = 3, score is a value in {1, 2,
# 3}...3 is the highest score
model_output = torch.tensor([[3, 2, 1], [3, 1, 2], [1, 3, 2]])

# Class 0 is the correct class for sample 1, class 2 for sample 2, etc
target = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# Only the first sample has top class correct, first and third
# sample have correct class in top 2
expected_value = {"top_1": 1 / 3.0, "top_2": 2 / 3.0}

self.meter_update_and_reset_test(meter, model_output, target, expected_value)

def test_single_meter_update_and_reset_multilabel(self):
"""
This test verifies that the meter works as expected on a single
update + reset + same single update with multilabel target.
"""
meter = AccuracyMeter(topk=[1, 2])

# Batchsize = 7, num classes = 3, score is a value in {1, 2,
# 3}...3 is the highest score
model_output = torch.tensor(
[
[3, 2, 1],
[3, 1, 2],
[1, 3, 2],
[1, 2, 3],
[2, 1, 3],
[2, 3, 1],
[1, 3, 2],
]
)

target = torch.tensor(
[
[1, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1],
[1, 1, 1],
[1, 0, 1],
]
)

# 1st, 4th, 5th, 6th sample has top class correct, 2nd and 7th have at least
# one correct class in top 2.
expected_value = {"top_1": 4 / 7.0, "top_2": 6 / 7.0}

self.meter_update_and_reset_test(meter, model_output, target, expected_value)

def test_meter_invalid_model_output(self):
meter = AccuracyMeter(topk=[1, 2])
# This model output has 3 dimensions instead of expected 2
Expand All @@ -69,8 +128,8 @@ def test_meter_invalid_model_output(self):
def test_meter_invalid_target(self):
meter = AccuracyMeter(topk=[1, 2])
model_output = torch.tensor([[3, 2, 1], [3, 1, 2], [1, 3, 2]])
# Target has 2 dimensions instead of expected 1
target = torch.tensor([[0, 1, 2], [0, 1, 2]])
# Target has 3 dimensions instead of expected 1 or 2
target = torch.tensor([[[0, 1, 2], [0, 1, 2]]])

self.meter_invalid_meter_input_test(meter, model_output, target)

Expand Down
4 changes: 2 additions & 2 deletions test/meters_precision_meter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_non_onehot_target(self):
This test verifies that the meter works as expected on a single
update + reset + same single update.
"""
meter = PrecisionAtKMeter(topk=[1, 2], target_is_one_hot=False, num_classes=3)
meter = PrecisionAtKMeter(topk=[1, 2])

# Batchsize = 2, num classes = 3, score is probability of class
model_outputs = [
Expand All @@ -195,7 +195,7 @@ def test_non_onehot_target_one_dim_target(self):
This test verifies that the meter works as expected on a single
update + reset + same single update with one dimensional targets.
"""
meter = PrecisionAtKMeter(topk=[1, 2], target_is_one_hot=False, num_classes=3)
meter = PrecisionAtKMeter(topk=[1, 2])

# Batchsize = 2, num classes = 3, score is probability of class
model_outputs = [
Expand Down
4 changes: 2 additions & 2 deletions test/meters_recall_meter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_non_onehot_target(self):
This test verifies that the meter works as expected on a single
update + reset + same single update.
"""
meter = RecallAtKMeter(topk=[1, 2], target_is_one_hot=False, num_classes=3)
meter = RecallAtKMeter(topk=[1, 2])

# Batchsize = 2, num classes = 3, score is probability of class
model_outputs = [
Expand All @@ -191,7 +191,7 @@ def test_non_onehot_target(self):

self.meter_update_and_reset_test(meter, model_outputs, targets, expected_value)

def test_non_onehot_target(self):
def test_non_onehot_target_one_dim_target(self):
"""
This test verifies that the meter works as expected on a single
update + reset + same single update with one dimensional targets.
Expand Down
4 changes: 2 additions & 2 deletions test/meters_video_accuracy_meter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_meter_invalid_target(self):
[[3, 2, 1], [3, 1, 2], [1, 2, 2], [1, 2, 3], [2, 2, 2], [1, 3, 2]],
dtype=torch.float,
)
# Target has 2 dimensions instead of expected 1
target = torch.tensor([[0, 1, 2], [0, 1, 2]])
# Target has 3 dimensions instead of expected 1 or 2
target = torch.tensor([[[0, 1, 2], [0, 1, 2]]])

self.meter_invalid_meter_input_test(meter, model_output, target)
# Target of clips from the same video is not consistent
Expand Down

0 comments on commit 1aeeeeb

Please sign in to comment.