Skip to content

Commit

Permalink
Add data augmentation
Browse files Browse the repository at this point in the history
selen_erkan committed Apr 23, 2022
1 parent 91158d0 commit 65475ab
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -8,16 +8,27 @@


class XrayDataset(Dataset):
def __init__(self, root, csv_path, transform=None):
def __init__(self, root, csv_path, data_augmentation=False):
super(Dataset, self).__init__()

if transform is None:
if data_augmentation:
self.transform = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize(224),
transforms.RandomApply(transforms=[transforms.RandomHorizontalFlip(p=1)], p=0.5),
transforms.RandomApply(transforms=[transforms.RandomAffine(degrees=10, shear=0)], p=0.3),
transforms.RandomApply(transforms=[
transforms.ColorJitter(brightness=(0.9, 1), contrast=(0.3), saturation=(0.5, 1), hue=(-0.1, 0.1))],
p=0.3),
transforms.RandomApply(transforms=[transforms.RandomEqualize()], p=0.5),
transforms.RandomApply(transforms=[transforms.RandomPerspective(distortion_scale=0.1, p=1)], p=0.1),
transforms.RandomApply(transforms=[transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.1),
])
else:
self.transform = transform
self.transform = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize(224),
])

self.root = root
csv_df = pd.read_csv(csv_path, sep=';')
@@ -35,7 +46,7 @@ def normalize(self, img):
img = img[0:3, :, :]
if img.shape != torch.Size([3, 224, 224]):
img = img.unsqueeze(0)
assert(img.shape == torch.Size([3, 224, 224])), img.shape
assert (img.shape == torch.Size([3, 224, 224])), img.shape
if len(img.shape) < 2:
print("error, dimension lower than 2 for image")

2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@

# =============================================================================
# Dataset and dataloader
dataset = XrayDataset(PATH, CSV_FILE_PATH)
dataset = XrayDataset(PATH, CSV_FILE_PATH, data_augmentation=True)
dataloaders = get_dataloader(dataset, batch_size=BATCH_SIZE)

# Model

0 comments on commit 65475ab

Please sign in to comment.