Skip to content

Commit

Permalink
Update cutmixup_collator.py (facebookresearch#196)
Browse files Browse the repository at this point in the history
Summary:
Fixed collator compatibility issues w/ MoCo and SimCLR

Pull Request resolved: facebookresearch#196

Reviewed By: blefaudeux

Differential Revision: D26651453

Pulled By: growlix

fbshipit-source-id: 120d14005236ca9e7ced948fdfd2a3a0ca5770bb
  • Loading branch information
growlix authored and facebook-github-bot committed Feb 25, 2021
1 parent 4cffc1c commit a7a2828
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ config:
TRANSFORMS:
- name: RandomResizedCrop
size: 224
- name: VisslRandAugment
- name: RandAugment
magnitude: 5
weight_choice: 0
- name: ToTensor
Expand All @@ -35,7 +35,6 @@ config:
}
TEST:
DATA_SOURCES: [disk_folder]
# DATA_PATHS: ["<path to test folder>"]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_debug_folder]
BATCHSIZE_PER_REPLICA: 32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ config:
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
- name: VisslRandAugment
- name: RandAugment
magnitude: 9
magnitude_std: 0.5
increasing_severity: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ config:
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_folder]
BATCHSIZE_PER_REPLICA: 32
DROP_LAST: True
TRANSFORMS:
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
- name: VisslRandAugment
- name: RandAugment
magnitude: 9
magnitude_std: 0.5
increasing_severity: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ config:
HOOKS:
TENSORBOARD_SETUP:
USE_TENSORBOARD: True
EXPERIMENT_LOG_DIR: "/checkpoint/ito/vision_transformer/1gpu_test"
FLUSH_EVERY_N_MIN: 20
CHECKPOINT:
DIR: "."
Expand All @@ -13,19 +12,18 @@ config:
NUM_DATALOADER_WORKERS: 8
TRAIN:
DATA_SOURCES: [disk_folder]
# DATA_PATHS: ["<path to train folder>"]
LABEL_SOURCES: [disk_folder]
DATASET_NAMES: [imagenet1k_debug_folder]
DATASET_NAMES: [imagenet1k_folder]
LABEL_TYPE: "zero"
BATCHSIZE_PER_REPLICA: 16
DROP_LAST: True
TRANSFORMS:
- name: ImgPilToMultiCrop
total_num_crops: 2
size_crops: [224]
num_crops: [2]
crop_scales: [[0.14, 1]]
- name: RandomHorizontalFlip
- name: VisslRandAugment
- name: RandAugment
magnitude: 9
magnitude_std: 0.5
increasing_severity: True
Expand All @@ -42,14 +40,15 @@ config:
std: [0.229, 0.224, 0.225]
COLLATE_FUNCTION: cutmixup_collator
COLLATE_FUNCTION_PARAMS: {
"ssl_method": "swav",
"mixup_alpha": 1.0, # mixup alpha value, mixup is active if > 0.
"cutmix_alpha": 1.0, # cutmix alpha value, cutmix is active if > 0.
"prob": 1.0, # probability of applying mixup or cutmix per batch or element
"switch_prob": 0.5, # probability of switching to cutmix instead of mixup when both are active
"mode": "batch", # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
"correct_lam": True, # apply lambda correction when cutmix bbox clipped by image borders
"label_smoothing": 0.1, # apply label smoothing to the mixed target tensor
"num_classes": 1000 # number of classes for target
"num_classes": 1 # number of classes for target
}
MODEL:
TRUNK:
Expand Down
2 changes: 1 addition & 1 deletion vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ config:
# no valid values other than 0 or None), unclear if beneficial. Default =
# None.
# TRANSFORMS:
# - name: VisslRandAugment
# - name: RandAugment
# magnitude: 9
# magnitude_std: 0.5
# num_layers: 2
Expand Down
8 changes: 4 additions & 4 deletions vissl/data/collators/cutmixup_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def data_back_to_input_form(data, labels, data_valid, data_idx):
for sample_i in range(num_images):
sample_input_form = {"data": [], "data_valid": [], "data_idx": [], "label": []}
for duplicate_i in range(num_duplicates):
sample_input_form["data"].append(data[duplicate_i][sample_i])
sample_input_form["label"].append(labels[duplicate_i][sample_i])
valid_and_idx_i = sample_i + (num_duplicates * duplicate_i)
sample_input_form["data_idx"].append(data_idx[valid_and_idx_i])
sample_input_form["data_valid"].append(data_valid[valid_and_idx_i])
sample_input_form["data"].append(data[duplicate_i][sample_i])
sample_input_form["label"].append(labels[duplicate_i][sample_i].tolist())
sample_input_form["data_idx"].append(data_idx[valid_and_idx_i].item())
sample_input_form["data_valid"].append(data_valid[valid_and_idx_i].item())
data_input_form.append(sample_input_form)
return data_input_form

Expand Down
2 changes: 1 addition & 1 deletion vissl/data/ssl_transforms/rand_auto_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


# Modification/Addition
@register_transform("VisslRandAugment")
@register_transform("RandAugment")
class RandAugment(ClassyTransform):
"""
Create a RandAugment transform.
Expand Down

0 comments on commit a7a2828

Please sign in to comment.