forked from facebookresearch/ClassyVision
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mixup data augmentation (facebookresearch#469)
Summary: Pull Request resolved: facebookresearch#469 This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Differential Revision: D20911088 fbshipit-source-id: 479ab9c74b9d326fec0a5f14a6ff82b11a198d54
- Loading branch information
1 parent
6214d10
commit 6c66c2d
Showing
3 changed files
with
64 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from classy_vision.generic.util import convert_to_one_hot | ||
from torch.distributions.beta import Beta | ||
|
||
|
||
def mixup_transform(sample, num_classes, alpha): | ||
""" | ||
This implements the mixup data augmentation in the paper | ||
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412) | ||
Args: | ||
sample (Dict[str, Any]): the batch data | ||
alpha (float): the hyperparameter of Beta distribution used to sample mixup | ||
coefficient. | ||
""" | ||
assert ( | ||
sample["target"].ndim == 1 | ||
), "Currently mixup only supports single-label classification" | ||
sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1), num_classes) | ||
|
||
c = Beta(torch.tensor([alpha]), torch.tensor([alpha])).sample() | ||
|
||
for key in ["input", "target"]: | ||
sample[key] = c * sample[key] + (1.0 - c) * sample[key].flip([0]) | ||
|
||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters