From 01356eb8d4a21f1ec2aa42c423916726d345d2f0 Mon Sep 17 00:00:00 2001 From: libn Date: Wed, 24 May 2023 15:58:55 +0800 Subject: [PATCH] fix bugs in hue_loss --- diffusion/SRNdataset.py | 7 ++++--- diffusion/train.py | 8 +++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/diffusion/SRNdataset.py b/diffusion/SRNdataset.py index 30928f7..0ed21c5 100644 --- a/diffusion/SRNdataset.py +++ b/diffusion/SRNdataset.py @@ -41,7 +41,8 @@ def __iter__(self): class SRNDataset(Dataset): - def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128): + def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128, + use_hue_loss=False): self.imgsize = imgsize self.path = path super().__init__() @@ -51,6 +52,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars random.seed(0) random.shuffle(all_the_vid) + self.use_hue_loss = use_hue_loss self.split = split if split == 'train': self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)] @@ -87,7 +89,7 @@ def __getitem__(self, idx): imgs = np.stack(imgs, 0) hue_delta = 0 - if self.split == 'train': + if self.split == 'train' and self.use_hue_loss: hue_delta = random.random() - 0.5 adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta) imgs[1] = adjust_img.numpy() @@ -97,7 +99,6 @@ def __getitem__(self, idx): return imgs, R, T, K, hue_delta - # if __name__ == "__main__": # # from torch.utils.data import DataLoader diff --git a/diffusion/train.py b/diffusion/train.py index 48fad31..96eb03c 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -19,8 +19,10 @@ def main(args): - d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size) - d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size) + d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size, + use_hue_loss=args.use_hue_decoder) + d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size, + use_hue_loss=False) loader = MultiEpochsDataLoader(d, batch_size=args.batch_size, shuffle=True, drop_last=True, @@ -74,7 +76,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args): optimizer.zero_grad() - logsnr = utils.logsnr_schedule_cosine(torch.rand((B, ))) + logsnr = utils.logsnr_schedule_cosine(torch.rand((B,))) loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta, loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder)