Skip to content

Commit

Permalink
Can control more aspects of projeciton in the command line (init lr, …
Browse files Browse the repository at this point in the history
…constant lr, etc.), can now re-center W according to a seed or projected vector for random interpolations in generate.py, general code linting, and update README
  • Loading branch information
PDillis committed Jun 21, 2021
1 parent dab4f40 commit 8a05574
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 127 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ This README will eventually be updated to correctly show how to run the followin
details.
* **Vertical mirroring for training**: Use `--mirror-y=True` when training your model to mirror your training images along the horizontal axis.
* [**Project in W+**](https://arxiv.org/abs/1904.03189): Use `--project-in-wplus` when running `projector.py` to project
in the W+ latent space. Use `--help` for better guidance for now.
in the W+ latent space. Use `--help` for better guidance for now. Thanks to [Peter Baylies](https://github.com/pbaylies)
on how to do this cleanly.
* **Save all steps in the projection**: Running `projector.py` with `--save-every-step` will save all the frames of
the projection video as different `.jpg` files, as well as the projected disentangled latent vector at each step. These
will be saved in the `.npy` format in the run dir.
Expand Down
109 changes: 74 additions & 35 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"""Generate images using pretrained network pickle."""

import os
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import click

import dnnlib
from torch_utils.gen_utils import num_range, parse_fps, compress_video, double_slowdown, \
make_run_dir, z_to_img, w_to_img, get_w_from_file, create_image_grid, save_config, parse_slowdown
make_run_dir, z_to_img, w_to_img, get_w_from_file, create_image_grid, save_config, parse_slowdown, get_w_from_seed

import scipy
import numpy as np
Expand Down Expand Up @@ -131,14 +131,16 @@ def generate_images(
print('warn: --class=lbl ignored when running on an unconditional network')

if training_snapshot:
# Note: this doesn't really work, so more work is warranted
print('Recreating the snapshot grid...')
size_dict = {'1080p': (1920, 1080, 3, 2), '4k': (3840, 2160, 7, 4), '8k': (7680, 4320, 7, 4)}
grid_width = int(np.clip(size_dict[snapshot_size][0] // G.img_resolution, size_dict[snapshot_size][2], 32))
grid_height = int(np.clip(size_dict[snapshot_size][1] // G.img_resolution, size_dict[snapshot_size][3], 32))
num_images = grid_width * grid_height

rnd = np.random.RandomState(0)
all_indices = list(range(15654)) # irrelevant
torch.manual_seed(0)
all_indices = list(range(70000)) # irrelevant
rnd.shuffle(all_indices)

grid_z = rnd.randn(num_images, G.z_dim) # TODO: generate with torch, as in the training_loop.py file
Expand Down Expand Up @@ -193,18 +195,28 @@ def generate_images(
'description': description,
'projected_w': projected_w
}
# Save the run configuration
save_config(ctx=ctx, run_dir=run_dir)


# ----------------------------------------------------------------------------


def _parse_new_center(s: str) -> Tuple[str, Union[int, np.ndarray]]:
"""Get a new center for the W latent space (a seed or projected dlatent; to be transformed later)"""
try:
new_center = int(s) # it's a seed
return s, new_center
except ValueError:
new_center = get_w_from_file(s) # it's a projected dlatent
return s, new_center


@main.command(name='random-video')
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seeds', type=num_range, help='List of random seeds', required=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--new-center', type=_parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--grid-width', '-gw', type=int, help='Video grid width / number of columns', default=None, show_default=True)
Expand All @@ -220,6 +232,7 @@ def random_interpolation_video(
network_pkl: Union[str, os.PathLike],
seeds: Optional[List[int]],
truncation_psi: float,
new_center: Tuple[str, Union[int, np.ndarray]],
class_idx: Optional[int],
noise_mode: str,
grid_width: int,
Expand Down Expand Up @@ -310,26 +323,6 @@ def random_interpolation_video(
# Name of the video
mp4_name = f'{grid_width}x{grid_height}-slerp-{slowdown}xslowdown'

# Save the configuration used
ctx.obj = {
'network_pkl': network_pkl,
'seeds': seeds,
'truncation_psi': truncation_psi,
'class_idx': class_idx,
'noise_mode': noise_mode,
'grid_width': grid_width,
'grid_height': grid_height,
'slowdown': slowdown,
'duration_sec': duration_sec,
'video_fps': fps,
'run_dir': run_dir,
'description': description,
'compress': compress,
'smoothing_sec': smoothing_sec
}
# Save the run configuration
save_config(ctx=ctx, run_dir=run_dir)

# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
Expand All @@ -347,17 +340,42 @@ def random_interpolation_video(
frames=num_frames)
slowdown //= 2

def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Get the images with the labels
images = z_to_img(G, latents, label, truncation_psi, noise_mode)
# Generate the grid for this timestamp
grid = create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid
if new_center is None:
def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Get the images with the labels
images = z_to_img(G, latents, label, truncation_psi, noise_mode)
# Generate the grid for this timestamp
grid = create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid

else:
new_center, new_center_value = new_center
# We get the new center using the int or recovered dlatent
if isinstance(new_center_value, int):
new_w_avg = get_w_from_seed(G, device, new_center_value, truncation_psi=1.0) # We want the pure dlatent
elif isinstance(new_center_value, np.ndarray):
new_w_avg = torch.from_numpy(new_center_value).to(device)

def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Do the truncation trick with this new center
w = G.mapping(latents, None)
w = new_w_avg + (w - new_w_avg) * truncation_psi
# Get the images with the new center
images = w_to_img(G, w, noise_mode)
# Generate the grid for this timestamp
grid = create_image_grid(images, grid_size)
# Grayscale => RGB
if grid.shape[2] == 1:
grid = grid.repeat(3, 2)
return grid


# Generate video using the respective make_frame function
videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
Expand All @@ -371,6 +389,27 @@ def make_frame(t):
if compress:
compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)

# Save the configuration used
new_center = 'w_avg' if new_center is None else new_center
ctx.obj = {
'network_pkl': network_pkl,
'seeds': seeds,
'truncation_psi': truncation_psi,
'new_center': new_center,
'class_idx': class_idx,
'noise_mode': noise_mode,
'grid_width': grid_width,
'grid_height': grid_height,
'slowdown': slowdown,
'duration_sec': duration_sec,
'video_fps': fps,
'run_dir': run_dir,
'description': description,
'compress': compress,
'smoothing_sec': smoothing_sec
}
save_config(ctx=ctx, run_dir=run_dir)


# ----------------------------------------------------------------------------

Expand Down
Loading

0 comments on commit 8a05574

Please sign in to comment.