diff --git a/diffusion/SRNdataset.py b/diffusion/SRNdataset.py
index 30928f7..0ed21c5 100644
--- a/diffusion/SRNdataset.py
+++ b/diffusion/SRNdataset.py
@@ -41,7 +41,8 @@ def __iter__(self):
class SRNDataset(Dataset):
- def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128):
+ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128,
+ use_hue_loss=False):
self.imgsize = imgsize
self.path = path
super().__init__()
@@ -51,6 +52,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars
random.seed(0)
random.shuffle(all_the_vid)
+ self.use_hue_loss = use_hue_loss
self.split = split
if split == 'train':
self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)]
@@ -87,7 +89,7 @@ def __getitem__(self, idx):
imgs = np.stack(imgs, 0)
hue_delta = 0
- if self.split == 'train':
+ if self.split == 'train' and self.use_hue_loss:
hue_delta = random.random() - 0.5
adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta)
imgs[1] = adjust_img.numpy()
@@ -97,7 +99,6 @@ def __getitem__(self, idx):
return imgs, R, T, K, hue_delta
-
# if __name__ == "__main__":
#
# from torch.utils.data import DataLoader
diff --git a/diffusion/train.py b/diffusion/train.py
index 48fad31..96eb03c 100644
--- a/diffusion/train.py
+++ b/diffusion/train.py
@@ -19,8 +19,10 @@
def main(args):
- d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size)
- d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size)
+ d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size,
+ use_hue_loss=args.use_hue_decoder)
+ d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size,
+ use_hue_loss=False)
loader = MultiEpochsDataLoader(d, batch_size=args.batch_size,
shuffle=True, drop_last=True,
@@ -74,7 +76,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):
optimizer.zero_grad()
- logsnr = utils.logsnr_schedule_cosine(torch.rand((B, )))
+ logsnr = utils.logsnr_schedule_cosine(torch.rand((B,)))
loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta,
loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder)