forked from rosinality/stylegan2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_and_copy_lips.py
113 lines (88 loc) · 2.95 KB
/
generate_and_copy_lips.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
from detect_lips import LipDetector
from lip_optimizer import LipOptimizer
def generate(args, g_ema, device, mean_latent):
with torch.no_grad():
g_ema.eval()
sample_z = torch.randn(args.sample, args.latent, device=device)
sample, _ = g_ema(
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
)
return sample
def generate_from_sample_z(args, g_ema: Generator, device, mean_latent, sample_z):
g_ema.eval()
sample, _ = g_ema(
[sample_z.to(device)],
truncation=args.truncation,
truncation_latent=mean_latent,
)
return sample
def generate_from_sample_w(args, g_ema: Generator, device, mean_latent, sample_w):
g_ema.eval()
sample, _ = g_ema(
[sample_w.to(device)],
truncation=args.truncation,
truncation_latent=mean_latent,
input_is_latent=True,
)
return sample
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Generate samples from the generator")
parser.add_argument(
"--size", type=int, default=1024, help="output image size of the generator"
)
parser.add_argument(
"--sample",
type=int,
default=1,
help="number of samples to be generated for each image",
)
parser.add_argument(
"--pics", type=int, default=20, help="number of images to be generated"
)
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
parser.add_argument(
"--truncation_mean",
type=int,
default=4096,
help="number of vectors to calculate mean for the truncation",
)
parser.add_argument(
"--ckpt",
type=str,
default="stylegan2-ffhq-config-f.pt",
help="path to the model checkpoint",
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier of the generator. config-f = 2, else = 1",
)
args = parser.parse_args()
args.latent = 512
args.n_mlp = 8
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
checkpoint = torch.load(args.ckpt)
g_ema.load_state_dict(checkpoint["g_ema"])
if args.truncation < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
else:
mean_latent = None
# Load lip detection module
lip_detector = LipDetector()
# Load lip optimizer
lip_optimizer = LipOptimizer(g_ema, lip_detector, args)
# lip_optimizer.optimize_retarget(
# torch.randn(args.sample, args.latent, device=device),
# torch.randn(args.sample, args.latent, device=device),
# )
lip_optimizer.optimize_smile(torch.randn(args.sample, args.latent, device=device))