Skip to content

Commit

Permalink
update the single and aligned/paired image datasets, add the unaligne…
Browse files Browse the repository at this point in the history
…d/unpaired images dataset, additional functions and fixes
  • Loading branch information
victorca25 committed May 26, 2021
1 parent a5861eb commit 8131225
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 98 deletions.
8 changes: 4 additions & 4 deletions codes/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def create_dataset(dataset_opt: dict) -> torch.utils.data.Dataset:
:param dataset_opt: Dataset configuration from opt file
"""
mode = dataset_opt['mode']
if mode == 'LR':
from data.LR_dataset import LRDataset as D
if mode == 'single' or mode == 'LR':
from data.single_dataset import SingleDataset as D
elif mode in ['aligned', 'LRHR', 'LRHROTF', 'LRHRC']:
from data.aligned_dataset import AlignedDataset as D
elif mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset as D
elif mode == 'LRHRseg_bg':
from data.LRHR_seg_bg_dataset import LRHRSeg_BG_Dataset as D
elif mode == 'VLRHR':
Expand All @@ -51,8 +53,6 @@ def create_dataset(dataset_opt: dict) -> torch.utils.data.Dataset:
from data.DVD_dataset import DVDDataset as D
elif mode == 'DVDI':
from data.DVD_dataset import DVDIDataset as D
elif mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
Expand Down
119 changes: 59 additions & 60 deletions codes/data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from dataops.common import _init_lmdb, channel_convert

# from dataops.debug import tmp_vis, describe_numpy, describe_tensor

from data.base_dataset import BaseDataset, get_dataroots_paths, read_imgs_from_path, read_single_dataset, read_split_single_dataset

from data.base_dataset import BaseDataset, get_dataroots_paths, read_imgs_from_path, get_single_dataroot_path, read_split_single_dataset
from dataops.augmentations import (generate_A_fn, image_type, get_default_imethod, dim_change_fn,
shape_change_fn, random_downscale_B, paired_imgs_check,
get_unpaired_params, get_augmentations, get_totensor_params, get_totensor,
set_transforms, get_ds_kernels, get_noise_patches,
get_params, image_size, image_channels, scale_params, scale_opt, get_transform,
Scale, modcrop)
# from dataops.debug import tmp_vis, describe_numpy, describe_tensor


class AlignedDataset(BaseDataset):
Expand All @@ -25,10 +22,9 @@ class AlignedDataset(BaseDataset):
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option dictionary) -- stores all the experiment flags
opt (Option dictionary): stores all the experiment flags
"""
super(AlignedDataset, self).__init__(opt, keys_ds=['LR','HR'])
self.LR_env, self.HR_env = None, None # environment for lmdb
self.vars = self.opt.get('outputs', 'LRHR') #'AB'
self.ds_kernels = get_ds_kernels(self.opt)
self.noise_patches = get_noise_patches(self.opt)
Expand All @@ -37,22 +33,26 @@ def __init__(self, opt):
# get images paths (and optional environments for lmdb) from dataroots
dir_AB = self.opt.get('dataroot', None) or self.opt.get('dataroot_AB', None)
if dir_AB:
self.AB_env = None
self.AB_paths = read_single_dataset(self.opt, dir_AB)
self.AB_env = None # environment for lmdb
self.AB_paths = get_single_dataroot_path(self.opt, dir_AB)
if self.opt.get('data_type') == 'lmdb':
self.AB_env = _init_lmdb(dir_AB)
else:
self.paths_LR, self.paths_HR = get_dataroots_paths(self.opt, strict=False, keys_ds=self.keys_ds)
self.A_paths, self.B_paths = get_dataroots_paths(self.opt, strict=False, keys_ds=self.keys_ds)
self.AB_paths = None
self.A_env, self.B_env = None, None # environment for lmdb

if self.opt.get('data_type') == 'lmdb':
self.LR_env = _init_lmdb(self.opt.get('dataroot_'+self.keys_ds[0]))
self.HR_env = _init_lmdb(self.opt.get('dataroot_'+self.keys_ds[1]))
self.A_env = _init_lmdb(self.opt.get(f'dataroot_{self.keys_ds[0]}'))
self.B_env = _init_lmdb(self.opt.get(f'dataroot_{self.keys_ds[1]}'))

# get reusable totensor params
self.totensor_params = get_totensor_params(self.opt)

def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index: a random integer for data indexing
index (int): a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
(or LR, HR, LR_paths and HR_paths)
A (tensor): an image in the input domain
Expand All @@ -65,58 +65,58 @@ def __getitem__(self, index):

######## Read the images ########
if self.AB_paths:
img_LR, img_HR, LR_path, HR_path = read_split_single_dataset(
img_A, img_B, A_path, B_path = read_split_single_dataset(
self.opt, index, self.AB_paths, self.AB_env)
else:
img_LR, img_HR, LR_path, HR_path = read_imgs_from_path(
self.opt, index, self.paths_LR, self.paths_HR, self.LR_env, self.HR_env)
img_A, img_B, A_path, B_path = read_imgs_from_path(
self.opt, index, self.A_paths, self.B_paths, self.A_env, self.B_env)

######## Modify the images ########

# for the validation / test phases
if self.opt['phase'] != 'train':
img_type = image_type(img_HR)
# HR modcrop
img_HR = modcrop(img_HR, scale=scale, img_type=img_type)
# modcrop and downscale LR if enabled
img_type = image_type(img_B)
# B/HR modcrop
img_B = modcrop(img_B, scale=scale, img_type=img_type)
# modcrop and downscale A/LR if enabled
if self.opt['lr_downscale']:
img_LR = modcrop(img_LR, scale=scale, img_type=img_type)
img_A = modcrop(img_A, scale=scale, img_type=img_type)
# TODO: 'pil' images will use default method for scaling
img_LR, _ = Scale(img_LR, scale,
img_A, _ = Scale(img_A, scale,
algo=self.opt.get('lr_downscale_types', 777), img_type=img_type)

# change color space if necessary
# TODO: move to get_transform()
color_HR = self.opt.get('color', None) or self.opt.get('color_HR', None)
if color_HR:
img_HR = channel_convert(image_channels(img_HR), color_HR, [img_HR])[0]
color_LR = self.opt.get('color', None) or self.opt.get('color_LR', None)
if color_LR:
img_LR = channel_convert(image_channels(img_LR), color_LR, [img_LR])[0]
color_B = self.opt.get('color', None) or self.opt.get('color_HR', None)
if color_B:
img_B = channel_convert(image_channels(img_B), color_B, [img_B])[0]
color_A = self.opt.get('color', None) or self.opt.get('color_LR', None)
if color_A:
img_A = channel_convert(image_channels(img_A), color_A, [img_A])[0]

######## Augmentations ########

#Augmentations during training
if self.opt['phase'] == 'train':

default_int_method = get_default_imethod(image_type(img_LR))
default_int_method = get_default_imethod(image_type(img_A))

# random HR downscale
img_LR, img_HR = random_downscale_B(img_A=img_LR, img_B=img_HR,
img_A, img_B = random_downscale_B(img_A=img_A, img_B=img_B,
opt=self.opt)

# validate there's an img_LR, if not, use img_HR
if img_LR is None:
img_LR = img_HR
print("Image LR: ", LR_path, ("was not loaded correctly, using HR pair to downscale on the fly."))
# validate there's an img_A, if not, use img_B
if img_A is None:
img_A = img_B
print(f"Image A: {A_path} was not loaded correctly, using B pair to downscale on the fly.")

# validate proper dimensions between paired images, generate A if needed
img_LR, img_HR = paired_imgs_check(
img_LR, img_HR, opt=self.opt, ds_kernels=self.ds_kernels)
img_A, img_B = paired_imgs_check(
img_A, img_B, opt=self.opt, ds_kernels=self.ds_kernels)

# get and apply the paired transformations below
transform_params = get_params(
scale_opt(self.opt, scale), image_size(img_LR))
scale_opt(self.opt, scale), image_size(img_A))
A_transform = get_transform(
scale_opt(self.opt, scale),
transform_params,
Expand All @@ -127,51 +127,50 @@ def __getitem__(self, index):
scale_params(transform_params, scale),
# grayscale=(output_nc == 1),
method=default_int_method)
img_LR = A_transform(img_LR)
img_HR = B_transform(img_HR)
img_A = A_transform(img_A)
img_B = B_transform(img_B)

# Below are the On The Fly augmentations

# get and apply the unpaired transformations below
lr_aug_params, hr_aug_params = get_unpaired_params(self.opt)
a_aug_params, b_aug_params = get_unpaired_params(self.opt)

lr_augmentations = get_augmentations(
a_augmentations = get_augmentations(
self.opt,
params=lr_aug_params,
params=a_aug_params,
noise_patches=self.noise_patches,
)
hr_augmentations = get_augmentations(
b_augmentations = get_augmentations(
self.opt,
params=hr_aug_params,
params=b_aug_params,
noise_patches=self.noise_patches,
)

img_LR = lr_augmentations(img_LR)
img_HR = hr_augmentations(img_HR)
img_A = a_augmentations(img_A)
img_B = b_augmentations(img_B)

# Alternative position for changing the colorspace of LR.
# color_LR = self.opt.get('color', None) or self.opt.get('color_LR', None)
# if color_LR:
# img_LR = channel_convert(image_channels(img_LR), color_LR, [img_LR])[0]
# Alternative position for changing the colorspace of A/LR.
# color_A = self.opt.get('color', None) or self.opt.get('color_A', None)
# if color_A:
# img_A = channel_convert(image_channels(img_A), color_A, [img_A])[0]

######## Convert images to PyTorch Tensors ########

totensor_params = get_totensor_params(self.opt)
tensor_transform = get_totensor(self.opt, params=totensor_params, toTensor=True, grayscale=False)
img_LR = tensor_transform(img_LR)
img_HR = tensor_transform(img_HR)
tensor_transform = get_totensor(
self.opt, params=self.totensor_params, toTensor=True, grayscale=False)
img_A = tensor_transform(img_A)
img_B = tensor_transform(img_B)

if LR_path is None:
LR_path = HR_path
if A_path is None:
A_path = B_path
if self.vars == 'AB':
return {'A': img_LR, 'B': img_HR, 'A_paths': LR_path, 'B_paths': HR_path}
return {'A': img_A, 'B': img_B, 'A_path': A_path, 'B_path': B_path}
else:
return {'LR': img_LR, 'HR': img_HR, 'LR_path': LR_path, 'HR_path': HR_path}
return {'LR': img_A, 'HR': img_B, 'LR_path': A_path, 'HR_path': B_path}

def __len__(self):
"""Return the total number of images in the dataset."""
if self.AB_paths:
return len(self.AB_paths)
else:
return len(self.paths_HR)

return len(self.B_paths)
28 changes: 24 additions & 4 deletions codes/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def name(self):



############# Testing below

def process_img_paths(images_paths=None, data_type='img', max_dataset_size=float("inf")):
if not images_paths:
return images_paths
Expand All @@ -77,9 +75,10 @@ def format_paths(dataroot=None):
Check if dataroot is a list of directories or a single directory.
Note: lmdb will not currently work with a list
"""
# if receiving a single path in str format, convert to list

if dataroot:
if type(dataroot) is str:
# if receiving a single path in str format, convert to list
dataroot = os.path.join(dataroot)
dataroot = [dataroot]
else:
Expand Down Expand Up @@ -108,6 +107,9 @@ def paired_dataset_validation(A_images_paths, B_images_paths, data_type='img', m


def check_data_keys(opt, keys_ds=['LR', 'HR']):
if len(keys_ds) < 2:
return opt

keys_A = ['LR', 'A', 'lq']
keys_B = ['HR', 'B', 'gt']

Expand Down Expand Up @@ -307,7 +309,7 @@ def read_imgs_from_path(opt, index, paths_A, paths_B, A_env, B_env):
return img_A, img_B, A_path, B_path


def read_single_dataset(opt=None, dataroot=None, max_dataset_size=float("inf")):
def get_single_dataroot_path(opt=None, dataroot=None, max_dataset_size=float("inf")):
images_paths = format_paths(dataroot)
img_paths = process_img_paths(images_paths, opt['data_type'], max_dataset_size)

Expand Down Expand Up @@ -336,3 +338,21 @@ def read_split_single_dataset(opt, index, AB_paths, env):
return img_A, img_B, A_path, B_path


def read_single_dataset(opt, index, paths, env, idx_case=None, d_size=None):
loader = opt.get('img_loader', 'cv2')

if idx_case == 'inrange' or idx_case == 'serial':
# make sure index is within the range of dataset.
# if used for both unpaired datasets, will always return
# the same pair of images
index = index % d_size
elif idx_case == 'random':
# randomize the index for domain B in unpaired dataset to
# avoid fixed pairs.
index = random.randint(0, d_size - 1)

path = paths[index]
img = read_img(env=env, path=path, loader=loader)

return img, path

Loading

0 comments on commit 8131225

Please sign in to comment.