Skip to content

Commit

Permalink
Code has been refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
PDillis committed Apr 10, 2021
1 parent 1972dc2 commit 56be8eb
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 96 deletions.
12 changes: 6 additions & 6 deletions dnnlib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,13 @@ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
"""Takes in a list of tuples of (src, dst) paths and copies files.
Will create all necessary directories."""
for file in files:
target_dir_name = os.path.dirname(file[1])
# target_dir_name = os.path.dirname(file[1])
#
# # will create all intermediate-level directories
# if not os.path.exists(target_dir_name):
# os.makedirs(target_dir_name)

# will create all intermediate-level directories
if not os.path.exists(target_dir_name):
os.makedirs(target_dir_name)

shutil.copyfile(file[0], file[1])
shutil.copy(file[0], file[1])


# URL helpers
Expand Down
83 changes: 20 additions & 63 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from typing import List, Optional, Union
from locale import atof
import click
import imageio

import dnnlib
from torch_utils.gen_utils import num_range, parse_fps, compress_video, double_slowdown, make_run_dir, w_to_img
from torch_utils.gen_utils import num_range, parse_fps, compress_video, double_slowdown, \
make_run_dir, z_to_img, w_to_img, get_w_from_file, create_image_grid

import scipy
import numpy as np
Expand All @@ -34,58 +36,18 @@
def main():
pass

# ----------------------------------------------------------------------------


def create_image_grid(images, grid_size=None):
"""
Create a grid with the fed images
Args:
images (np.array): array of images
grid_size (tuple(int)): size of grid (grid_width, grid_height)
Returns:
grid (np.array): image grid of size grid_size
"""
# Sanity check
assert images.ndim == 3 or images.ndim == 4, f'Images has {images.ndim} dimensions (shape: {images.shape})!'
num, img_h, img_w, c = images.shape
# If user specifies the grid shape, use it
if grid_size is not None:
grid_w, grid_h = tuple(grid_size)
# If one of the sides is None, then we must infer it
if grid_w is None:
grid_w = num // grid_h + 1
elif grid_h is None:
grid_h = num // grid_w + 1

# Otherwise, we can infer it by the number of images (priority is given to grid_w)
else:
grid_w = max(int(np.ceil(np.sqrt(num))), 1)
grid_h = max((num - 1) // grid_w + 1, 1)

# Sanity check
assert grid_w * grid_h >= num, 'Number of rows and columns must be greater than the number of images!'
# Get the grid
grid = np.zeros([grid_h * img_h, grid_w * img_h] + list(images.shape[-1:]), dtype=images.dtype)
# Paste each image in the grid
for idx in range(num):
x = (idx % grid_w) * img_w
y = (idx // grid_w) * img_h
grid[y:y + img_h, x:x + img_w, ...] = images[idx]
return grid


# ----------------------------------------------------------------------------


@main.command(name='generate-images')
@main.command(name='images')
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--network', 'network_pkl', type=click.Path(exists=True, dir_okay=False), help='Network pickle filename', required=True)
@click.option('--seeds', type=num_range, help='List of random seeds')
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--projected-w', help='Projection result file; can be either .npy or .npz files', type=click.Path(dir_okay=False), metavar='FILE')
@click.option('--projected-w', help='Projection result file; can be either .npy or .npz files', type=click.Path(exists=True, dir_okay=False), metavar='FILE')
@click.option('--save-grid', help='Use flag to save image grid', is_flag=True, show_default=True)
@click.option('--grid-width', '-gw', type=int, help='Grid width (number of columns)', default=None)
@click.option('--grid-height', '-gh', type=int, help='Grid height (number of rows)', default=None)
Expand Down Expand Up @@ -142,15 +104,10 @@ def generate_images(
if seeds is not None:
print('warn: --seeds is ignored when using --projected-w')
print(f'Generating images from projected W "{projected_w}"')
if projected_w.endswith('.npy'):
ws = np.load(projected_w)
elif projected_w.endswith('.npz'):
ws = np.load(projected_w)['w']
else:
ctx.fail(f'Projected W latent vector "{projected_w}" has wrong file format! Use either ".npy" or ".npz" formats.')
ws = get_w_from_file(projected_w)
ws = torch.tensor(ws, device=device)
assert ws.shape[1:] == (G.num_ws, G.w_dim)
n_digits = int(np.log10(len(ws))) + 1 # number of digits to correctly generate and save the .jpg images
n_digits = int(np.log10(len(ws))) + 1 # number of digits for naming the .jpg images
for idx, w in enumerate(ws):
img = w_to_img(G, w, noise_mode)
PIL.Image.fromarray(img, 'RGB').save(f'{run_dir}/proj{idx:0{n_digits}d}.jpg')
Expand All @@ -174,8 +131,7 @@ def generate_images(
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
img = G(z=z, c=label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
img = z_to_img(G, z, label, truncation_psi, noise_mode)[0]
if save_grid:
images.append(img)
PIL.Image.fromarray(img, 'RGB').save(f'{run_dir}/seed{seed:04d}.jpg')
Expand All @@ -187,15 +143,16 @@ def generate_images(
PIL.Image.fromarray(create_image_grid(np.array(images)), 'RGB').save(f'{run_dir}/grid.png')
# The user tells the specific shape of the grid, but one value may be None
elif None in (grid_width, grid_height):
PIL.Image.fromarray(create_image_grid(np.array(images), (grid_width, grid_height)), 'RGB').save(f'{run_dir}/grid.png')
PIL.Image.fromarray(create_image_grid(np.array(images), (grid_width, grid_height)),
'RGB').save(f'{run_dir}/grid.png')


# ----------------------------------------------------------------------------


def _parse_slowdown(slowdown: Union[str, int]) -> int:
"""Function to parse the 'slowdown' parameter by the user. Will approximate to the nearest power of 2."""
# TODO: slowdown should be any int, we can modify the code to be slowed down to whatever amount we want
# TODO: slowdown should be any int
if not isinstance(slowdown, int):
slowdown = atof(slowdown)
assert slowdown > 0
Expand All @@ -206,7 +163,7 @@ def _parse_slowdown(slowdown: Union[str, int]) -> int:

@main.command(name='random-video')
@click.pass_context
@click.option('--network', 'network_pkl', type=click.Path(exists=True), help='Network pickle filename', required=True)
@click.option('--network', 'network_pkl', type=click.Path(exists=True, dir_okay=False), help='Network pickle filename', required=True)
@click.option('--seeds', type=num_range, help='List of random seeds', required=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
Expand All @@ -216,7 +173,7 @@ def _parse_slowdown(slowdown: Union[str, int]) -> int:
@click.option('--slowdown', type=_parse_slowdown, help='Slow down the video by this amount; will be approximated to the nearest power of 2', default='1', show_default=True)
@click.option('--duration-sec', '-sec', type=float, help='Duration length of the video', default=30.0, show_default=True)
@click.option('--fps', type=parse_fps, help='Video FPS.', default=30, show_default=True)
@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file via ffmpeg-python (same resolution, lower file size)')
@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out'), show_default=True, metavar='DIR')
@click.option('--desc', type=str, help='Description name for the directory path to save results', default='random-video', show_default=True)
def random_interpolation_video(
Expand Down Expand Up @@ -257,7 +214,8 @@ def random_interpolation_video(
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

# Create the run dir with the given name description
# Create the run dir with the given name description; add slowdown if different than the default (1)
desc = f'{desc}-{slowdown}xslowdown' if slowdown != 1 else desc
run_dir = make_run_dir(outdir, desc)

# Number of frames in the video and its total duration in seconds
Expand All @@ -274,14 +232,14 @@ def random_interpolation_video(
# Get the grid width and height according to num, giving priority to the number of columns
grid_width = max(int(np.ceil(np.sqrt(num_seeds))), 1)
grid_height = max((num_seeds - 1) // grid_width + 1, 1)
grid_size = [grid_width, grid_height]
grid_size = (grid_width, grid_height)
shape = [num_frames, G.z_dim] # This is per seed
# Get the z latents
all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)

# If only one seed is provided, but the specific grid shape is specified:
elif None not in (grid_width, grid_height) and len(seeds) == 1:
grid_size = [grid_width, grid_height]
grid_size = (grid_width, grid_height)
shape = [num_frames, np.prod(grid_size), G.z_dim]
# Since we have one seed, we use it to generate all latents
all_latents = np.random.RandomState(*seeds).randn(*shape).astype(np.float32)
Expand All @@ -290,7 +248,7 @@ def random_interpolation_video(
elif None not in (grid_width, grid_height) and len(seeds) >= 1:
# Case is similar to the first one
num_seeds = len(seeds)
grid_size = [grid_width, grid_height]
grid_size = (grid_width, grid_height)
available_slots = np.prod(grid_size)
if available_slots < num_seeds:
diff = num_seeds - available_slots
Expand Down Expand Up @@ -332,8 +290,7 @@ def make_frame(t):
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
latents = torch.from_numpy(all_latents[frame_idx]).to(device)
# Get the images with the labels
images = G(z=latents, c=label, truncation_psi=truncation_psi, noise_mode=noise_mode)
images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
images = z_to_img(G, latents, label, truncation_psi, noise_mode)
# Generate the grid for this timestamp
grid = create_image_grid(images, grid_size)
# Grayscale => RGB
Expand Down
34 changes: 14 additions & 20 deletions style_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import click

import dnnlib
from torch_utils.gen_utils import parse_fps, compress_video, make_run_dir
from torch_utils.gen_utils import parse_fps, compress_video, make_run_dir, w_to_img

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -87,14 +87,14 @@ def main():


@main.command(name='grid')
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
@click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
@click.option('--network', 'network_pkl', type=click.Path(exists=True, dir_okay=False), help='Network pickle filename', required=True)
@click.option('--row-seeds', '-rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
@click.option('--col-seeds', '-cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
@click.option('--styles', 'col_styles', type=num_range, help='Style layers to use; can pass "coarse", "middle", "fine", or a list or range of ints', default='0-6', show_default=True)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out'), show_default=True, metavar='DIR')
@click.option('--desc', type=str, help='Description name for the directory path to save results', default='style-mix-grid', show_default=True)
@click.option('--desc', type=str, help='Description name for the directory path to save results', default='stylemix-grid', show_default=True)
def generate_style_mix(
network_pkl: str,
row_seeds: List[int],
Expand Down Expand Up @@ -138,18 +138,16 @@ def generate_style_mix(
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}

print('Generating images...')
all_images = G.synthesis(all_w, noise_mode=noise_mode)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
all_images = w_to_img(G, all_w, noise_mode)
image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}

print('Generating style-mixed images...')
for row_seed in row_seeds:
for col_seed in col_seeds:
w = w_dict[row_seed].clone()
w[col_styles] = w_dict[col_seed][col_styles]
image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
image = w_to_img(G, w, noise_mode)[0]
image_dict[(row_seed, col_seed)] = image

print('Saving images...')
for (row_seed, col_seed), image in image_dict.items():
Expand Down Expand Up @@ -177,7 +175,7 @@ def generate_style_mix(

@main.command(name='video')
@click.pass_context
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--network', 'network_pkl', type=click.Path(exists=True, dir_okay=False),help='Network pickle filename', required=True)
@click.option('--row-seed', '-row', 'row_seed', type=int, help='Random seed to use for video row', required=True)
@click.option('--col-seeds', '-cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
@click.option('--styles', 'col_styles', type=num_range, help='Style layers to use; can pass "coarse", "middle", "fine", or a list or range of ints', default='0-6', show_default=True)
Expand All @@ -188,7 +186,7 @@ def generate_style_mix(
@click.option('--duration-sec', type=float, help='Duration of the video in seconds', default=30, show_default=True)
@click.option('--fps', type=parse_fps, help='Video FPS.', default=30, show_default=True)
@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out'), show_default=True, metavar='DIR')
@click.option('--desc', type=str, help='Description name for the directory path to save results', default='style-mix-video', show_default=True)
@click.option('--desc', type=str, help='Description name for the directory path to save results', default='stylemix-video', show_default=True)
def random_stylemix_video(
ctx: click.Context,
network_pkl: str,
Expand Down Expand Up @@ -284,8 +282,7 @@ def make_frame(t):
# Replace the values defined by col_styles
w_col[:, col_styles] = src_w[frame_idx, col_styles]
# Generate the style-mixed images
col_images = G.synthesis(w_col, noise_mode=noise_mode)
col_images = (col_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
col_images = w_to_img(G, w_col, noise_mode)
# Paste them in their respective spot in the grid
for row, image in enumerate(list(col_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * H, row * W))
Expand All @@ -299,8 +296,7 @@ def make_frame(t):
canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len([row_seed]) + 1)), 'black')

# Generate all destination images (first row; static images)
dst_images = G.synthesis(dst_w, noise_mode=noise_mode)
dst_images = (dst_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
dst_images = w_to_img(G, dst_w, noise_mode)
# Paste them in the canvas
for col, dst_image in enumerate(list(dst_images)):
canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * H, 0))
Expand All @@ -309,8 +305,7 @@ def make_frame(t):
# Get the frame number according to time t
frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
# Get the image at this frame (first column; video)
src_image = G.synthesis(src_w[frame_idx].unsqueeze(0), noise_mode=noise_mode) # [18, 512] -> [1, 18, 512]
src_image = (src_image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
src_image = w_to_img(G, src_w[frame_idx], noise_mode)[0]
# Paste it to the lower left
canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, H))

Expand All @@ -321,8 +316,7 @@ def make_frame(t):
# Replace the values defined by col_styles
w_col[:, col_styles] = src_w[frame_idx, col_styles]
# Generate these style-mixed images
col_images = G.synthesis(w_col, noise_mode=noise_mode)
col_images = (col_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
col_images = w_to_img(G, w_col, noise_mode)
# Paste them in their respective spot in the grid
for row, image in enumerate(list(col_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * H, (row + 1) * W))
Expand Down
Loading

0 comments on commit 56be8eb

Please sign in to comment.