Skip to content

Commit

Permalink
update vit example for new API (hpcaitech#98) (hpcaitech#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Jan 4, 2022
1 parent d09a79b commit f03bcb3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def __next__(self):
img = lam * img + (1 - lam) * img[idx, :]
label_a, label_b = label, label[idx]
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
label = (label_a, label_b, lam)
label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
else:
label = (label, label, torch.ones(
1, device=img.device, dtype=img.dtype))
return (img,), label
return (img,), (label,)
label = {'targets_a': label, 'targets_b': label,
'lam': torch.ones(1, device=img.device, dtype=img.dtype)}
return img, label
return img, label
13 changes: 11 additions & 2 deletions examples/vit_b16_imagenet_data_parallel/mixup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import torch.nn as nn
from colossalai.registry import LOSSES
import torch


@LOSSES.register_module
class MixupLoss(nn.Module):
def __init__(self, loss_fn_cls):
super().__init__()
self.loss_fn = loss_fn_cls()

def forward(self, inputs, *args):
targets_a, targets_b, lam = args
def forward(self, inputs, targets_a, targets_b, lam):
return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b)


class MixupAccuracy(nn.Module):
def forward(self, logits, targets):
targets = targets['targets_a']
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(targets == preds)
return correct
6 changes: 3 additions & 3 deletions examples/vit_b16_imagenet_data_parallel/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.trainer import Trainer, hooks
from colossalai.nn.lr_scheduler import LinearWarmupLR
from dataloader.imagenet_dali_dataloader import DaliDataloader
from mixup import MixupLoss
from mixup import MixupLoss, MixupAccuracy
from timm.models import vit_base_patch16_224
from myhooks import TotalBatchsizeHook

Expand Down Expand Up @@ -62,7 +62,7 @@ def main():
port=args.port,
backend=args.backend
)
# launch from torch
# launch from torch
# colossalai.launch_from_torch(config=args.config)

# get logger
Expand Down Expand Up @@ -96,7 +96,7 @@ def main():
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),
Expand Down

0 comments on commit f03bcb3

Please sign in to comment.