Skip to content

Commit

Permalink
Added counting of Dense mapping layer
Browse files Browse the repository at this point in the history
Counting Dense mapping layer and convert generator and g_ema accordingly
  • Loading branch information
levindabhi authored Jan 25, 2021
1 parent 0ce34eb commit cc50f50
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def discriminator_fill_statedict(statedict, vars, size):
return statedict


def fill_statedict(state_dict, vars, size):
def fill_statedict(state_dict, vars, size, n_mlp):
log_size = int(math.log(size, 2))

for i in range(8):
for i in range(n_mlp):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))

update(
Expand Down Expand Up @@ -237,9 +237,15 @@ def fill_statedict(state_dict, vars, size):

size = g_ema.output_shape[2]

g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
n_mlp = 0
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
for layer in mapping_layers_names:
if layer[0].startswith('Dense'):
n_mlp += 1

g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, g_ema.vars, size)
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)

g.load_state_dict(state_dict)

Expand All @@ -248,7 +254,7 @@ def fill_statedict(state_dict, vars, size):
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}

if args.gen:
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator.vars, size)
ckpt["g"] = g_train_state
Expand Down Expand Up @@ -292,5 +298,4 @@ def fill_statedict(state_dict, vars, size):

utils.save_image(
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)

)

0 comments on commit cc50f50

Please sign in to comment.