Skip to content

Commit

Permalink
Updated path_batch_shrink
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Dec 23, 2019
1 parent 2492eaf commit 9388348
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ dmypy.json

wandb/
*.lmdb/
*.pkl
Binary file modified inception_ffhq.pkl
Binary file not shown.
19 changes: 13 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
pbar = range(args.iter)

if get_rank() == 0:
pbar = tqdm(pbar, dynamic_ncols=True)
pbar = tqdm(pbar, dynamic_ncols=True, smoothing=0.01)

mean_path_length = 0

Expand All @@ -121,6 +121,7 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
g_loss_val = 0
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}

sample_z = torch.randn(8 * 8, args.latent, device=device)
Expand Down Expand Up @@ -186,16 +187,20 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
# mean_path_length_avg = 0

generator.zero_grad()
g_loss.backward(retain_graph=g_regularize)
g_loss.backward(retain_graph=g_regularize and args.path_batch_shrink == 1)
g_optim.step()

if g_regularize:
if args.path_batch_shrink > 1:
if args.mixing > 0 and random.random() < args.mixing:
noise3 = make_noise(args.batch, args.latent, 2, device)
noise3 = make_noise(
args.batch // args.path_batch_shrink, args.latent, 2, device
)

else:
noise3 = make_noise(args.batch, args.latent, 1, device)
noise3 = make_noise(
args.batch // args.path_batch_shrink, args.latent, 1, device
)
noise3 = [noise3]

fake_img, latents = generator(noise3, return_latents=True)
Expand All @@ -214,7 +219,9 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
gather_grad(generator.parameters())
g_optim.step()

mean_path_length_avg = reduce_sum(mean_path_length) / get_world_size()
mean_path_length_avg = (
reduce_sum(mean_path_length).item() / get_world_size()
)

loss_dict['path'] = path_loss
loss_dict['path_length'] = path_lengths.mean()
Expand Down Expand Up @@ -291,7 +298,7 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
parser.add_argument('--path_regularize', type=float, default=2)
parser.add_argument('--path_batch_shrink', type=int, default=1)
parser.add_argument('--d_reg_every', type=int, default=16)
parser.add_argument('--g_reg_every', type=int, default=4)
parser.add_argument('--g_reg_every', type=int, default=8)
parser.add_argument('--mixing', type=float, default=0.9)
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--lr', type=float, default=0.002)
Expand Down

0 comments on commit 9388348

Please sign in to comment.