Skip to content

Commit

Permalink
fix bugs in hue_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 24, 2023
1 parent 1e80eed commit 01356eb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
7 changes: 4 additions & 3 deletions diffusion/SRNdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __iter__(self):

class SRNDataset(Dataset):

def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128):
def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128,
use_hue_loss=False):
self.imgsize = imgsize
self.path = path
super().__init__()
Expand All @@ -51,6 +52,7 @@ def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars

random.seed(0)
random.shuffle(all_the_vid)
self.use_hue_loss = use_hue_loss
self.split = split
if split == 'train':
self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)]
Expand Down Expand Up @@ -87,7 +89,7 @@ def __getitem__(self, idx):

imgs = np.stack(imgs, 0)
hue_delta = 0
if self.split == 'train':
if self.split == 'train' and self.use_hue_loss:
hue_delta = random.random() - 0.5
adjust_img = torchvision.transforms.functional.adjust_hue(torch.Tensor(imgs[1]), hue_delta)
imgs[1] = adjust_img.numpy()
Expand All @@ -97,7 +99,6 @@ def __getitem__(self, idx):

return imgs, R, T, K, hue_delta


# if __name__ == "__main__":
#
# from torch.utils.data import DataLoader
Expand Down
8 changes: 5 additions & 3 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@


def main(args):
d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size)
d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size)
d = SRNDataset('train', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size,
use_hue_loss=args.use_hue_decoder)
d_val = SRNDataset('val', path=args.data_path, pickle_file=args.pickle_path, imgsize=args.image_size,
use_hue_loss=False)

loader = MultiEpochsDataLoader(d, batch_size=args.batch_size,
shuffle=True, drop_last=True,
Expand Down Expand Up @@ -74,7 +76,7 @@ def train(model, optimizer, loader, loader_val, writer, now, step, args):

optimizer.zero_grad()

logsnr = utils.logsnr_schedule_cosine(torch.rand((B, )))
logsnr = utils.logsnr_schedule_cosine(torch.rand((B,)))

loss = utils.p_losses(model, img=img, R=R, T=T, K=K, logsnr=logsnr, hue_delta=hue_delta,
loss_type="l2", cond_prob=0.1, use_hue_loss=args.use_hue_decoder)
Expand Down

0 comments on commit 01356eb

Please sign in to comment.