Skip to content

Commit

Permalink
convertweightenv
Browse files Browse the repository at this point in the history
  • Loading branch information
zengxianyu committed Apr 3, 2022
1 parent bef283a commit 35b07f3
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-
This will create converted stylegan2-ffhq-config-f.pt file.

> python convert_weight1.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl
> python convert_weight2.py stylegan2-ffhq-config-f.pkl
### Generate samples

> python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT
Expand Down
266 changes: 266 additions & 0 deletions convert_weight1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import argparse
import os
import sys
import pickle
import math

import numpy as np


def convert_modconv(vars, source_name, target_name, flip=False):
weight = vars[source_name + "/weight"].value().eval()
mod_weight = vars[source_name + "/mod_weight"].value().eval()
mod_bias = vars[source_name + "/mod_bias"].value().eval()
noise = vars[source_name + "/noise_strength"].value().eval()
bias = vars[source_name + "/bias"].value().eval()

dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"noise.weight": np.array([noise]),
"activate.bias": bias,
}

dic_torch = {}

for k, v in dic.items():
dic_torch[target_name + "." + k] = v#torch.from_numpy(v)

if flip:
dic_torch[target_name + ".conv.weight"] = \
np.flip(
dic_torch[target_name + ".conv.weight"], [3, 4]
)
#torch.flip(
# dic_torch[target_name + ".conv.weight"], [3, 4]
#)

return dic_torch


def convert_conv(vars, source_name, target_name, bias=True, start=0):
weight = vars[source_name + "/weight"].value().eval()

dic = {"weight": weight.transpose((3, 2, 0, 1))}

if bias:
dic["bias"] = vars[source_name + "/bias"].value().eval()

dic_torch = {}

dic_torch[target_name + f".{start}.weight"] = dic["weight"]#torch.from_numpy(dic["weight"])

if bias:
dic_torch[target_name + f".{start + 1}.bias"] = dic["bias"]#torch.from_numpy(dic["bias"])

return dic_torch


def convert_torgb(vars, source_name, target_name):
weight = vars[source_name + "/weight"].value().eval()
mod_weight = vars[source_name + "/mod_weight"].value().eval()
mod_bias = vars[source_name + "/mod_bias"].value().eval()
bias = vars[source_name + "/bias"].value().eval()

dic = {
"conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
"conv.modulation.weight": mod_weight.transpose((1, 0)),
"conv.modulation.bias": mod_bias + 1,
"bias": bias.reshape((1, 3, 1, 1)),
}

dic_torch = {}

for k, v in dic.items():
dic_torch[target_name + "." + k] = v#torch.from_numpy(v)

return dic_torch


def convert_dense(vars, source_name, target_name):
weight = vars[source_name + "/weight"].value().eval()
bias = vars[source_name + "/bias"].value().eval()

dic = {"weight": weight.transpose((1, 0)), "bias": bias}

dic_torch = {}

for k, v in dic.items():
dic_torch[target_name + "." + k] = v#torch.from_numpy(v)

return dic_torch


def update(state_dict, new):
for k, v in new.items():
#if k not in state_dict:
# raise KeyError(k + " is not found")

#if v.shape != state_dict[k].shape:
# raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")

state_dict[k] = v


def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))

update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))

conv_i = 1

for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(
statedict,
convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
),
)
update(
statedict,
convert_conv(
vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
),
)
conv_i += 1

update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
update(statedict, convert_dense(vars, f"Output", "final_linear.1"))

return statedict


def fill_statedict(state_dict, vars, size, n_mlp):
log_size = int(math.log(size, 2))

for i in range(n_mlp):
update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))

update(
state_dict,
{
"input.input": vars["G_synthesis/4x4/Const/const"].value().eval()
#torch.from_numpy(
# vars["G_synthesis/4x4/Const/const"].value().eval()
#)
},
)

update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))

for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
)

update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))

conv_i = 0

for i in range(log_size - 2):
reso = 4 * 2 ** (i + 1)
update(
state_dict,
convert_modconv(
vars,
f"G_synthesis/{reso}x{reso}/Conv0_up",
f"convs.{conv_i}",
flip=True,
),
)
update(
state_dict,
convert_modconv(
vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
),
)
conv_i += 2

for i in range(0, (log_size - 2) * 2 + 1):
update(
state_dict,
{
f"noises.noise_{i}": vars[f"G_synthesis/noise{i}"].value().eval()
#torch.from_numpy(
# vars[f"G_synthesis/noise{i}"].value().eval()
#)
},
)

return state_dict


if __name__ == "__main__":
device = "cuda"

parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--repo",
type=str,
required=True,
help="path to the offical StyleGAN2 repository with dnnlib/ folder",
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--disc", action="store_true", help="convert the discriminator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")

args = parser.parse_args()

sys.path.append(args.repo)

import dnnlib
from dnnlib import tflib

tflib.init_tf()

with open(args.path, "rb") as f:
generator, discriminator, g_ema = pickle.load(f)

size = g_ema.output_shape[2]
latent_avg = g_ema.vars["dlatent_avg"].value().eval()

#batch_size = {256: 16, 512: 9, 1024: 4}
n_sample = 1#batch_size.get(size, 25)

z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.randomize_noise = False
#img_tf = g_ema.run(z, None, **Gs_kwargs)

n_mlp = 0
mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
for layer in mapping_layers_names:
if layer[0].startswith('Dense'):
n_mlp += 1

state_dict = fill_statedict({}, g_ema.vars, size, n_mlp)
state_dict['latent_avg'] = latent_avg
state_dict['size'] = size
state_dict['n_mlp'] = n_mlp
#state_dict['img_tf'] = img_tf
state_dict['n_sample'] = n_sample
state_dict['z'] = z
name = os.path.splitext(os.path.basename(args.path))[0]
with open(name+"_numpy.pkl", "wb") as f:
pickle.dump(state_dict, f)
102 changes: 102 additions & 0 deletions convert_weight2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import os
import sys
import pickle
import math

import numpy as np



if __name__ == "__main__":
device = "cuda"

parser = argparse.ArgumentParser(
description="Tensorflow to pytorch model checkpoint converter"
)
parser.add_argument(
"--gen", action="store_true", help="convert the generator weights"
)
parser.add_argument(
"--disc", action="store_true", help="convert the discriminator weights"
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor. config-f = 2, else = 1",
)
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")

args = parser.parse_args()

import torch
from torchvision import utils
from model import Generator, Discriminator

name = os.path.splitext(os.path.basename(args.path))[0]
with open(name+"_numpy.pkl", "rb") as f:
state_dict = pickle.load( f)
size = state_dict.pop('size')
n_mlp = state_dict.pop('n_mlp')
z = state_dict.pop('z')
n_sample = state_dict.pop('n_sample')
latent_avg = state_dict.pop('latent_avg')
#import pdb
#pdb.set_trace()

g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
for k,v in state_dict.items():
state_dict[k] = torch.from_numpy(v)
g.load_state_dict(state_dict, strict=False)

latent_avg = torch.from_numpy(latent_avg)

ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}

if args.gen:
raise NotImplementedError
g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
ckpt["g"] = g_train_state

if args.disc:
raise NotImplementedError
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
d_state = disc.state_dict()
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
ckpt["d"] = d_state

name = os.path.splitext(os.path.basename(args.path))[0]
torch.save(ckpt, name + ".pt")


g = g.to(device)

z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")

with torch.no_grad():
img_pt, _ = g(
[torch.from_numpy(z).to(device)],
truncation=0.5,
truncation_latent=latent_avg.to(device),
randomize_noise=False,
)

#Gs_kwargs = dnnlib.EasyDict()
#Gs_kwargs.randomize_noise = False
#img_tf = g_ema.run(z, None, **Gs_kwargs)
#img_tf = torch.from_numpy(img_tf).to(device)

#img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
# 0.0, 1.0
#)

#img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)

#print(img_diff.abs().max())

utils.save_image(
img_pt, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)

0 comments on commit 35b07f3

Please sign in to comment.