Skip to content

Commit

Permalink
mixup data augmentation (facebookresearch#469)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#469

This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412)

Differential Revision: D20911088

fbshipit-source-id: 479ab9c74b9d326fec0a5f14a6ff82b11a198d54
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Apr 8, 2020
1 parent 6214d10 commit 6c66c2d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
32 changes: 32 additions & 0 deletions classy_vision/dataset/transforms/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from classy_vision.generic.util import convert_to_one_hot
from torch.distributions.beta import Beta


def mixup_transform(sample, num_classes, alpha):
"""
This implements the mixup data augmentation in the paper
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412)
Args:
sample (Dict[str, Any]): the batch data
alpha (float): the hyperparameter of Beta distribution used to sample mixup
coefficient.
"""
assert (
sample["target"].ndim == 1
), "Currently mixup only supports single-label classification"
sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1), num_classes)

c = Beta(torch.tensor([alpha]), torch.tensor([alpha])).sample()

for key in ["input", "target"]:
sample[key] = c * sample[key] + (1.0 - c) * sample[key].flip([0])

return sample
9 changes: 4 additions & 5 deletions classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,10 @@ def maybe_convert_to_one_hot(target, model_output):
):
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])

assert (target.shape == model_output.shape) and (
torch.min(target.eq(0) + target.eq(1)) == 1
), (
"Target must be one-hot/multi-label encoded and of the "
"same shape as model_output."
# target are not necessarily hard 0/1 encoding. It can be soft
# (i.e. fractional) in some cases, such as mixup label
assert target.shape == model_output.shape, (
"Target must of the " "same shape as model_output."
)

return target
Expand Down
28 changes: 28 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import torch
import torch.nn as nn
from classy_vision.dataset import ClassyDataset, build_dataset
from classy_vision.dataset.transforms.mixup import mixup_transform
from classy_vision.generic.distributed_util import (
all_reduce_mean,
barrier,
init_distributed_data_parallel_model,
is_distributed_training_run,
)
from classy_vision.generic.util import (
convert_to_one_hot,
copy_model_to_gpu,
recursive_copy_to_gpu,
update_classy_state,
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(self):
BroadcastBuffersMode.DISABLED
)
self.amp_args = None
self.mixup_args = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
Expand Down Expand Up @@ -306,6 +309,20 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
logging.info(f"AMP enabled with args {amp_args}")
return self

def set_mixup_args(self, mixup_args: Optional[Dict[str, Any]]):
"""Disable / enable mixup data augmentation
Args::
mixup_args: expect to include the follow keys in the dictionary
num_classes (int): number of dataset classes
alpha (float): the hyperparameter of Beta distribution used to
sample mixup coefficient.
"""
self.mixup_args = mixup_args
if mixup_args is None:
logging.info(f"mixup disabled")
return self

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
"""Instantiates a ClassificationTask from a configuration.
Expand Down Expand Up @@ -348,6 +365,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_optimizer(optimizer)
.set_meters(meters)
.set_amp_args(amp_args)
.set_mixup_args(config.get("mixup"))
.set_distributed_options(
broadcast_buffers_mode=BroadcastBuffersMode[
config.get("broadcast_buffers", "disabled").upper()
Expand Down Expand Up @@ -697,6 +715,11 @@ def eval_step(self, use_gpu):
+ "'target' keys"
)

if self.mixup_args is not None:
sample["target"] = convert_to_one_hot(
sample["target"].view(-1, 1), self.mixup_args["num_classes"]
)

# Copy sample to GPU
target = sample["target"]
if use_gpu:
Expand Down Expand Up @@ -743,6 +766,11 @@ def train_step(self, use_gpu):
+ "'target' keys"
)

if self.mixup_args is not None:
sample = mixup_transform(
sample, self.mixup_args["num_classes"], self.mixup_args["alpha"]
)

# Copy sample to GPU
target = sample["target"]
if use_gpu:
Expand Down

0 comments on commit 6c66c2d

Please sign in to comment.