Skip to content

Commit

Permalink
Create and process the celebA dataset in (32,32,3) dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
snknitin committed Apr 26, 2018
1 parent f1c1bee commit 91891a2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
13 changes: 3 additions & 10 deletions Code/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,7 @@
import pandas as pd
import numpy as np
from keras.datasets import mnist, cifar10







import shared_utils as su


def load_dataset(hps):
Expand All @@ -34,9 +28,8 @@ def load_dataset(hps):
if hps.module == 'celeba':
# load CelebA data
width, height, channels = hps.train_celeba_dimensions
(X_train, y_train), (X_test, y_test) = celeba.load_data()
# rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = su.load_data()



# defining input dims
Expand Down
17 changes: 17 additions & 0 deletions Code/shared_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pandas as pd
import numpy as np
import h5py
from keras import backend as K
# import skimage
# from skimage import data, color, exposure
Expand Down Expand Up @@ -63,3 +64,19 @@ def squash(vectors, axis=-1):
return scale * vectors


def load_data():
"""
Loads the CelebA dataset from the hdf5 file and processes it
:return:
"""

with h5py.File(os.path.join(os.path.dirname(os.getcwd()),"Data/celebA/CelebA_32_data.h5"), "r") as hf:
# Loading the data as floats
X_real_train = hf["data"][:].astype(np.float32)
# Transpose to make channels the last dimension
X_real_train = X_real_train.transpose(0, 2, 3, 1)
# Normalizing the pixels
X_real_train = (X_real_train- 127.5) / 127.5
np.random.shuffle(X_real_train)
# Split to 80%
return X_real_train[:162080]
2 changes: 1 addition & 1 deletion Data/create_celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_HDF5(size=64):
help='Plot the images to make sure the data processing went OK')
args = parser.parse_args()

data_dir = "../../data/processed"
data_dir = os.path.join(os.getcwd(),"celebA/")

build_HDF5(args.jpeg_dir, size=args.img_size)

Expand Down

0 comments on commit 91891a2

Please sign in to comment.