Skip to content

Commit

Permalink
Added samples
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Dec 19, 2019
1 parent 7550be4 commit 6893bdd
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 6 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch

## WARNING
## Notice

Currently I didn't fully validated my implementations. Also I have tried implement model and trainers to closely match original implementations as much as possible, but I could have missed details. So pleas use this implementation with cautions.
I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.

## Usage

Expand All @@ -18,4 +18,10 @@ Then you can train model in distributed settings

> python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH
train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.
train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.

## Samples

![Sample with truncation](sample.png)

At 40,000 iterations. (trained on 1.28M images)
4 changes: 3 additions & 1 deletion distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def reduce_sum(tensor):


def gather_grad(params):
world_size = get_world_size()

for param in params:
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data.div_(get_world_size())
param.grad.data.div_(world_size)


def all_gather(data):
Expand Down
22 changes: 20 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ def __init__(
lr_mlp=0.01,
):
super().__init__()

self.style_dim = style_dim

layers = [PixelNorm()]

Expand Down Expand Up @@ -423,9 +425,25 @@ def __init__(
in_channel = out_channel

self.n_latent = log_size * 2 - 2

def forward(self, styles, return_latents=False):


def mean_latent(self, n_latent):
latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device)
latent = self.style(latent_in).mean(0, keepdim=True)

return latent


def forward(self, styles, return_latents=False, truncation=0, truncation_latent=None):
styles = [self.style(s) for s in styles]

if truncation > 0:
style_t = []

for style in styles:
style_t.append(style + truncation * (truncation_latent - style))

styles = style_t

if len(styles) < 2:
inject_index = self.n_latent
Expand Down
Binary file added sample.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,19 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
fake_score_val = loss_reduced['fake_score'].mean().item()
path_length_val = loss_reduced['path_length'].mean().item()

if get_rank() == 0 or get_rank() == 1:
if (i + 1) % 256 == 0:
torch.save(
{
'g': generator.module.state_dict(),
'd': discriminator.module.state_dict(),
'g_ema': g_ema.state_dict(),
'g_optim': g_optim.state_dict(),
'd_optim': d_optim.state_dict(),
},
f'checkpoint/{get_rank()}-{str(i).zfill(6)}.pt',
)

if get_rank() == 0:
pbar.set_description(
(
Expand Down

0 comments on commit 6893bdd

Please sign in to comment.