Skip to content

Commit

Permalink
Add learning rate to training, my full projection code, add configura…
Browse files Browse the repository at this point in the history
…tions for dataset_tool.py, other updates.
  • Loading branch information
PDillis committed May 3, 2021
1 parent 08378e1 commit dab4f40
Show file tree
Hide file tree
Showing 5 changed files with 586 additions and 100 deletions.
127 changes: 103 additions & 24 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,41 @@
import PIL.Image
from tqdm import tqdm

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def error(msg):
print('Error: ' + msg)
sys.exit(1)

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def maybe_min(a: int, b: Optional[int]) -> int:
if b is not None:
return min(a, b)
return a

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def file_ext(name: Union[str, Path]) -> str:
return str(name).split('.')[-1]

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def is_image_ext(fname: Union[str, Path]) -> bool:
ext = file_ext(fname).lower()
return f'.{ext}' in PIL.Image.EXTENSION # type: ignore

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_image_folder(source_dir, *, max_images: Optional[int]):
input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
Expand Down Expand Up @@ -75,7 +85,9 @@ def iterate_images():
break
return max_idx, iterate_images()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_image_zip(source, *, max_images: Optional[int]):
with zipfile.ZipFile(source, mode='r') as z:
Expand Down Expand Up @@ -104,7 +116,9 @@ def iterate_images():
break
return max_idx, iterate_images()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
import cv2 # pip install opencv-python
Expand Down Expand Up @@ -132,7 +146,9 @@ def iterate_images():

return max_idx, iterate_images()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_cifar10(tarball: str, *, max_images: Optional[int]):
images = []
Expand Down Expand Up @@ -164,7 +180,9 @@ def iterate_images():

return max_idx, iterate_images()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_mnist(images_gz: str, *, max_images: Optional[int]):
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
Expand Down Expand Up @@ -194,15 +212,18 @@ def iterate_images():

return max_idx, iterate_images()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def make_transform(
transform: Optional[str],
output_width: Optional[int],
output_height: Optional[int],
resize_filter: str
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
resample = {'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS}[resize_filter]

def scale(width, height, img):
w = img.shape[1]
h = img.shape[0]
Expand All @@ -216,7 +237,7 @@ def scale(width, height, img):

def center_crop(width, height, img):
crop = np.min(img.shape[:2])
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), resample)
return np.array(img)
Expand All @@ -226,28 +247,69 @@ def center_crop_wide(width, height, img):
if img.shape[1] < width or ch < height:
return None

img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
img = img[(img.shape[0] - ch) // 2: (img.shape[0] + ch) // 2] # center-crop: [width0, height0, 3] -> [width, height0, 3]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), resample)
img = img.resize((width, height), resample) # resize: [width, height0, 3] -> [width, height, 3]
img = np.array(img)

canvas = np.zeros([width, width, 3], dtype=np.uint8)
canvas[(width - height) // 2 : (width + height) // 2, :] = img
canvas = np.zeros([width, width, 3], dtype=np.uint8) # square canvas
canvas[(width - height) // 2: (width + height) // 2, :] = img # replace the middle with img
return canvas

def center_crop_tall(width, height, img):
ch = int(np.round(height * img.shape[1] / img.shape[0]))
if img.shape[0] < height or ch < width:
return None

img = img[:, (img.shape[1] - ch) // 2: (img.shape[1] + ch) // 2] # center-crop: [width0, height0, 3] -> [width0, height, 3]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), resample) # resize: [width0, height, 3] -> [width, height, 3]
img = np.array(img)

canvas = np.zeros([height, height, 3], dtype=np.uint8) # square canvas
canvas[:, (height - width) // 2: (height + width) // 2] = img # replace the middle with img
return canvas

def resize_pad(width, height, img):
pass

def multi_crop(width, height, img):
pass

def cut_crop(width, height, img):
pass

if transform is None:
return functools.partial(scale, output_width, output_height)
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --width and --height when using ' + transform + 'transform')
error(f'must specify --width and --height when using {transform} transform')
return functools.partial(center_crop, output_width, output_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --width and --height when using ' + transform + ' transform')
error(f'must specify --width and --height when using {transform} transform')
return functools.partial(center_crop_wide, output_width, output_height)
if transform == 'center-crop-tall':
if (output_width is None) or (output_height is None):
error(f'must specify --width and --height when using {transform} transform')
return functools.partial(center_crop_tall, output_width, output_height)
if transform == 'resize-pad':
if (output_width is None) or (output_height is None):
error(f'must specify --width or --height when using {transform} transform')
return functools.partial(resize_pad, output_width, output_height)
if transform == 'multi-crop':
if (output_width is None) or (output_height is None):
error(f'must specify --width or --height when using {transform} transform')
return functools.partial(multi_crop, output_width, output_height)
if transform == 'cut-crop':
if (output_width is None) or (output_height is None):
error(f'must specify --width and --height when using {transform} transform')
return functools.partial(cut_crop, output_width, output_height)
assert False, 'unknown transform'

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_dataset(source, *, max_images: Optional[int]):
if os.path.isdir(source):
Expand All @@ -267,7 +329,9 @@ def open_dataset(source, *, max_images: Optional[int]):
else:
error(f'Missing input file or directory: {source}')

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
dest_ext = file_ext(dest)
Expand Down Expand Up @@ -299,15 +363,17 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]):
fout.write(data)
return dest, folder_write_bytes, lambda: None

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


@click.command()
@click.pass_context
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-tall', 'resize-pad', 'multi-crop', 'cut-crop']))
@click.option('--width', help='Output width', type=int)
@click.option('--height', help='Output height', type=int)
def convert_dataset(
Expand Down Expand Up @@ -377,6 +443,14 @@ def convert_dataset(
\b
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
--transform=center-crop-wide --width 512 --height=384
If your images are taller than they are wider (or you just want vertical bars to
the right and left of the image), you can use --transform=center-crop-tall, along
with --width and --height options. For example:
\b
python dataset_tool.py --source AFHQ/train/cat --dest /tmp/afhqcat_vertical.zip \\
--transform=center-crop-tall --width 384 --height 512
"""

PIL.Image.init() # type: ignore
Expand Down Expand Up @@ -426,7 +500,7 @@ def convert_dataset(
error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))

# Save the image as an uncompressed PNG.
img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
image_bits = io.BytesIO()
img.save(image_bits, format='png', compress_level=0, optimize=False)
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
Expand All @@ -438,7 +512,12 @@ def convert_dataset(
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
close_dest()

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------


if __name__ == "__main__":
convert_dataset() # pylint: disable=no-value-for-parameter


# ----------------------------------------------------------------------------
Loading

0 comments on commit dab4f40

Please sign in to comment.