Skip to content

Commit

Permalink
fix bugs in hue_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cuttle-fish-my committed May 19, 2023
1 parent 68ec3ad commit fcf031b
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions diffusion/xunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ class XUNet(torch.nn.Module):
use_pos_emb: bool = True
use_ref_pose_emb: bool = True

def __init__(self, **kwargs):
def __init__(self, use_hue_decoder=False, **kwargs):
self.__dict__.update(kwargs)
super().__init__()

self.use_hue_decoder = use_hue_decoder
assert self.H % (2 ** (
len(self.ch_mult) - 1)) == 0, f"Size of the image must me multiple of {2 ** (len(self.ch_mult) - 1)}"
assert self.W % (2 ** (
Expand Down Expand Up @@ -426,18 +426,19 @@ def __init__(self, **kwargs):
use_attn=self.num_resolutions in self.attn_resolutions)

# hue_delta prediction
self.hue_decoder = torch.nn.Sequential(
torch.nn.Conv2d(self.dim_out[-1] * 2, self.dim_out[-1] * 4, kernel_size=3, padding=1), # 1024 -> 2048
torch.nn.AvgPool2d(kernel_size=2, stride=2), # 8x8 -> 4x4
torch.nn.ReLU(),
torch.nn.Conv2d(self.dim_out[-1] * 4, self.dim_out[-1] // 4, kernel_size=1, stride=1), # 2048 -> 128
torch.nn.ReLU(),
torch.nn.Flatten(start_dim=1), # 128x4x4 -> 2048
torch.nn.Linear(2048, 256), # 4098 -> 256
torch.nn.ReLU(),
torch.nn.Linear(256, 1), # 256 -> 1
torch.nn.Tanh()
)
if self.use_hue_decoder:
self.hue_decoder = torch.nn.Sequential(
torch.nn.Conv2d(self.dim_out[-1] * 2, self.dim_out[-1] * 4, kernel_size=3, padding=1), # 1024 -> 2048
torch.nn.AvgPool2d(kernel_size=2, stride=2), # 8x8 -> 4x4
torch.nn.ReLU(),
torch.nn.Conv2d(self.dim_out[-1] * 4, self.dim_out[-1] // 4, kernel_size=1, stride=1), # 2048 -> 128
torch.nn.ReLU(),
torch.nn.Flatten(start_dim=1), # 128x4x4 -> 2048
torch.nn.Linear(2048, 256), # 4098 -> 256
torch.nn.ReLU(),
torch.nn.Linear(256, 1), # 256 -> 1
torch.nn.Tanh()
)

# Downsampling
self.upsample = torch.nn.ModuleDict()
Expand Down Expand Up @@ -521,8 +522,8 @@ def forward(self, batch, *, cond_mask):
emb = logsnr_emb[..., None, None] + pose_embs[-1]

h = self.middle(h, emb)

hue_delta = self.hue_decoder(rearrange(h, 'b f c h w -> b (f c) h w'))
if self.use_hue_decoder:
hue_delta = self.hue_decoder(rearrange(h, 'b f c h w -> b (f c) h w'))
# hue_delta = rearrange(hue_delta, '(b f) d -> b f d', b=B, f=2)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
Expand All @@ -541,7 +542,10 @@ def forward(self, batch, *, cond_mask):
h = torch.nn.functional.silu(self.lastgn(h)) # [B, F, self.ch, 128, 128]
pred_noise = self.lastconv(rearrange(h, 'b f c h w -> (b f) c h w'))
pred_noise = rearrange(pred_noise, '(b f) c h w -> b f c h w', b=B)[:, 1]
return pred_noise, hue_delta
if self.use_hue_decoder:
return pred_noise, hue_delta
else:
return pred_noise

# if __name__ == "__main__":
# h, w = 56, 56
Expand Down

0 comments on commit fcf031b

Please sign in to comment.