Skip to content

Commit

Permalink
Update artifusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenqi-he authored Nov 14, 2023
1 parent 7a2816a commit 01c908d
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions model_training/scripts/improved_diffusion/artifusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def _forward(self, x,emb):
# print(self.emb_layers)
nW, x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(B, -1, self.window_size * self.window_size,C)
print('in swinTransformer block, emb before: ',emb.shape)
# print('in swinTransformer block, emb before: ',emb.shape)
time_embed=self.emb_layers(emb).unsqueeze(1)
time_embed=time_embed.repeat(1,nW,1,1)
print('in swinTransformer block, emb after: ',time_embed.shape)
# print('in swinTransformer block, emb after: ',time_embed.shape)
x_windows = torch.concat((x_windows,time_embed),dim=-2)
x_windows = x_windows.flatten(0,1)
# x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
Expand Down Expand Up @@ -537,28 +537,28 @@ def inner_dtype(self):
def forward(self, x, timesteps, y=None):


print("timesteps shape ",timesteps.shape)
print("input x shape ",x.shape)
# print("timesteps shape ",timesteps.shape)
# print("input x shape ",x.shape)
x = self.patch_embed(x)
print("input x shape after patch_emb",x.shape)
# print("input x shape after patch_emb",x.shape)
x = self.pos_drop(x)
print("input x shape after pos_drop",x.shape)
# print("input x shape after pos_drop",x.shape)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"

hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
print('emb shape after time_embed in forward ',emb.shape)
# print('emb shape after time_embed in forward ',emb.shape)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)

h = x.type(self.inner_dtype)
print("downsample before, h shape ",h.shape)
# print("downsample before, h shape ",h.shape)
for module in self.input_blocks:
print("emb,",emb.shape)
print("downsample, h shape ",h.shape)
# print("emb,",emb.shape)
# print("downsample, h shape ",h.shape)
hs.append(h)
h = module(h, emb)

Expand All @@ -569,7 +569,7 @@ def forward(self, x, timesteps, y=None):
h = self.norm(h)
# print("len hs ", len(hs))
for inx, layer_up in enumerate(self.output_blocks_layers_up):
print("updample h shape ",h.shape)
# print("updample h shape ",h.shape)
if inx == 0:
h = layer_up(h,emb)
else:
Expand Down Expand Up @@ -641,7 +641,7 @@ def get_feature_vectors(self, x, timesteps, y=None):
result["up"].append(h.type(x.dtype))
return result

class SuperResModel(SwinUNetModel):
class SuperResModel(ArtiFusionModel):

def __init__(self, in_channels, *args, **kwargs):
super().__init__(in_channels * 2, *args, **kwargs)
Expand Down

0 comments on commit 01c908d

Please sign in to comment.