diff --git a/examples/pix2pose/model.py b/examples/pix2pose/model.py new file mode 100644 index 000000000..f8093300a --- /dev/null +++ b/examples/pix2pose/model.py @@ -0,0 +1,208 @@ +import numpy as np + +from tensorflow.keras.layers import Conv2D, Activation, UpSampling2D, Dense, Conv2DTranspose, Dropout, Input, Flatten, Reshape, LeakyReLU, BatchNormalization, Concatenate +from tensorflow.keras.models import Model +import tensorflow as tf + + +def loss_color_wrapped(rotation_matrices): + def loss_color_unwrapped(color_image, predicted_color_image): + min_loss = tf.float32.max + + # Bring the image in the range between 0 and 1 + color_image = (color_image + 1) * 0.5 + + # Calculate masks for the object and the background (they are independent of the rotation) + mask_object = tf.repeat(tf.expand_dims(tf.math.reduce_max(tf.math.ceil(color_image), axis=-1), axis=-1), + repeats=3, axis=-1) + mask_background = tf.ones(tf.shape(mask_object)) - mask_object + + # Bring the image again in the range between -1 and 1 + color_image = (color_image * 2) - 1 + + # Iterate over all possible rotations + for rotation_matrix in rotation_matrices: + + real_color_image = tf.identity(color_image) + + # Add a small epsilon value to avoid the discontinuity problem + real_color_image = real_color_image + tf.ones_like(real_color_image) * 0.0001 + + # Rotate the object + real_color_image = tf.einsum('ij,mklj->mkli', tf.convert_to_tensor(np.array(rotation_matrix), dtype=tf.float32), real_color_image) + #real_color_image = tf.where(tf.math.less(real_color_image, 0), tf.ones_like(real_color_image) + real_color_image, real_color_image) + + # Set the background to be all -1 + real_color_image *= mask_object + real_color_image += (mask_background*tf.constant(-1.)) + + # Get the number of pixels + num_pixels = tf.math.reduce_prod(tf.shape(real_color_image)[1:3]) + beta = 3 + + # Calculate the difference between the real and predicted images including the mask + diff_object = tf.math.abs(predicted_color_image*mask_object - real_color_image*mask_object) + diff_background = tf.math.abs(predicted_color_image*mask_background - real_color_image*mask_background) + + # Calculate the total loss + loss_colors = tf.cast((1/num_pixels), dtype=tf.float32)*(beta*tf.math.reduce_sum(diff_object, axis=[1, 2, 3]) + tf.math.reduce_sum(diff_background, axis=[1, 2, 3])) + min_loss = tf.math.minimum(loss_colors, min_loss) + return min_loss + + return loss_color_unwrapped + + +def loss_error(real_error_image, predicted_error_image): + # Get the number of pixels + num_pixels = tf.math.reduce_prod(tf.shape(real_error_image)[1:3]) + loss_error = tf.cast((1/num_pixels), dtype=tf.float32)*(tf.math.reduce_sum(tf.math.square(predicted_error_image - tf.clip_by_value(tf.math.abs(real_error_image), tf.float32.min, 1.)), axis=[1, 2, 3])) + + return loss_error + + +def Generator(): + bn_axis = 3 + + input = Input((128, 128, 3), name='input_image') + + # First layer of the encoder + e1_1 = Conv2D(64, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_1_1')(input) + e1_1 = BatchNormalization(bn_axis)(e1_1) + e1_1 = LeakyReLU()(e1_1) + + e1_2 = Conv2D(64, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_1_2')(input) + e1_2 = BatchNormalization(bn_axis)(e1_2) + e1_2 = LeakyReLU()(e1_2) + + e1 = Concatenate()([e1_1, e1_2]) + + # Second layer of the encoder + e2_1 = Conv2D(128, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_2_1')(e1) + e2_1 = BatchNormalization(bn_axis)(e2_1) + e2_1 = LeakyReLU()(e2_1) + + e2_2 = Conv2D(128, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_2_2')(e1) + e2_2 = BatchNormalization(bn_axis)(e2_2) + e2_2 = LeakyReLU()(e2_2) + + e2 = Concatenate()([e2_1, e2_2]) + + # Third layer of the encoder + e3_1 = Conv2D(128, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_3_1')(e2) + e3_1 = BatchNormalization(bn_axis)(e3_1) + e3_1 = LeakyReLU()(e3_1) + + e3_2 = Conv2D(128, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_3_2')(e2) + e3_2 = BatchNormalization(bn_axis)(e3_2) + e3_2 = LeakyReLU()(e3_2) + + e3 = Concatenate()([e3_1, e3_2]) + + # Fourth layer of the encoder + e4_1 = Conv2D(256, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_4_1')(e3) + e4_1 = BatchNormalization(bn_axis)(e4_1) + e4_1 = LeakyReLU()(e4_1) + + e4_2 = Conv2D(256, (5, 5), strides=(2, 2), padding='same', name='encoder_conv2D_4_2')(e3) + e4_2 = BatchNormalization(bn_axis)(e4_2) + e4_2 = LeakyReLU()(e4_2) + + e4 = Concatenate()([e4_1, e4_2]) + + # Latent dimension + x = Flatten()(e4) + x = Dense(256)(x) + x = Dense(8*8*256)(x) + x = Reshape((8, 8, 256))(x) + + # First layer of the decoder + d1_1 = Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', name='decoder_conv2D_1_1')(x) + d1_1 = BatchNormalization(bn_axis)(d1_1) + d1_1 = LeakyReLU()(d1_1) + + d1 = Concatenate()([d1_1, e3_2]) + + # Second layer of the decoder + d2_1 = Conv2D(256, (5, 5), strides=(1, 1), padding='same', name='decoder_conv2D_2_1')(d1) + d2_1 = BatchNormalization(bn_axis)(d2_1) + d2_1 = LeakyReLU()(d2_1) + + d2_2 = Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', name='decoder_conv2D_2_2')(d2_1) + d2_2 = BatchNormalization(bn_axis)(d2_2) + d2_2 = LeakyReLU()(d2_2) + + d2 = Concatenate()([d2_2, e2_2]) + + # Third layer of the decoder + d3_1 = Conv2D(256, (5, 5), strides=(1, 1), padding='same', name='decoder_conv2D_3_1')(d2) + d3_1 = BatchNormalization(bn_axis)(d3_1) + d3_1 = LeakyReLU()(d3_1) + + d3_2 = Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', name='decoder_conv2D_3_2')(d3_1) + d3_2 = BatchNormalization(bn_axis)(d3_2) + d3_2 = LeakyReLU()(d3_2) + + d3 = Concatenate()([d3_2, e1_2]) + + # Fourth layer + d4_1 = Conv2D(128, (5, 5), strides=(1, 1), padding='same', name='decoder_conv2D_4_1')(d3) + d4_1 = BatchNormalization(bn_axis)(d4_1) + d4_1 = LeakyReLU()(d4_1) + + # Define the two outputs + color_output = Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same')(d4_1) + color_output = Activation('tanh', name='color_output')(color_output) + + error_output = Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same')(d4_1) + error_output = Activation('sigmoid', name='error_output')(error_output) + + # Define model + model = Model(inputs=[input], outputs=[color_output, error_output]) + + return model + + +def Discriminator(): + bn_axis = 3 + + input = Input((128, 128, 3), name='input_image') + + # First layer of the discriminator + d1 = Conv2D(64, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_1_1')(input) + d1 = BatchNormalization(bn_axis)(d1) + d1 = LeakyReLU(0.2)(d1) + + # Second layer of the discriminator + d2 = Conv2D(128, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_2_1')(d1) + d2 = BatchNormalization(bn_axis)(d2) + d2 = LeakyReLU(0.2)(d2) + + # Third layer of the discriminator + d3 = Conv2D(256, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_3_1')(d2) + d3 = BatchNormalization(bn_axis)(d3) + d3 = LeakyReLU(0.2)(d3) + + # Fourth layer of the discriminator + d4 = Conv2D(512, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_4_1')(d3) + d4 = BatchNormalization(bn_axis)(d4) + d4 = LeakyReLU(0.2)(d4) + + # Fifth layer of the discriminator + d5 = Conv2D(512, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_5_1')(d4) + d5 = BatchNormalization(bn_axis)(d5) + d5 = LeakyReLU(0.2)(d5) + + # Sixth layer of the discriminator + d6 = Conv2D(512, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_6_1')(d5) + d6 = BatchNormalization(bn_axis)(d6) + d6 = LeakyReLU(0.2)(d6) + + # Seventh layer of the discriminator + d7 = Conv2D(512, (3, 3), strides=(2, 2), padding='same', name='discriminator_conv2D_7_1')(d6) + d7 = BatchNormalization(bn_axis)(d7) + d7 = LeakyReLU(0.2)(d7) + + flatten = Flatten()(d7) + output = Dense(1, activation='sigmoid', name='discriminator_output')(flatten) + discriminator_model = Model(inputs=input, outputs=[output]) + return discriminator_model \ No newline at end of file diff --git a/examples/pix2pose/pipelines.py b/examples/pix2pose/pipelines.py new file mode 100644 index 000000000..ee4299670 --- /dev/null +++ b/examples/pix2pose/pipelines.py @@ -0,0 +1,174 @@ +import numpy as np +import os +import glob +import random +from tensorflow.keras.utils import Sequence + +from paz.abstract import SequentialProcessor, Processor +from paz.abstract.sequence import SequenceExtra +from paz.pipelines import RandomizeRenderedImage +from paz import processors as pr + + +class GeneratedImageProcessor(Processor): + """ + Loads pre-generated images + """ + def __init__(self, path_images, background_images_paths, num_occlusions=1, split=pr.TRAIN, no_ambiguities=False): + super(GeneratedImageProcessor, self).__init__() + self.copy = pr.Copy() + self.augment = RandomizeRenderedImage(background_images_paths, num_occlusions) + preprocessors_input = [pr.NormalizeImage()] + preprocessors_output = [NormalizeImageTanh()] + self.preprocess_input = SequentialProcessor(preprocessors_input) + self.preprocess_output = SequentialProcessor(preprocessors_output) + self.split = split + + # Total number of images + self.num_images = len(glob.glob(os.path.join(path_images, "image_original/*"))) + + # Load all images into memory to save time + self.images_original = [np.load(os.path.join(path_images, "image_original/image_original_{}.npy".format(str(i).zfill(7)))) for i in range(self.num_images)] + + if no_ambiguities: + self.images_colors = [np.load(os.path.join(path_images, "image_colors_no_ambiguities/image_colors_no_ambiguities_{}.npy".format(str(i).zfill(7)))) for i in range(self.num_images)] + else: + self.images_colors = [np.load(os.path.join(path_images, "image_colors/image_colors_{}.npy".format(str(i).zfill(7)))) for i in range(self.num_images)] + + self.alpha_original = [np.load(os.path.join(path_images, "alpha_original/alpha_original_{}.npy".format(str(i).zfill(7)))) for i in range(self.num_images)] + + + def call(self): + index = random.randint(0, self.num_images-1) + image_original = self.images_original[index] + image_colors = self.images_colors[index] + alpha_original = self.alpha_original[index] + + if self.split == pr.TRAIN: + image_original = self.augment(image_original, alpha_original) + + image_original = self.preprocess_input(image_original) + image_colors = self.preprocess_output(image_colors) + + return image_original, image_colors + + +class GeneratedImageGenerator(SequentialProcessor): + def __init__(self, path_images, size, background_images_paths, num_occlusions=1, split=pr.TRAIN): + super(GeneratedImageGenerator, self).__init__() + self.add(GeneratedImageProcessor( + path_images, background_images_paths, num_occlusions, split)) + self.add(pr.SequenceWrapper( + {0: {'input_image': [size, size, 3]}}, + {1: {'color_output': [size, size, 3]}, 0: {'error_output': [size, size, 1]}})) + +""" +Creates a batch of train data for the discriminator. For real images the label is 1, +for fake images the label is 0 +""" +def make_batch_discriminator(generator, input_images, color_output_images, label): + if label == 1: + return color_output_images, np.ones(len(color_output_images)) + elif label == 0: + predictions = generator.predict(input_images) + return predictions[0], np.zeros(len(predictions[0])) + + +class GeneratingSequencePix2Pose(SequenceExtra): + """Sequence generator used for generating samples. + Unfortunately the GeneratingSequence class from paz.abstract cannot be used here. Reason: not all of + the training data is available right at the start. The error images depend on the predicted color images, + so that they have to be generated on-the-fly during training. This is done here. + + # Arguments + processor: Function used for generating and processing ``samples``. + model: Keras model + batch_size: Int. + num_steps: Int. Number of steps for each epoch. + as_list: Bool, if True ``inputs`` and ``labels`` are dispatched as + lists. If false ``inputs`` and ``labels`` are dispatched as + dictionaries. + """ + def __init__(self, processor, model, batch_size, num_steps, as_list=False, rotation_matrices=None): + self.num_steps = num_steps + self.model = model + self.rotation_matrices = rotation_matrices + super(GeneratingSequencePix2Pose, self).__init__( + processor, batch_size, as_list) + + def __len__(self): + return self.num_steps + + def rotate_image(self, image, rotation_matrix): + mask_image = np.ma.masked_not_equal(np.sum(image, axis=-1), -1.*3).mask.astype(float) + mask_image = np.repeat(mask_image[..., np.newaxis], 3, axis=-1) + mask_background = np.ones_like(mask_image) - mask_image + + # Rotate the object + image_rotated = np.einsum('ij,klj->kli', rotation_matrix, image) + image_rotated *= mask_image + image_rotated += (mask_background * -1.) + + return image_rotated + + def process_batch(self, inputs, labels, batch_index): + input_images, samples = list(), list() + for sample_arg in range(self.batch_size): + sample = self.pipeline() + samples.append(sample) + input_image = sample['inputs'][self.ordered_input_names[0]] + input_images.append(input_image) + + input_images = np.asarray(input_images) + # This line is very important. If model.predict(...) is used instead the results are wrong. + # Reason: BatchNormalization behaves differently, depending on whether it is in train or + # inference mode. model.predict(...) is the inference mode, so the predictions here will + # be different from the predictions the model is trained on --> Result: the error images + # generated here are also wrong + predictions = self.model(input_images, training=True) + + # Calculate the errors between the target output and the predicted output + for sample_arg in range(self.batch_size): + sample = samples[sample_arg] + + # List of tuples of the form (error, error_image) + stored_errors = [] + + # Iterate over all rotation matrices to find the object position + # with the smallest error + for rotation_matrix in self.rotation_matrices: + color_image_rotated = self.rotate_image(sample['labels']['color_output'], rotation_matrix) + error_image = np.sum(predictions['color_output'][sample_arg] - color_image_rotated, axis=-1, keepdims=True) + + error_value = np.sum(np.abs(error_image)) + stored_errors.append((error_value, error_image)) + + # Select the error image with the smallest error + minimal_error_pair = min(stored_errors, key=lambda t: t[0]) + sample['labels'][self.ordered_label_names[0]] = minimal_error_pair[1] + self._place_sample(sample['inputs'], sample_arg, inputs) + self._place_sample(sample['labels'], sample_arg, labels) + + return inputs, labels + + +class NormalizeImageTanh(Processor): + """ + Normalize image so that the values are between -1 and 1 + """ + def __init__(self): + super(NormalizeImageTanh, self).__init__() + + def call(self, image): + return (image/127.5)-1 + + +class DenormalizeImageTanh(Processor): + """ + Transforms an image from the value range -1 to 1 back to 0 to 255 + """ + def __init__(self): + super(DenormalizeImageTanh, self).__init__() + + def call(self, image): + return (image + 1.0)*127.5 diff --git a/examples/pix2pose/pix2pose.sh b/examples/pix2pose/pix2pose.sh new file mode 100644 index 000000000..fb315cb9f --- /dev/null +++ b/examples/pix2pose/pix2pose.sh @@ -0,0 +1 @@ +python3 train.py --images_directory /home/fabian/.keras/tless_obj05/pix2pose/normal_coloring --background_images_directory /home/fabian/.keras/backgrounds --batch_size 4 --steps_per_epoch 5 --image_size 128 --rotation_matrices /home/fabian/Uni/masterarbeit/src/paz/examples/pix2pose/rotation_matrices/2_fold_symmetry_rotation_matrices.npy \ No newline at end of file diff --git a/examples/pix2pose/train.py b/examples/pix2pose/train.py new file mode 100644 index 000000000..c7adce3f1 --- /dev/null +++ b/examples/pix2pose/train.py @@ -0,0 +1,137 @@ +import os +import glob +import argparse +import numpy as np +import time + +from tensorflow.keras.callbacks import CSVLogger +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model + +from paz.abstract import GeneratingSequence +from paz.abstract.sequence import GeneratingSequence + +from pipelines import GeneratingSequencePix2Pose, GeneratedImageGenerator, make_batch_discriminator +from model import Generator, Discriminator, loss_color_wrapped, loss_error + + +description = 'Training script Pix2Pose model' +root_path = os.path.join(os.path.expanduser('~'), '.keras/') +parser = argparse.ArgumentParser(description=description) +parser.add_argument('-cl', '--class_name', default='tless05', type=str, + help='Class name to be added to model save path') +parser.add_argument('-id', '--background_images_directory', type=str, + help='Path to directory containing background images') +parser.add_argument('-pi', '--images_directory', type=str, + help='Path to pre-generated images (npy format)') +parser.add_argument('-bs', '--batch_size', default=4, type=int, + help='Batch size for training') +parser.add_argument('-lr', '--learning_rate', default=0.001, type=float, + help='Initial learning rate for Adam') +parser.add_argument('-ld', '--image_size', default=128, type=int, + help='Size of the side of a square image e.g. 64') +parser.add_argument('-e', '--max_num_epochs', default=10000, type=int, + help='Maximum number of epochs before finishing') +parser.add_argument('-st', '--steps_per_epoch', default=5, type=int, + help='Steps per epoch') +parser.add_argument('-oc', '--num_occlusions', default=2, type=int, + help='Number of occlusions') +parser.add_argument('-sa', '--save_path', + default=os.path.join( + os.path.expanduser('~'), '.keras/paz/models'), + type=str, help='Path for writing model weights and logs') +parser.add_argument('-rm', '--rotation_matrices', + type=str, help='Path to npy file with a list of rotation matrices', required=True) +parser.add_argument('-de', '--description', + type=str, help='Description of the model') +args = parser.parse_args() + +# Building the whole GAN model +dcgan_input = Input(shape=(128, 128, 3)) +discriminator = Discriminator() +generator = Generator() +color_output, error_output = generator(dcgan_input) +discriminator.trainable = False +discriminator_output = discriminator(color_output) +dcgan = Model(inputs=[dcgan_input], outputs={"color_output": color_output, "error_output": error_output, "discriminator_output": discriminator_output}) + +# For the loss function pix2pose needs to know all the rotations under which the pose looks the same +rotation_matrices = np.load(args.rotation_matrices) +loss_color = loss_color_wrapped(rotation_matrices) + +# Set the loss +optimizer = Adam(args.learning_rate, amsgrad=True) +losses = {"color_output": loss_color, + "error_output": loss_error, + "discriminator_output": "binary_crossentropy"} +lossWeights = {"color_output": 100.0, "error_output": 50.0, "discriminator_output": 1.0} +dcgan.compile(optimizer=optimizer, loss=losses, loss_weights=lossWeights, run_eagerly=True) + +discriminator.trainable = True +discriminator.compile(loss=['binary_crossentropy'], optimizer=optimizer) + +# Creating sequencer +background_image_paths = glob.glob(os.path.join(args.background_images_directory, '*.jpg')) +processor_train = GeneratedImageGenerator(os.path.join(args.images_directory, "train"), args.image_size, background_image_paths, num_occlusions=0) +processor_test = GeneratedImageGenerator(os.path.join(args.images_directory, "test"), args.image_size, background_image_paths, num_occlusions=0) +sequence_train = GeneratingSequencePix2Pose(processor_train, dcgan, args.batch_size, args.steps_per_epoch, rotation_matrices=rotation_matrices) +sequence_test = GeneratingSequencePix2Pose(processor_test, dcgan, args.batch_size, args.steps_per_epoch, rotation_matrices=rotation_matrices) + +# Making directory for saving model weights and logs +model_name = '_'.join([dcgan.name, args.class_name]) +save_path = os.path.join(args.save_path, model_name) +if not os.path.exists(save_path): + os.makedirs(save_path) + +# Setting callbacks +log = CSVLogger(os.path.join(save_path, '%s.log' % model_name)) +log.model = dcgan + +callbacks=[log] + +for callback in callbacks: + callback.on_train_begin() + +for num_epoch in range(args.max_num_epochs): + sequence_iterator_train = sequence_train.__iter__() + sequence_iterator_test = sequence_test.__iter__() + + for callback in callbacks: + callback.on_epoch_begin(num_epoch) + + for num_batch in range(args.steps_per_epoch): + # Train the discriminator + discriminator.trainable = True + batch = next(sequence_iterator_train) + + X_discriminator_real, y_discriminator_real = make_batch_discriminator(generator, batch[0]['input_image'], batch[1]['color_output'], 1) + loss_discriminator_real = discriminator.train_on_batch(X_discriminator_real, y_discriminator_real) + + X_discriminator_fake, y_discriminator_fake = make_batch_discriminator(generator, batch[0]['input_image'], batch[1]['color_output'], 0) + loss_discriminator_fake = discriminator.train_on_batch(X_discriminator_fake, y_discriminator_fake) + + loss_discriminator = (loss_discriminator_real + loss_discriminator_fake)/2. + + # Train the generator + discriminator.trainable = False + loss_dcgan, loss_color_output, loss_dcgan_discriminator, loss_error_output = dcgan.train_on_batch(batch[0]['input_image'], {"color_output": batch[1]['color_output'], "error_output": batch[1]['error_output'], "discriminator_output": np.ones((args.batch_size, 1))}) + + # Test the network + batch_test = next(sequence_iterator_test) + loss_dcgan_test, loss_color_output_test, loss_dcgan_discriminator_test, loss_error_output_test = dcgan.test_on_batch(batch_test[0]['input_image'], {"color_output": batch_test[1]['color_output'], "error_output": batch_test[1]['error_output'], "discriminator_output": np.ones((args.batch_size, 1))}) + + print("Loss DCGAN: {}".format(loss_dcgan)) + for callback in callbacks: + callback.on_epoch_end(num_epoch, logs={'loss_discriminator': loss_discriminator, + 'loss_dcgan': loss_dcgan, 'loss_color_output': loss_color_output, + 'loss_dcgan_discriminator': loss_dcgan_discriminator, + 'loss_error_output': loss_error_output, + 'loss_dcgan_test': loss_dcgan_test, 'loss_color_output_test': loss_color_output_test, + 'loss_dcgan_discriminator_test': loss_dcgan_discriminator_test, + 'loss_error_output_test': loss_error_output_test + }) + + +for callback in callbacks: + callback.on_train_end() \ No newline at end of file