From b995d1ada5045683c3129fa8353a8bbbd19d5dc6 Mon Sep 17 00:00:00 2001 From: akanazawa Date: Fri, 20 Apr 2018 17:54:25 -0700 Subject: [PATCH] added training code + data --- demo.py | 2 + do_train.sh | 13 + doc/train.md | 109 ++++ prepare_datasets.sh | 55 ++ requirements.txt | 2 +- src/RunModel.py | 2 +- src/config.py | 208 +++++- src/data_loader.py | 331 ++++++++++ src/datasets/__init__.py | 0 src/datasets/coco_to_tfrecords.py | 352 ++++++++++ src/datasets/common.py | 229 +++++++ src/datasets/convert_datasets.sh | 15 + src/datasets/lsp_to_tfrecords.py | 152 +++++ src/datasets/mpi_inf_3dhp/__init__.py | 0 .../mpi_inf_3dhp/read_mpi_inf_3dhp.py | 107 ++++ src/datasets/mpi_inf_3dhp_to_tfrecords.py | 285 ++++++++ src/datasets/mpii_to_tfrecords.py | 295 +++++++++ src/datasets/smpl_to_tfrecords.py | 119 ++++ src/main.py | 30 + src/models.py | 111 +++- src/ops.py | 60 ++ src/tf_smpl/batch_smpl.py | 2 +- src/trainer.py | 606 ++++++++++++++++++ src/util/data_utils.py | 345 ++++++++++ src/util/renderer.py | 38 +- 25 files changed, 3420 insertions(+), 48 deletions(-) create mode 100755 do_train.sh create mode 100644 doc/train.md create mode 100755 prepare_datasets.sh create mode 100644 src/data_loader.py create mode 100644 src/datasets/__init__.py create mode 100644 src/datasets/coco_to_tfrecords.py create mode 100644 src/datasets/common.py create mode 100644 src/datasets/convert_datasets.sh create mode 100644 src/datasets/lsp_to_tfrecords.py create mode 100644 src/datasets/mpi_inf_3dhp/__init__.py create mode 100644 src/datasets/mpi_inf_3dhp/read_mpi_inf_3dhp.py create mode 100644 src/datasets/mpi_inf_3dhp_to_tfrecords.py create mode 100644 src/datasets/mpii_to_tfrecords.py create mode 100644 src/datasets/smpl_to_tfrecords.py create mode 100644 src/main.py create mode 100644 src/ops.py create mode 100644 src/trainer.py create mode 100644 src/util/data_utils.py diff --git a/demo.py b/demo.py index e530df3f3..c65f45485 100644 --- a/demo.py +++ b/demo.py @@ -127,6 +127,8 @@ def main(img_path, json_path=None): if __name__ == '__main__': config = flags.FLAGS config(sys.argv) + # Using pre-trained model, change this to use your own. + config.load_path = src.config.PRETRAINED_MODEL config.batch_size = 1 diff --git a/do_train.sh b/do_train.sh new file mode 100755 index 000000000..5e1bd3337 --- /dev/null +++ b/do_train.sh @@ -0,0 +1,13 @@ +# TODO: Replace with where you downloaded your resnet_v2_50. +PRETRAINED=/scratch1/projects/tf_datasets/models/resnet_v2_50/resnet_v2_50.ckpt +# TODO: Replace with where you generated tf_record! +DATA_DIR=/scratch1/storage/hmr_release_files/test_tf_datasets/ + +CMD="python -m src.main --d_lr 1e-4 --e_lr 1e-5 --log_img_step 1000 --pretrained_model_path=${PRETRAINED} --data_dir ${DATA_DIR} --e_loss_weight 60. --batch_size=64 --use_3d_label True --e_3d_weight 60. --datasets lsp,lsp_ext,mpii,h36m,coco,mpi_inf_3dhp --epoch 75 --log_dir logs" + +# To pick up training/training from a previous model, set LP +# LP='logs/' +# CMD="python -m src.main --d_lr 1e-4 --e_lr 1e-5 --log_img_step 1000 --load_path=${LP} --e_loss_weight 60. --batch_size=64 --use_3d_label True --e_3d_weight 60. --datasets lsp lsp_ext mpii h36m coco mpi_inf_3dhp --epoch 75" + +echo $CMD +$CMD diff --git a/doc/train.md b/doc/train.md new file mode 100644 index 000000000..0a6911e6b --- /dev/null +++ b/doc/train.md @@ -0,0 +1,109 @@ +## Pre-reqs + +### Download required models + +1. Download the mean SMPL parameters (initialization) +``` +wget https://people.eecs.berkeley.edu/~kanazawa/cachedir/hmr/neutral_smpl_mean_params.h5 +``` + +Store this inside `hmr/models/`, along with the neutral SMPL model +(`neutral_smpl_with_cocoplus_reg.pkl`). + + +2. Download the pre-trained resnet-50 from +[Tensorflow](https://github.com/tensorflow/models/tree/master/research/slim#Pretrained) +``` +wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz && tar -xf resnet_v2_50_2017_04_14.tar.gz +``` + +3. In `src/do_train.sh`, replace the path of `PRETRAINED` to the path of this model (`resnet_v2_50.ckpt`). + +### Download datasets. +Download these datasets somewhere. + +- [LSP](http://sam.johnson.io/research/lsp_dataset.zip) and [LSP extended](http://sam.johnson.io/research/lspet_dataset.zip) +- [COCO](http://cocodataset.org/#download) we used 2014 Train. You also need to + install the [COCO API](https://github.com/cocodataset/cocoapi) for python. +- [MPII](http://human-pose.mpi-inf.mpg.de/#download) +- [MPI-INF-3DHP](http://human-pose.mpi-inf.mpg.de/#download) + +For Human3.6M, download the pre-computed tfrecords [here](https://drive.google.com/file/d/14RlfDlREouBCNsR1QGDP0qpOUIu5LlV5/view?usp=sharing). +Note that this is 9.1GB! I advice you do this in a directly outside of the HMR code base. +``` +wget https://angjookanazawa.com/cachedir/hmr/tf_records_human36m.tar.gz +``` + +If you use these datasets, please consider citing them. + +## Mosh Data. +We provide the MoShed data using the neutral SMPL model. +Please note that usage of this data is for [**non-comercial scientific research only**](http://mosh.is.tue.mpg.de/data_license). + +If you use any of the MoSh data, please cite: +``` +article{Loper:SIGASIA:2014, + title = {{MoSh}: Motion and Shape Capture from Sparse Markers}, + author = {Loper, Matthew M. and Mahmood, Naureen and Black, Michael J.}, + journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, + volume = {33}, + number = {6}, + pages = {220:1--220:13}, + publisher = {ACM}, + address = {New York, NY, USA}, + month = nov, + year = {2014}, + url = {http://doi.acm.org/10.1145/2661229.2661273}, + month_numeric = {11} +} +``` + +[Download link to MoSh](https://drive.google.com/file/d/1b51RMzi_5DIHeYh2KNpgEs8LVaplZSRP/view?usp=sharing) + +## TFRecord Generation + +All the data has to be converted into TFRecords and saved to a `DATA_DIR` of +your choice. + +1. Make `DATA_DIR` where you will save the tf_records. For ex: +``` +mkdir ~/hmr/tf_datasets/ +``` + +2. Edit `prepare_datasets.sh`, with paths to where you downloaded the datasets, +and set `DATA_DIR` to the path to the directory you just made. + +3. From the root HMR directly (where README is), run `prepare_datasets.sh`, which calls the tfrecord conversion scripts: +``` +sh prepare_datasets.sh +``` + +This takes a while! If there is an issue consider running line by line. + +4. Move the downloaded human36m tf_records `tf_records_human36m.tar.gz` into the +`data_dir`: +``` +tar -xf tf_records_human36m.tar.gz +``` + +5. In `do_train.sh` and/or `src/config.py`, set `DATA_DIR` to the path where you saved the +tf_records. + + +## Training +Finally we can start training! +A sample training script (with parameters used in the paper) is in +`do_train.sh`. + +Update the path to in the beginning of this script and run: +``` +sh do_train.sh +``` + +The training write to a log directory that you can specify. +Setup tensorboard to this directory to monitor the training progress like so: +![Teaser Image](https://akanazawa.github.io/hmr/resources/images/tboard_ex.png) + +It's important to visually monitor the training! Make sure that the images +loaded look right. + diff --git a/prepare_datasets.sh b/prepare_datasets.sh new file mode 100755 index 000000000..e1e953731 --- /dev/null +++ b/prepare_datasets.sh @@ -0,0 +1,55 @@ +# --------------------------- +# ----- SET YOUR PATH!! ----- +# --------------------------- +# This is where you want all of your tf_records to be saved: +DATA_DIR=/scratch1/storage/hmr_release_files/test_tf_datasets + +# This is the directory that contains README.txt +LSP_DIR=/scratch1/storage/human_datasets/lsp_dataset + +# This is the directory that contains README.txt +LSP_EXT_DIR=/scratch1/storage/human_datasets/lsp_extended + +# This is the directory that contains 'images' and 'annotations' +MPII_DIR=/scratch1/storage/human_datasets/mpii + +# This is the directory that contains README.txt +COCO_DIR=/scratch1/storage/coco + +# This is the directory that contains README.txt, S1..S8, etc +MPI_INF_3DHP_DIR=/scratch1/storage/mpi_inf_3dhp + +## Mosh +# This is the path to the directory that contains neutrSMPL_* directories +MOSH_DIR=/scratch1/storage/human_datasets/neutrMosh +# --------------------------- + + +# --------------------------- +# Run each command below from this directory. I advice to run each one independently. +# --------------------------- + +# ----- LSP ----- +python -m src.datasets.lsp_to_tfrecords --img_directory $LSP_DIR --output_directory $DATA_DIR/lsp + +# ----- LSP-extended ----- +python -m src.datasets.lsp_to_tfrecords --img_directory $LSP_EXT_DIR --output_directory $DATA_DIR/lsp_ext + +# ----- MPII ----- +python -m src.datasets.mpii_to_tfrecords --img_directory $MPII_DIR --output_directory $DATA_DIR/mpii + +# ----- COCO ----- +python -m src.datasets.coco_to_tfrecords --data_directory $COCO_DIR --output_directory $DATA_DIR/coco + +# ----- MPI-INF-3DHP ----- +python -m src.datasets.mpi_inf_3dhp_to_tfrecords --data_directory $MPI_INF_3DHP_DIR --output_directory $DATA_DIR/mpi_inf_3dhp + +# ----- Mosh data, for each dataset ----- +# CMU: +python -m src.datasets.smpl_to_tfrecords --data_directory $MOSH_DIR --output_directory $DATA_DIR/mocap_neutrMosh --dataset_name 'neutrSMPL_CMU' + +# H3.6M: +python -m src.datasets.smpl_to_tfrecords --data_directory $MOSH_DIR --output_directory $DATA_DIR/mocap_neutrMosh --dataset_name 'neutrSMPL_H3.6' + +# jointLim: +python -m src.datasets.smpl_to_tfrecords --data_directory $MOSH_DIR --output_directory $DATA_DIR/mocap_neutrMosh --dataset_name 'neutrSMPL_jointLim' diff --git a/requirements.txt b/requirements.txt index 00d6967b0..f4f5539ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# python requiremenst +# python requirements pip>=9.0 scipy numpy diff --git a/src/RunModel.py b/src/RunModel.py index 66ef3866e..f3e14a8a2 100644 --- a/src/RunModel.py +++ b/src/RunModel.py @@ -113,7 +113,7 @@ def build_test_model_ief(self): poses = theta_here[:, self.num_cam:(self.num_cam + self.num_theta)] shapes = theta_here[:, (self.num_cam + self.num_theta):] - verts, Js = self.smpl(shapes, poses, get_skin=True) + verts, Js, _ = self.smpl(shapes, poses, get_skin=True) # Project to 2D! pred_kp = self.proj_fn(Js, cams, name='proj_2d_stage%d' % i) diff --git a/src/config.py b/src/config.py index 3edcf55bc..0e728830a 100644 --- a/src/config.py +++ b/src/config.py @@ -1,11 +1,23 @@ -""" +""" Sets default args Note all data format is NHWC because slim resnet wants NHWC. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys from absl import flags import os.path as osp +from os import makedirs +from glob import glob +from datetime import datetime +import json + +import numpy as np + curr_path = osp.dirname(osp.abspath(__file__)) model_dir = osp.join(curr_path, '..', 'models') if not osp.exists(model_dir): @@ -15,15 +27,17 @@ SMPL_MODEL_PATH = osp.join(model_dir, 'neutral_smpl_with_cocoplus_reg.pkl') SMPL_FACE_PATH = osp.join(curr_path, '../src/tf_smpl', 'smpl_faces.npy') -# Default model path +# Default pred-trained model path for the demo. PRETRAINED_MODEL = osp.join(model_dir, 'model.ckpt-667589') flags.DEFINE_string('smpl_model_path', SMPL_MODEL_PATH, 'path to the neurtral smpl model') flags.DEFINE_string('smpl_face_path', SMPL_FACE_PATH, - 'path to smpl mesh faces (for easy rendering)') -flags.DEFINE_string('load_path', PRETRAINED_MODEL, 'path to trained model') -flags.DEFINE_integer('batch_size', 1, + 'path to smpl mesh faces (for easy rendering)') +flags.DEFINE_string('load_path', None, 'path to trained model') +flags.DEFINE_string('pretrained_model_path', None, + 'if not None, fine-tunes from this ckpt') +flags.DEFINE_integer('batch_size', 8, 'Input image size to the network after preprocessing') # Don't change if testing: @@ -32,7 +46,189 @@ flags.DEFINE_string('data_format', 'NHWC', 'Data format') flags.DEFINE_integer('num_stage', 3, '# of times to iterate regressor') flags.DEFINE_string('model_type', 'resnet_fc3_dropout', - 'What kind of networks to use') + 'Specifies which network to use') flags.DEFINE_string( 'joint_type', 'cocoplus', 'cocoplus (19 keypoints) or lsp 14 keypoints, returned by SMPL') + +# Training settings: +# TODO! If you want to train, change this to your 'tf_datasets' or specify it with the flag. +DATA_DIR = '/scratch1/projects/tf_datasets/' + +flags.DEFINE_string('data_dir', DATA_DIR, 'Where to save training models') +flags.DEFINE_string('log_dir', 'logs', 'Where to save training models') +flags.DEFINE_string('model_dir', None, 'Where model will be saved -- filled automatically') +flags.DEFINE_integer('log_img_step', 100, 'How often to visualize img during training') +flags.DEFINE_integer('epoch', 100, '# of epochs to train') + +flags.DEFINE_list('datasets', ['lsp', 'lsp_ext', 'mpii', 'coco'], + 'datasets to use for training') +flags.DEFINE_list('mocap_datasets', ['CMU', 'H3.6', 'jointLim'], + 'datasets to use for adversarial prior training') + +# Model config +flags.DEFINE_boolean( + 'encoder_only', False, + 'if set, no adversarial prior is trained = monsters') + +flags.DEFINE_boolean( + 'use_3d_label', True, + 'Uses 3D labels if on.') + +# Hyper parameters: +flags.DEFINE_float('e_lr', 0.001, 'Encoder learning rate') +flags.DEFINE_float('d_lr', 0.001, 'Adversarial prior learning rate') +flags.DEFINE_float('e_wd', 0.0001, 'Encoder weight decay') +flags.DEFINE_float('d_wd', 0.0001, 'Adversarial prior weight decay') + +flags.DEFINE_float('e_loss_weight', 60, 'weight on E_kp losses') +flags.DEFINE_float('d_loss_weight', 1, 'weight on discriminator') + + +flags.DEFINE_float('e_3d_weight', 1, 'weight on E_3d') + +# Data augmentation +flags.DEFINE_integer('trans_max', 20, 'Value to jitter translation') +flags.DEFINE_float('scale_max', 1.23, 'Max value of scale jitter') +flags.DEFINE_float('scale_min', 0.8, 'Min value of scale jitter') + + +def get_config(): + config = flags.FLAGS + config(sys.argv) + + if 'resnet' in config.model_type: + setattr(config, 'img_size', 224) + # Slim resnet wants NHWC.. + setattr(config, 'data_format', 'NHWC') + + return config + + +# ----- For training ----- # + + +def prepare_dirs(config, prefix=['HMR']): + # Continue training from a load_path + if config.load_path: + if not osp.exists(config.load_path): + print("load_path: %s doesnt exist..!!!" % config.load_path) + import ipdb + ipdb.set_trace() + print('continuing from %s!' % config.load_path) + + # Check for changed training parameter: + # Load prev config param path + param_path = glob(osp.join(config.load_path, '*.json'))[0] + + with open(param_path, 'r') as fp: + prev_config = json.load(fp) + dict_here = config.__dict__ + ignore_keys = ['load_path', 'log_img_step', 'pretrained_model_path'] + diff_keys = [ + k for k in dict_here + if k not in ignore_keys and k in prev_config.keys() + and prev_config[k] != dict_here[k] + ] + + for k in diff_keys: + if k == 'load_path' or k == 'log_img_step': + continue + if prev_config[k] is None and dict_here[k] is not None: + print("%s is different!! before: None after: %g" % + (k, dict_here[k])) + elif prev_config[k] is not None and dict_here[k] is None: + print("%s is different!! before: %g after: None" % + (k, prev_config[k])) + else: + print("%s is different!! before: " % k) + print(prev_config[k]) + print("now:") + print(dict_here[k]) + + if len(diff_keys) > 0: + print("really continue??") + import ipdb + ipdb.set_trace() + + config.model_dir = config.load_path + + else: + postfix = [] + + # If config.dataset is not the same as default, add that to name. + default_dataset = [ + 'lsp', 'lsp_ext', 'mpii', 'h36m', 'coco', 'mpi_inf_3dhp' + ] + default_mocap = ['CMU', 'H3.6', 'jointLim'] + + if sorted(config.datasets) != sorted(default_dataset): + has_all_default = np.all( + [name in config.datasets for name in default_dataset]) + if has_all_default: + new_names = [ + name for name in sorted(config.datasets) + if name not in default_dataset + ] + postfix.append('default+' + '-'.join(sorted(new_names))) + else: + postfix.append('-'.join(sorted(config.datasets))) + if sorted(config.mocap_datasets) != sorted(default_mocap): + postfix.append('-'.join(config.mocap_datasets)) + + postfix.append(config.model_type) + + if config.num_stage != 3: + prefix += ["T%d" % config.num_stage] + + postfix.append("Elr%1.e" % config.e_lr) + + if config.e_loss_weight != 1: + postfix.append("kp-weight%g" % config.e_loss_weight) + + if not config.encoder_only: + postfix.append("Dlr%1.e" % config.d_lr) + if config.d_loss_weight != 1: + postfix.append("d-weight%g" % config.d_loss_weight) + + if config.use_3d_label: + print('Using 3D labels!!') + prefix.append("3DSUP") + if config.e_3d_weight != 1: + postfix.append("3dsup-weight%g" % config.e_3d_weight) + + # Data: + # Jitter amount: + if config.trans_max != 20: + postfix.append("transmax-%d" % config.trans_max) + if config.scale_max != 1.23: + postfix.append("scmax_%.3g" % config.scale_max) + if config.scale_min != 0.8: + postfix.append("scmin-%.3g" % config.scale_min) + + prefix = '_'.join(prefix) + postfix = '_'.join(postfix) + + time_str = datetime.now().strftime("%b%d_%H%M") + + save_name = "%s_%s_%s" % (prefix, postfix, time_str) + config.model_dir = osp.join(config.log_dir, save_name) + + for path in [config.log_dir, config.model_dir]: + if not osp.exists(path): + print('making %s' % path) + makedirs(path) + + +def save_config(config): + param_path = osp.join(config.model_dir, "params.json") + + print("[*] MODEL dir: %s" % config.model_dir) + print("[*] PARAM path: %s" % param_path) + + config_dict = {} + for k in dir(config): + config_dict[k] = config.__getattr__(k) + + with open(param_path, 'w') as fp: + json.dump(config_dict, fp, indent=4, sort_keys=True) diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 000000000..643580258 --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,331 @@ +""" +Data loader with data augmentation. +Only used for training. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os.path import join +from glob import glob + +import tensorflow as tf + +from .tf_smpl.batch_lbs import batch_rodrigues +from .util import data_utils + +_3D_DATASETS = ['h36m', 'up', 'mpi_inf_3dhp'] + + +def num_examples(datasets): + _NUM_TRAIN = { + 'lsp': 1000, + 'lsp_ext': 10000, + 'mpii': 20000, + 'h36m': 312188, + 'coco': 79344, + 'mpi_inf_3dhp': 147221, # without S8 + # Below is number for MOSH/mocap: + 'H3.6': 1559985, # without S9 and S11, + 'CMU': 3934267, + 'jointLim': 181968, + } + + if not isinstance(datasets, list): + datasets = [datasets] + total = 0 + + use_dict = _NUM_TRAIN + + for d in datasets: + total += use_dict[d] + return total + + +class DataLoader(object): + def __init__(self, config): + self.config = config + + self.use_3d_label = config.use_3d_label + + self.dataset_dir = config.data_dir + self.datasets = config.datasets + self.mocap_datasets = config.mocap_datasets + self.batch_size = config.batch_size + self.data_format = config.data_format + self.output_size = config.img_size + # Jitter params: + self.trans_max = config.trans_max + self.scale_range = [config.scale_min, config.scale_max] + + self.image_normalizing_fn = data_utils.rescale_image + + def load(self): + if self.use_3d_label: + image_loader = self.get_loader_w3d() + else: + image_loader = self.get_loader() + + return image_loader + + def get_loader(self): + """ + Outputs: + image_batch: batched images as per data_format + label_batch: batched keypoint labels N x K x 3 + """ + files = data_utils.get_all_files(self.dataset_dir, self.datasets) + + do_shuffle = self.split is 'train' + fqueue = tf.train.string_input_producer( + files, shuffle=do_shuffle, name="input") + image, label = self.read_data(fqueue, has_3d=False) + min_after_dequeue = 5000 + num_threads = 8 + capacity = min_after_dequeue + 3 * self.batch_size + + pack_these = [image, label] + pack_name = ['image', 'label'] + + all_batched = tf.train.shuffle_batch( + pack_these, + batch_size=self.batch_size, + num_threads=num_threads, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=False, + name='input_batch_train') + batch_dict = {} + for name, batch in zip(pack_name, all_batched): + batch_dict[name] = batch + + return batch_dict + + def get_loader_w3d(self): + """ + Similar to get_loader, but outputs are: + image_batch: batched images as per data_format + label_batch: batched keypoint labels N x K x 3 + label3d_batch: batched keypoint labels N x (216 + 10 + 42) + 216=24*3*3 pose, 10 shape, 42=14*3 3D joints + (3D datasets only have 14 joints annotated) + has_gt3d_batch: batched indicator for + existence of [3D joints, 3D SMPL] labels N x 2 - bool + Note 3D SMPL is only available for H3.6M. + + + Problem is that those datasets without pose/shape do not have them + in the tfrecords. There's no way to check for this in TF, + so, instead make 2 string_input_producers, one for data without 3d + and other for data with 3d. + And send [2 x *] to train.*batch + """ + datasets_no3d = [d for d in self.datasets if d not in _3D_DATASETS] + datasets_yes3d = [d for d in self.datasets if d in _3D_DATASETS] + + files_no3d = data_utils.get_all_files(self.dataset_dir, datasets_no3d) + files_yes3d = data_utils.get_all_files(self.dataset_dir, + datasets_yes3d) + + # Make sure we have dataset with 3D. + if len(files_yes3d) == 0: + print("Dont run this without any datasets with gt 3d") + import ipdb; ipdb.set_trace() + exit(1) + + do_shuffle = True + + fqueue_yes3d = tf.train.string_input_producer( + files_yes3d, shuffle=do_shuffle, name="input_w3d") + image, label, label3d, has_smpl3d = self.read_data( + fqueue_yes3d, has_3d=True) + + if len(files_no3d) != 0: + fqueue_no3d = tf.train.string_input_producer( + files_no3d, shuffle=do_shuffle, name="input_wout3d") + image_no3d, label_no3d = self.read_data(fqueue_no3d, has_3d=False) + label3d_no3d = tf.zeros_like(label3d) + image = tf.parallel_stack([image, image_no3d]) + label = tf.parallel_stack([label, label_no3d]) + label3d = tf.parallel_stack([label3d, label3d_no3d]) + # 3D joint is always available for data with 3d. + has_3d_joints = tf.constant([True, False], dtype=tf.bool) + has_3d_smpl = tf.concat([has_smpl3d, [False]], axis=0) + else: + # If no "no3d" images, need to make them 1 x * + image = tf.expand_dims(image, 0) + label = tf.expand_dims(label, 0) + label3d = tf.expand_dims(label3d, 0) + has_3d_joints = tf.constant([True], dtype=tf.bool) + has_3d_smpl = has_smpl3d + + # Combine 3D bools. + # each is 2 x 1, column is [3d_joints, 3d_smpl] + has_3dgt = tf.stack([has_3d_joints, has_3d_smpl], axis=1) + + min_after_dequeue = 2000 + capacity = min_after_dequeue + 3 * self.batch_size + + image_batch, label_batch, label3d_batch, bool_batch = tf.train.shuffle_batch( + [image, label, label3d, has_3dgt], + batch_size=self.batch_size, + num_threads=8, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + name='input_batch_train_3d') + + if self.data_format == 'NCHW': + image_batch = tf.transpose(image_batch, [0, 3, 1, 2]) + elif self.data_format == 'NHWC': + pass + else: + raise Exception("[!] Unkown data_format: {}".format( + self.data_format)) + + batch_dict = { + 'image': image_batch, + 'label': label_batch, + 'label3d': label3d_batch, + 'has3d': bool_batch, + } + + return batch_dict + + def get_smpl_loader(self): + """ + Loads dataset in form of queue, loads shape/pose of smpl. + returns a batch of pose & shape + """ + + data_dirs = [ + join(self.dataset_dir, 'mocap_neutrMosh', + 'neutrSMPL_%s_*.tfrecord' % dataset) + for dataset in self.mocap_datasets + ] + files = [] + for data_dir in data_dirs: + files += glob(data_dir) + + if len(files) == 0: + print('Couldnt find any files!!') + import ipdb + ipdb.set_trace() + + return self.get_smpl_loader_from_files(files) + + def get_smpl_loader_from_files(self, files): + """ + files = list of tf records. + """ + with tf.name_scope('input_smpl_loader'): + filename_queue = tf.train.string_input_producer( + files, shuffle=True) + + mosh_batch_size = self.batch_size * self.config.num_stage + + min_after_dequeue = 1000 + capacity = min_after_dequeue + 3 * mosh_batch_size + + pose, shape = data_utils.read_smpl_data(filename_queue) + pose_batch, shape_batch = tf.train.batch( + [pose, shape], + batch_size=mosh_batch_size, + num_threads=4, + capacity=capacity, + name='input_smpl_batch') + + return pose_batch, shape_batch + + def read_data(self, filename_queue, has_3d=False): + with tf.name_scope(None, 'read_data', [filename_queue]): + reader = tf.TFRecordReader() + _, example_serialized = reader.read(filename_queue) + if has_3d: + image, image_size, label, center, fname, pose, shape, gt3d, has_smpl3d = data_utils.parse_example_proto( + example_serialized, has_3d=has_3d) + # Need to send pose bc image can get flipped. + image, label, pose, gt3d = self.image_preprocessing( + image, image_size, label, center, pose=pose, gt3d=gt3d) + + # Convert pose to rotation. + # Do not ignore the global!! + rotations = batch_rodrigues(tf.reshape(pose, [-1, 3])) + gt3d_flat = tf.reshape(gt3d, [-1]) + # Label 3d is: + # [rotations, shape-beta, 3Djoints] + # [216=24*3*3, 10, 42=14*3] + label3d = tf.concat( + [tf.reshape(rotations, [-1]), shape, gt3d_flat], 0) + else: + image, image_size, label, center, fname = data_utils.parse_example_proto( + example_serialized) + image, label = self.image_preprocessing( + image, image_size, label, center) + + # label should be K x 3 + label = tf.transpose(label) + + if has_3d: + return image, label, label3d, has_smpl3d + else: + return image, label + + def image_preprocessing(self, + image, + image_size, + label, + center, + pose=None, + gt3d=None): + margin = tf.to_int32(self.output_size / 2) + with tf.name_scope(None, 'image_preprocessing', + [image, image_size, label, center]): + visibility = label[2, :] + keypoints = label[:2, :] + + # Randomly shift center. + print('Using translation jitter: %d' % self.trans_max) + center = data_utils.jitter_center(center, self.trans_max) + + # Pad image with safe margin. + # Extra 50 for safety. + margin_safe = margin + self.trans_max + 50 + image_pad = data_utils.pad_image_edge(image, margin_safe) + center_pad = center + margin_safe + keypoints_pad = keypoints + tf.to_float(margin_safe) + + start_pt = center_pad - margin + + # Crop image pad. + start_pt = tf.squeeze(start_pt) + bbox_begin = tf.stack([start_pt[1], start_pt[0], 0]) + bbox_size = tf.stack([self.output_size, self.output_size, 3]) + + crop = tf.slice(image_pad, bbox_begin, bbox_size) + x_crop = keypoints_pad[0, :] - tf.to_float(start_pt[0]) + y_crop = keypoints_pad[1, :] - tf.to_float(start_pt[1]) + + crop_kp = tf.stack([x_crop, y_crop, visibility]) + + if pose is not None: + crop, crop_kp, new_pose, new_gt3d = data_utils.random_flip( + crop, crop_kp, pose, gt3d) + else: + crop, crop_kp = data_utils.random_flip(crop, crop_kp) + + # Normalize kp output to [-1, 1] + final_vis = tf.cast(crop_kp[2, :] > 0, tf.float32) + final_label = tf.stack([ + 2.0 * (crop_kp[0, :] / self.output_size) - 1.0, + 2.0 * (crop_kp[1, :] / self.output_size) - 1.0, final_vis + ]) + # Preserving non_vis to be 0. + final_label = final_vis * final_label + + # rescale image from [0, 1] to [-1, 1] + crop = self.image_normalizing_fn(crop) + if pose is not None: + return crop, final_label, new_pose, new_gt3d + else: + return crop, final_label diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/datasets/coco_to_tfrecords.py b/src/datasets/coco_to_tfrecords.py new file mode 100644 index 000000000..cfba81005 --- /dev/null +++ b/src/datasets/coco_to_tfrecords.py @@ -0,0 +1,352 @@ +""" Convert Coco to TFRecords """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os.path import join, exists +from os import makedirs + +import numpy as np +import tensorflow as tf + +from pycocotools.coco import COCO + +from .common import convert_to_example, ImageCoder, resize_img + +tf.app.flags.DEFINE_string('data_directory', '/scratch1/storage/coco/', + 'data directory: top of coco') +tf.app.flags.DEFINE_string('output_directory', + '/scratch1/projects/tf_datasets/coco_wmask/', + 'Output data directory') + +tf.app.flags.DEFINE_integer('train_shards', 500, + 'Number of shards in training TFRecord files.') +tf.app.flags.DEFINE_integer('validation_shards', 500, + 'Number of shards in validation TFRecord files.') +FLAGS = tf.app.flags.FLAGS + +joint_names = [ + 'R Ankle', 'R Knee', 'R Hip', 'L Hip', 'L Knee', 'L Ankle', 'R Wrist', + 'R Elbow', 'R Shoulder', 'L Shoulder', 'L Elbow', 'L Wrist', 'Neck', + 'Head', 'Nose', 'L Eye', 'R Eye', 'L Ear', 'R Ear' +] + + +def convert_coco2universal(kp): + """ + Mapping from COCO joints (kp: 17 x 3) to + Universal 19 joints (14 lsp)+ (5 coco faces). + + Permutes and adds extra 0 two rows for missing head and neck + returns: 19 x 3 + """ + + UNIVERSAL_BODIES = [ + 16, # R ankle + 14, # R knee + 12, # R hip + 11, # L hip + 13, # L knee + 15, # L ankle + 10, # R Wrist + 8, # R Elbow + 6, # R shoulder + 5, # L shoulder + 7, # L Elbow + 9, # L Wrist + ] + UNIVERSAL_HEADS = range(5) + new_kp = np.vstack((kp[UNIVERSAL_BODIES, :], np.zeros((2, 3)), + kp[UNIVERSAL_HEADS, :])) + return new_kp + + +def get_anns_details(anns, coco, min_vis=5, min_max_height=60): + """ + anns is the list of annotations + coco is the cocoAPI + + extracts the boundingbox (using the mask) + and the keypoints for each person. + + Ignores the person if there is no or < min_vis keypoints + Ignores the person if max bbox length is <= min_max_height + """ + points_other_than_faceshoulder = [ + 16, # R ankle + 14, # R knee + 12, # R hip + 11, # L hip + 13, # L knee + 15, # L ankle + 10, # R Wrist + 8, # R Elbow + 7, # L Elbow + 9, # L Wrist + ] + filtered_anns = [] + kps = [] + centers, bboxes = [], [] + masks = [] + for ann in anns: + if 'keypoints' not in ann or type(ann['keypoints']) != list: + # Ignore those without keypoints + continue + if ann['num_keypoints'] == 0: + continue + + if 'segmentation' in ann: + # Use the mask to compute center + mask = coco.annToMask(ann) + # import ipdb; ipdb.set_trace() + # import matplotlib.pyplot as plt + # plt.ion() + # plt.figure(1) + # plt.imshow(mask) + # plt.pause(1e-3) + # this is N x 2 (in [x, y]) of fgpts + fg_pts = np.transpose(np.nonzero(mask))[:, ::-1] + min_pt = np.min(fg_pts, axis=0) + max_pt = np.max(fg_pts, axis=0) + bbox = [min_pt, max_pt[0] - min_pt[0], max_pt[1] - min_pt[1]] + center = (min_pt + max_pt) / 2. + else: + print('No segmentation!') + import ipdb + ipdb.set_trace() + + kp_raw = np.array(ann['keypoints']) + x = kp_raw[0::3] + y = kp_raw[1::3] + v = kp_raw[2::3] + # At least min_vis many visible (not occluded) kps. + if sum(v == 2) >= min_vis and max(bbox[2:]) > min_max_height: + # If only face & shoulder visible, skip. + if np.all(v[points_other_than_faceshoulder] == 0): + continue + kp = np.vstack([x, y, v]).T + kps.append(kp) + filtered_anns.append(ann) + centers.append(center) + bboxes.append(bbox) + masks.append(mask) + + return filtered_anns, kps, bboxes, centers, masks + + +def parse_people(kps, centers, masks): + ''' + Parses people i.e. figures out scale from annotation. + Input: + + Returns: + people - list of tuple (kp, img_scale, obj_pos) in this image. + ''' + # No single persons in this image. + if len(kps) == 0: + return [] + + # Read each human: + people = [] + + for kp, center, mask in zip(kps, centers, masks): + # Universal joints! + joints = convert_coco2universal(kp).T + # Scale person to be roughly 150x height + visible = joints[2, :].astype(bool) + min_pt = np.min(joints[:2, visible], axis=1) + max_pt = np.max(joints[:2, visible], axis=1) + person_height = np.linalg.norm(max_pt - min_pt) + + R_ank = joint_names.index('R Ankle') + L_ank = joint_names.index('L Ankle') + + # If ankles are visible + if visible[R_ank] or visible[L_ank]: + my_scale = 150. / person_height + else: + L_should = joint_names.index('L Shoulder') + L_hip = joint_names.index('L Hip') + R_should = joint_names.index('R Shoulder') + R_hip = joint_names.index('R Hip') + # Torso points left should, right shold, right hip, left hip + # torso_points = joints[:, [9, 8, 2, 3]] + torso_heights = [] + if visible[L_should] and visible[L_hip]: + torso_heights.append( + np.linalg.norm(joints[:2, L_should] - joints[:2, L_hip])) + if visible[R_should] and visible[R_hip]: + torso_heights.append( + np.linalg.norm(joints[:2, R_should] - joints[:2, R_hip])) + # Make torso 75px + if len(torso_heights) > 0: + my_scale = 75. / np.mean(torso_heights) + else: # No torso! + body_inds = np.array([0, 1, 2, 3, 4, 5, 6, 7, 10, 11]) + if np.all(visible[body_inds] == 0): + print('Face only! skip..') + continue + else: + my_scale = 50. / person_height + + people.append((joints, my_scale, center, mask)) + + return people + + +def add_to_tfrecord(coco, img_id, img_dir, coder, writer, is_train): + """ + Add each "single person" in this image. + coco - coco API + + Returns: + The number of people added. + """ + # Get annotation id for this guy + # Cat ids is [1] for human.. + ann_id = coco.getAnnIds(imgIds=img_id, catIds=[1], iscrowd=False) + anns = coco.loadAnns(ann_id) + # coco.showAnns(anns) + filtered_anns, kps, bboxes, centers, masks = get_anns_details( + anns, coco, min_vis=6, min_max_height=60) + + # Figure out the scale and pack each one in a tuple + people = parse_people(kps, centers, masks) + + if len(people) == 0: + # print('No single persons in img %d' % img_id) + return 0 + + # Add each people to tf record + img_data = coco.loadImgs(img_id)[0] + image_path = join(img_dir, img_data['file_name']) + with tf.gfile.FastGFile(image_path, 'rb') as f: + image_data = f.read() + + image = coder.decode_jpeg(image_data) + + for joints, scale, pos, mask in people: + # Scale image: + image_scaled, scale_factors = resize_img(image, scale) + height, width = image_scaled.shape[:2] + joints_scaled = np.copy(joints) + joints_scaled[0, :] *= scale_factors[0] + joints_scaled[1, :] *= scale_factors[1] + # center = pos * scale_factors + + visible = joints[2, :].astype(bool) + min_pt = np.min(joints_scaled[:2, visible], axis=1) + max_pt = np.max(joints_scaled[:2, visible], axis=1) + center = (min_pt + max_pt) / 2. + + ## Crop 400x400 around this image.. + margin = 200 + start_pt = np.maximum(center - margin, 0).astype(int) + end_pt = (center + margin).astype(int) + end_pt[0] = min(end_pt[0], width) + end_pt[1] = min(end_pt[1], height) + image_scaled = image_scaled[start_pt[1]:end_pt[1], start_pt[0]:end_pt[ + 0], :] + # Update others oo. + joints_scaled[0, :] -= start_pt[0] + joints_scaled[1, :] -= start_pt[1] + center -= start_pt + height, width = image_scaled.shape[:2] + + # Vis: + """ + import matplotlib.pyplot as plt + plt.ion() + plt.clf() + fig = plt.figure(1) + ax = fig.add_subplot(121) + image_with_skel = draw_skeleton(image, joints[:2, :], vis=visible, radius=(np.mean(image.shape[:2]) * 0.01).astype(int)) + ax.imshow(image_with_skel) + ax.axis('off') + # ax.imshow(image) + # ax.scatter(joints[0, visible], joints[1, visible]) + # ax.scatter(joints[0, ~visible], joints[1, ~visible], color='green') + ax.scatter(pos[0], pos[1], color='red') + ax = fig.add_subplot(122) + image_with_skel_scaled = draw_skeleton(image_scaled, joints_scaled[:2, :], vis=visible, radius=max(4, (np.mean(image_scaled.shape[:2]) * 0.01).astype(int))) + ax.imshow(image_with_skel_scaled) + ax.scatter(center[0], center[1], color='red') + # ax.imshow(image_scaled) + # ax.scatter(joints_scaled[0, visible], joints_scaled[1, visible]) + # ax.scatter(pos_scaled[0], pos_scaled[1], color='red') + ax.axis('on') + plt.draw() + plt.pause(0.01) + """ + + # Encode image: + image_data_scaled = coder.encode_jpeg(image_scaled) + example = convert_to_example(image_data_scaled, image_path, height, + width, joints_scaled, center) + writer.write(example.SerializeToString()) + + # Finally return how many were written. + return len(people) + + +def process_coco(data_dir, out_dir, num_shards, is_train=True): + + if is_train: + data_type = 'train2014' + out_path = join(out_dir, 'train_%04d_wmeta.tfrecord') + else: + data_type = 'val2014' + out_path = join(out_dir, 'val_%04d_wmeta.tfrecord') + + anno_file = join(data_dir, + 'annotations/person_keypoints_%s.json' % data_type) + img_dir = join(data_dir, 'images', data_type) + # initialize COCO api for person keypoints annotations + coco = COCO(anno_file) + catIds = coco.getCatIds(catNms=['person']) + img_inds = coco.getImgIds(catIds=catIds) + # Only run on 'single person's + coder = ImageCoder() + + i = 0 + # Count on shards + fidx = 0 + num_ppl = 0 + total_num_ppl = 0 + while i < len(img_inds): + tf_filename = out_path % fidx + print('Starting tfrecord file %s' % tf_filename) + with tf.python_io.TFRecordWriter(tf_filename) as writer: + # Count on total ppl in each shard + num_ppl = 0 + while i < len(img_inds) and num_ppl < num_shards: + if i % 100 == 0: + print('Reading img %d/%d' % (i, len(img_inds))) + num_ppl += add_to_tfrecord(coco, img_inds[i], img_dir, coder, + writer, is_train) + i += 1 + total_num_ppl += num_ppl + + fidx += 1 + + print('Made %d shards, with total # of people: %d' % + (fidx - 1, total_num_ppl)) + + +def main(unused_argv): + print('Saving results to %s' % FLAGS.output_directory) + + if not exists(FLAGS.output_directory): + makedirs(FLAGS.output_directory) + process_coco( + FLAGS.data_directory, + FLAGS.output_directory, + FLAGS.train_shards, + is_train=True) + # do_valid + # _process_coco(FLAGS.data_directory, FLAGS.output_directory, FLAGS.validation_shards, is_train=False) + + +if __name__ == '__main__': + tf.app.run() diff --git a/src/datasets/common.py b/src/datasets/common.py new file mode 100644 index 000000000..46ad985ee --- /dev/null +++ b/src/datasets/common.py @@ -0,0 +1,229 @@ +""" +Helpers for tfrecord conversion. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import numpy as np + + +class ImageCoder(object): + """Helper class that provides TensorFlow image coding utilities. + Taken from + https://github.com/tensorflow/models/blob/master/inception/inception/data/build_image_data.py + """ + + def __init__(self): + # Create a single Session to run all image coding calls. + self._sess = tf.Session() + + # Initializes function that converts PNG to JPEG data. + self._png_data = tf.placeholder(dtype=tf.string) + image = tf.image.decode_png(self._png_data, channels=3) + self._png_to_jpeg = tf.image.encode_jpeg( + image, format='rgb', quality=100) + + # Initializes function that decodes RGB JPEG data. + self._decode_jpeg_data = tf.placeholder(dtype=tf.string) + self._decode_jpeg = tf.image.decode_jpeg( + self._decode_jpeg_data, channels=3) + + self._encode_jpeg_data = tf.placeholder(dtype=tf.uint8) + self._encode_jpeg = tf.image.encode_jpeg( + self._encode_jpeg_data, format='rgb') + + self._decode_png_data = tf.placeholder(dtype=tf.string) + self._decode_png = tf.image.decode_png( + self._decode_png_data, channels=3) + + self._encode_png_data = tf.placeholder(dtype=tf.uint8) + self._encode_png = tf.image.encode_png(self._encode_png_data) + + def png_to_jpeg(self, image_data): + return self._sess.run( + self._png_to_jpeg, feed_dict={ + self._png_data: image_data + }) + + def decode_jpeg(self, image_data): + image = self._sess.run( + self._decode_jpeg, feed_dict={ + self._decode_jpeg_data: image_data + }) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + return image + + def encode_jpeg(self, image): + image_data = self._sess.run( + self._encode_jpeg, feed_dict={ + self._encode_jpeg_data: image + }) + return image_data + + def encode_png(self, image): + image_data = self._sess.run( + self._encode_png, feed_dict={ + self._encode_png_data: image + }) + return image_data + + def decode_png(self, image_data): + image = self._sess.run( + self._decode_png, feed_dict={ + self._decode_png_data: image_data + }) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + return image + + +def int64_feature(value): + """Wrapper for inserting int64 features into Example proto.""" + if not isinstance(value, list) and not isinstance(value, np.ndarray): + value = [value] + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + + +def float_feature(value): + """Wrapper for inserting float features into Example proto.""" + if not isinstance(value, list) and not isinstance(value, np.ndarray): + value = [value] + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def bytes_feature(value): + """Wrapper for inserting bytes features into Example proto.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def convert_to_example(image_data, image_path, height, width, label, center): + """Build an Example proto for an image example. + Args: + image_data: string, JPEG encoding of RGB image; + image_path: string, path to this image file + labels: 3 x 14 joint location + visibility --> This could be 3 x 19 + height, width: integers, image shapes in pixels. + center: 2 x 1 center of the tight bbox + Returns: + Example proto + """ + from os.path import basename + + image_format = 'JPEG' + add_face = False + if label.shape[1] == 19: + add_face = True + # Split and save facepts on it's own. + face_pts = label[:, 14:] + label = label[:, :14] + + feat_dict = { + 'image/height': int64_feature(height), + 'image/width': int64_feature(width), + 'image/center': int64_feature(center.astype(np.int)), + 'image/x': float_feature(label[0, :].astype(np.float)), + 'image/y': float_feature(label[1, :].astype(np.float)), + 'image/visibility': int64_feature(label[2, :].astype(np.int)), + 'image/format': bytes_feature(tf.compat.as_bytes(image_format)), + 'image/filename': bytes_feature( + tf.compat.as_bytes(basename(image_path))), + 'image/encoded': bytes_feature(tf.compat.as_bytes(image_data)), + } + if add_face: + # 3 x 5 + feat_dict.update({ + 'image/face_pts': + float_feature(face_pts.ravel().astype(np.float)) + }) + + example = tf.train.Example(features=tf.train.Features(feature=feat_dict)) + + return example + + +def convert_to_example_wmosh(image_data, image_path, height, width, label, + center, gt3d, pose, shape, scale_factors, + start_pt, cam): + """Build an Example proto for an image example. + Args: + image_data: string, JPEG encoding of RGB image; + image_path: string, path to this image file + labels: 3 x 14 joint location + visibility + height, width: integers, image shapes in pixels. + center: 2 x 1 center of the tight bbox + gt3d: 14x3 3D joint locations + scale_factors: 2 x 1, scale factor used to scale image. + start_pt: the left corner used to crop the _scaled_ image to 300x300 + cam: (3,), [f, px, py] intrinsic camera parameters. + Returns: + Example proto + """ + from os.path import basename + image_format = 'JPEG' + if label.shape[0] != 3: + label = label.T + if label.shape[1] > 14: + print('This shouldnt be happening') + import ipdb + ipdb.set_trace() + if pose is None: + has_3d = 0 + # Use -1 to save. + pose = -np.ones(72) + shape = -np.ones(10) + else: + has_3d = 1 + + example = tf.train.Example( + features=tf.train.Features(feature={ + 'image/height': + int64_feature(height), + 'image/width': + int64_feature(width), + 'image/center': + int64_feature(center.astype(np.int)), + 'image/x': + float_feature(label[0, :].astype(np.float)), + 'image/y': + float_feature(label[1, :].astype(np.float)), + 'image/visibility': + int64_feature(label[2, :].astype(np.int)), + 'image/format': + bytes_feature(tf.compat.as_bytes(image_format)), + 'image/filename': + bytes_feature(tf.compat.as_bytes(basename(image_path))), + 'image/encoded': + bytes_feature(tf.compat.as_bytes(image_data)), + 'mosh/pose': + float_feature(pose.astype(np.float)), + 'mosh/shape': + float_feature(shape.astype(np.float)), + 'mosh/gt3d': + float_feature(gt3d.ravel().astype(np.float)), + 'meta/scale_factors': + float_feature(np.array(scale_factors).astype(np.float)), + 'meta/crop_pt': + int64_feature(start_pt.astype(np.int)), + 'meta/has_3d': + int64_feature(has_3d), + 'image/cam': + float_feature(cam.astype(np.float)), + })) + + return example + + +def resize_img(img, scale_factor): + import cv2 + import numpy as np + new_size = (np.floor(np.array(img.shape[0:2]) * scale_factor)).astype(int) + new_img = cv2.resize(img, (new_size[1], new_size[0])) + # This is scale factor of [height, width] i.e. [y, x] + actual_factor = [ + new_size[0] / float(img.shape[0]), new_size[1] / float(img.shape[1]) + ] + return new_img, actual_factor diff --git a/src/datasets/convert_datasets.sh b/src/datasets/convert_datasets.sh new file mode 100644 index 000000000..da0058a62 --- /dev/null +++ b/src/datasets/convert_datasets.sh @@ -0,0 +1,15 @@ +# Change to your +OUT_DIR='/Users/kanazawa/projects/tf_datasets/' + +# Change to where each dataset directory is: +LSP_DIR='/scratch1/storage/human_datasets/lsp_dataset/' +LSP_EXT_DIR='/scratch1/storage/human_datasets/lsp_extended/' +MPII_DIR='/scratch1/storage/human_datasets/mpii/' + +# LSP: +python lsp_to_tfrecords.py --img_directory $LSP_DIR --output_directory $OUT_DIR/lsp +# LSP-extended: +python lsp_to_tfrecords.py --img_directory $LSP_EXT_DIR --output_directory $OUT_DIR/lsp_ext + +# MPII: +python mpii_to_tfrecords.py --img_directory $MPII_DIR --output_directory $OUT_DIR/mpii diff --git a/src/datasets/lsp_to_tfrecords.py b/src/datasets/lsp_to_tfrecords.py new file mode 100644 index 000000000..af2109275 --- /dev/null +++ b/src/datasets/lsp_to_tfrecords.py @@ -0,0 +1,152 @@ +""" +Convert LSP/LSP extended to TFRecords. +In LSP, the first 1000 is training and the last 1000 is test/validation. +All of LSP extended is training. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os import makedirs +from os.path import join, exists +from glob import glob + +import numpy as np + +import tensorflow as tf + +from .common import convert_to_example, ImageCoder + +tf.app.flags.DEFINE_string('img_directory', + '/scratch1/storage/human_datasets/lsp_dataset', + 'image data directory') +tf.app.flags.DEFINE_string( + 'output_directory', '/Users/kanazawa/projects/datasets/tf_datasets/lsp/', + 'Output data directory') + +tf.app.flags.DEFINE_integer('train_shards', 500, + 'Number of shards in training TFRecord files.') +tf.app.flags.DEFINE_integer('validation_shards', 500, + 'Number of shards in validation TFRecord files.') + +FLAGS = tf.app.flags.FLAGS + + +def _add_to_tfrecord(image_path, label, coder, writer, is_lsp_ext=False): + with tf.gfile.FastGFile(image_path, 'rb') as f: + image_data = f.read() + + image = coder.decode_jpeg(image_data) + height, width = image.shape[:2] + assert image.shape[2] == 3 + + # LSP 3-D dim, 0 means visible 1 means invisible. + # But in LSP-ext, 0 means invis, 1 means visible + # Negate this + if is_lsp_ext: + visible = label[2, :].astype(bool) + else: + visible = np.logical_not(label[2, :]) + label[2, :] = visible.astype(label.dtype) + min_pt = np.min(label[:2, visible], axis=1) + max_pt = np.max(label[:2, visible], axis=1) + center = (min_pt + max_pt) / 2. + """ + import matplotlib.pyplot as plt + plt.ion() + plt.clf() + fig = plt.figure(1) + ax = fig.add_subplot(111) + plt.imshow(image) + plt.scatter(label[0, visible], label[1, visible]) + plt.scatter(center[0], center[1]) + # bwidth, bheight = max_pt - min_pt + 1 + # rect = plt.Rectangle(min_pt, bwidth, bheight, fc='None', ec='green') + # ax.add_patch(rect) + import ipdb; ipdb.set_trace() + """ + + example = convert_to_example(image_data, image_path, height, width, label, + center) + + writer.write(example.SerializeToString()) + + +def package(img_paths, labels, out_path, num_shards): + """ + packages the images and labels into multiple tfrecords. + """ + is_lsp_ext = True if len(img_paths) == 10000 else False + coder = ImageCoder() + + i = 0 + fidx = 0 + while i < len(img_paths): + # Open new TFRecord file. + tf_filename = out_path % fidx + print('Starting tfrecord file %s' % tf_filename) + with tf.python_io.TFRecordWriter(tf_filename) as writer: + j = 0 + while i < len(img_paths) and j < num_shards: + if i % 100 == 0: + print('Converting image %d/%d' % (i, len(img_paths))) + _add_to_tfrecord( + img_paths[i], + labels[:, :, i], + coder, + writer, + is_lsp_ext=is_lsp_ext) + i += 1 + j += 1 + + fidx += 1 + + +def load_mat(fname): + import scipy.io as sio + res = sio.loadmat(fname) + # this is 3 x 14 x 2000 + return res['joints'] + + +def process_lsp(img_dir, out_dir, num_shards_train, num_shards_test): + """Process a complete data set and save it as a TFRecord. + LSP has 2000 images, first 1000 is train, last 1000 is test. + + Args: + img_dir: string, root path to the data set. + num_shards: integer number of shards for this data set. + """ + # Load labels 3 x 14 x N + labels = load_mat(join(img_dir, 'joints.mat')) + if labels.shape[0] != 3: + labels = np.transpose(labels, (1, 0, 2)) + + all_images = sorted([f for f in glob(join(img_dir, 'images/*.jpg'))]) + + if len(all_images) == 10000: + # LSP-extended is all train. + train_out = join(out_dir, 'train_%03d.tfrecord') + package(all_images, labels, train_out, num_shards_train) + else: + train_out = join(out_dir, 'train_%03d.tfrecord') + + package(all_images[:1000], labels[:, :, :1000], train_out, + num_shards_train) + + test_out = join(out_dir, 'test_%03d.tfrecord') + package(all_images[1000:], labels[:, :, 1000:], test_out, + num_shards_test) + + +def main(unused_argv): + print('Saving results to %s' % FLAGS.output_directory) + + if not exists(FLAGS.output_directory): + makedirs(FLAGS.output_directory) + process_lsp(FLAGS.img_directory, FLAGS.output_directory, + FLAGS.train_shards, FLAGS.validation_shards) + + +if __name__ == '__main__': + tf.app.run() diff --git a/src/datasets/mpi_inf_3dhp/__init__.py b/src/datasets/mpi_inf_3dhp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/datasets/mpi_inf_3dhp/read_mpi_inf_3dhp.py b/src/datasets/mpi_inf_3dhp/read_mpi_inf_3dhp.py new file mode 100644 index 000000000..7a0420609 --- /dev/null +++ b/src/datasets/mpi_inf_3dhp/read_mpi_inf_3dhp.py @@ -0,0 +1,107 @@ +""" +Open up mpi_inf_3dhp. + +TRAINING: +For each subject & sequence there is annot.mat +What is in annot.mat: + 'frames': number of frames, N + 'univ_annot3': (14,) for each camera of N x 84 -> Why is there univ for each camera if it's univ..? + 'annot3': (14,) for each camera of N x 84 + 'annot2': (14,) for each camera of N x 56 + 'cameras': + + In total there are 28 joints, but H3.6M subsets are used. + + The image frames are unpacked in: + BASE_DIR/S%d/Seq%d/video_%d/frame_%06.jpg + + +TESTING: + 'valid_frame': N_frames x 1 + 'annot2': N_frames x 1 x 17 x 2 + 'annot3': N_frames x 1 x 17 x 3 + 'univ_annot3': N_frames x 1 x 17 x 3 + 'bb_crop': this is N_frames x 34 (not sure what this is..) + 'activity_annotation': N_frames x 1 (of integer indicating activity type + The test images are already in jpg. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from os.path import join + + +def get_paths(base_dir, sub_id, seq_id): + data_dir = join(base_dir, 'S%d' % sub_id, 'Seq%d' % seq_id) + anno_path = join(data_dir, 'annot.mat') + img_dir = join(data_dir, 'imageFrames') + return img_dir, anno_path + + +def read_mat(path): + from scipy.io import loadmat + res = loadmat(path, struct_as_record=True, squeeze_me=True) + + cameras = res['cameras'] + annot2 = np.stack(res['annot2']) + annot3 = np.stack(res['annot3']) + frames = res['frames'] + + # univ_annot3 = np.stack(res['univ_annot3']) + + return frames, cameras, annot2, annot3 + + +def mpi_inf_3dhp_to_lsp_idx(): + # For training, this joint_idx gives names 17 + raw_to_h36m17_idx = np.array( + [8, 6, 15, 16, 17, 10, 11, 12, 24, 25, 26, 19, 20, 21, 5, 4, 7]) - 1 + names_17 = [ + 'Head', 'Neck', 'R Shoulder', 'R Elbow', 'R Wrist', 'L Shoulder', + 'L Elbow', 'L Wrist', 'R Hip', 'R Knee', 'R Ankle', 'L Hip', 'L Knee', + 'L Ankle', 'Pelvis', 'Spine', 'Head' + ] + want_names = [ + 'R Ankle', 'R Knee', 'R Hip', 'L Hip', 'L Knee', 'L Ankle', 'R Wrist', + 'R Elbow', 'R Shoulder', 'L Shoulder', 'L Elbow', 'L Wrist', 'Neck', + 'Head' + ] + + h36m17_to_lsp_idx = [names_17.index(j) for j in want_names] + + raw_to_lsp_idx = raw_to_h36m17_idx[h36m17_to_lsp_idx] + + return raw_to_lsp_idx, h36m17_to_lsp_idx + + +def read_camera(base_dir): + cam_path = join(base_dir, 'S1/Seq1/camera.calibration') + lines = [] + with open(cam_path, 'r') as f: + for line in f: + content = [x for x in line.strip().split(' ') if x] + lines.append(content) + + def get_cam_info(block): + cam_id = int(block[0][1]) + # Intrinsic + intrinsic = block[4][1:] + K = np.array([np.float(cont) for cont in intrinsic]).reshape(4, 4) + # Extrinsic: + extrinsic = block[5][1:] + Ext = np.array([float(cont) for cont in extrinsic]).reshape(4, 4) + return cam_id, K, Ext + + # Skip header + lines = lines[1:] + # each camera is 7 lines long. + num_cams = int(len(lines) / 7) + cams = {} + for i in range(num_cams): + cam_id, K, Ext = get_cam_info(lines[7 * i:7 * i + 7]) + cams[cam_id] = K + + return cams diff --git a/src/datasets/mpi_inf_3dhp_to_tfrecords.py b/src/datasets/mpi_inf_3dhp_to_tfrecords.py new file mode 100644 index 000000000..3ca14d5c0 --- /dev/null +++ b/src/datasets/mpi_inf_3dhp_to_tfrecords.py @@ -0,0 +1,285 @@ +""" Convert MPI_INF_3DHP to TFRecords """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os.path import join, exists +from os import makedirs + +import numpy as np + +import tensorflow as tf + +from .common import convert_to_example_wmosh, ImageCoder, resize_img +from .mpi_inf_3dhp.read_mpi_inf_3dhp import get_paths, read_mat, mpi_inf_3dhp_to_lsp_idx, read_camera + +tf.app.flags.DEFINE_string('data_directory', '/scratch1/storage/mpi_inf_3dhp/', + 'data directory: top of mpi-inf-3dhp') +tf.app.flags.DEFINE_string('output_directory', + '/scratch1/projects/tf_datasets/mpi_inf_3dhp/', + 'Output data directory') + +tf.app.flags.DEFINE_string('split', 'train', 'train or trainval') +tf.app.flags.DEFINE_integer('train_shards', 500, + 'Number of shards in training TFRecord files.') + +FLAGS = tf.app.flags.FLAGS +MIN_VIS_PTS = 8 # This many points must be within the image. + +# To go to h36m joints: +# training joints have 28 joints +# test joints are 17 (H3.6M subset in CPM order) +joint_idx2lsp, test_idx2lsp = mpi_inf_3dhp_to_lsp_idx() + + +def sample_frames(gt3ds): + use_these = np.zeros(gt3ds.shape[0], bool) + # Always use_these first frame. + use_these[0] = True + prev_kp3d = gt3ds[0] + for itr, kp3d in enumerate(gt3ds): + if itr > 0: + # Check if any joint moved more than 200mm. + if not np.any(np.linalg.norm(prev_kp3d - kp3d, axis=1) >= 200): + continue + use_these[itr] = True + prev_kp3d = kp3d + + return use_these + + +def get_all_data(base_dir, sub_id, seq_id, cam_ids, all_cam_info): + img_dir, anno_path = get_paths(base_dir, sub_id, seq_id) + # Get data for all cameras. + frames, _, annot2, annot3 = read_mat(anno_path) + + all_gt2ds, all_gt3ds, all_img_paths = [], [], [] + all_cams = [] + for cam_id in cam_ids: + base_path = join(img_dir, 'video_%d' % cam_id, 'frame_%06d.jpg') + num_frames = annot2[cam_id].shape[0] + gt2ds = annot2[cam_id].reshape(num_frames, -1, 2) + gt3ds = annot3[cam_id].reshape(num_frames, -1, 3) + # Convert N x 28 x . to N x 14 x 2, N x 14 x 3 + gt2ds = gt2ds[:, joint_idx2lsp, :] + gt3ds = gt3ds[:, joint_idx2lsp, :] + img_paths = [base_path % (frame + 1) for frame in frames] + if gt3ds.shape[0] != len(img_paths): + print('Not same paths?') + import ipdb + ipdb.set_trace() + use_these = sample_frames(gt3ds) + all_gt2ds.append(gt2ds[use_these]) + all_gt3ds.append(gt3ds[use_these]) + K = all_cam_info[cam_id] + flength = 0.5 * (K[0, 0] + K[1, 1]) + ppt = K[:2, 2] + flengths = np.tile(flength, (np.sum(use_these), 1)) + ppts = np.tile(ppt, (np.sum(use_these), 1)) + cams = np.hstack((flengths, ppts)) + all_cams.append(cams) + all_img_paths += np.array(img_paths)[use_these].tolist() + + all_gt2ds = np.vstack(all_gt2ds) + all_gt3ds = np.vstack(all_gt3ds) + all_cams = np.vstack(all_cams) + + return all_img_paths, all_gt2ds, all_gt3ds, all_cams + + +def check_good(image, gt2d): + h, w, _ = image.shape + + x_in = np.logical_and(gt2d[:, 0] < w, gt2d[:, 0] >= 0) + y_in = np.logical_and(gt2d[:, 1] < h, gt2d[:, 1] >= 0) + + ok_pts = np.logical_and(x_in, y_in) + + return np.sum(ok_pts) >= MIN_VIS_PTS + + +def add_to_tfrecord(im_path, + gt2d, + gt3d, + cam, + coder, + writer, + model=None, + sub_path=None): + """ + gt2d is 14 x 2 (lsp order) + gt3d is 14 x 3 + cam is (3,) + returns: + success = 1 if this is a good image + 0 if most of the kps are outside the image + """ + # Read image + if not exists(im_path): + # print('!!--%s doesnt exist! Skipping..--!!' % im_path) + return False + with tf.gfile.FastGFile(im_path, 'rb') as f: + image_data = f.read() + image = coder.decode_jpeg(coder.png_to_jpeg(image_data)) + assert image.shape[2] == 3 + + good = check_good(image, gt2d) + if not good: + if FLAGS.split == 'test': + print('Why no good?? shouldnt happen') + import ipdb + ipdb.set_trace() + return False + + # All kps are visible in mpi_inf_3dhp. + min_pt = np.min(gt2d, axis=0) + max_pt = np.max(gt2d, axis=0) + person_height = np.linalg.norm(max_pt - min_pt) + center = (min_pt + max_pt) / 2. + scale = 150. / person_height + + image_scaled, scale_factors = resize_img(image, scale) + height, width = image_scaled.shape[:2] + joints_scaled = np.copy(gt2d) + joints_scaled[:, 0] *= scale_factors[0] + joints_scaled[:, 1] *= scale_factors[1] + center_scaled = np.round(center * scale_factors).astype(np.int) + # scale camera: Flength, px, py + cam_scaled = np.copy(cam) + cam_scaled[0] *= scale + cam_scaled[1] *= scale_factors[0] + cam_scaled[2] *= scale_factors[1] + + # Crop 300x300 around the center + margin = 150 + start_pt = np.maximum(center_scaled - margin, 0).astype(int) + end_pt = (center_scaled + margin).astype(int) + end_pt[0] = min(end_pt[0], width) + end_pt[1] = min(end_pt[1], height) + image_scaled = image_scaled[start_pt[1]:end_pt[1], start_pt[0]:end_pt[ + 0], :] + # Update others too. + joints_scaled[:, 0] -= start_pt[0] + joints_scaled[:, 1] -= start_pt[1] + center_scaled -= start_pt + # Update principal point: + cam_scaled[1] -= start_pt[0] + cam_scaled[2] -= start_pt[1] + height, width = image_scaled.shape[:2] + + # Fix units: mm -> meter + gt3d = gt3d / 1000. + + # Encode image: + image_data_scaled = coder.encode_jpeg(image_scaled) + label = np.vstack([joints_scaled.T, np.ones((1, joints_scaled.shape[0]))]) + # pose and shape is not existent. + pose, shape = None, None + example = convert_to_example_wmosh( + image_data_scaled, im_path, height, width, label, center_scaled, gt3d, + pose, shape, scale_factors, start_pt, cam_scaled) + writer.write(example.SerializeToString()) + + return True + + +def save_to_tfrecord(out_name, im_paths, gt2ds, gt3ds, cams, num_shards): + coder = ImageCoder() + i = 0 + # Count on shards + fidx = 0 + # Count failures + num_bad = 0 + while i < len(im_paths): + tf_filename = out_name % fidx + print('Starting tfrecord file %s' % tf_filename) + with tf.python_io.TFRecordWriter(tf_filename) as writer: + j = 0 + while i < len(im_paths) and j < num_shards: + if i % 100 == 0: + print('Reading img %d/%d' % (i, len(im_paths))) + success = add_to_tfrecord(im_paths[i], gt2ds[i], gt3ds[i], + cams[i], coder, writer) + i += 1 + if success: + j += 1 + else: + num_bad += 1 + + fidx += 1 + + print('Done, wrote to %s, num skipped %d' % (out_name, num_bad)) + + +def process_mpi_inf_3dhp_train(data_dir, out_dir, is_train=False): + if is_train: + out_dir = join(out_dir, 'train') + print('!train set!') + sub_ids = range(1, 8) # No S8! + seq_ids = range(1, 3) + cam_ids = [0, 1, 2, 4, 5, 6, 7, 8] + else: # Full set!! + out_dir = join(out_dir, 'trainval') + print('doing the full train-val set!') + sub_ids = range(1, 9) + seq_ids = range(1, 3) + cam_ids = [0, 1, 2, 4, 5, 6, 7, 8] + + if not exists(out_dir): + makedirs(out_dir) + + out_path = join(out_dir, 'train_%04d.tfrecord') + num_shards = FLAGS.train_shards + + # Load all data & shuffle it,, + all_gt2ds, all_gt3ds, all_img_paths = [], [], [] + all_cams = [] + all_cam_info = read_camera(data_dir) + + for sub_id in sub_ids: + for seq_id in seq_ids: + print('collecting S%d, Seq%d' % (sub_id, seq_id)) + # Collect all data for each camera. + # img_paths: N list + # gt2ds/gt3ds: N x 17 x 2, N x 17 x 3 + img_paths, gt2ds, gt3ds, cams = get_all_data( + data_dir, sub_id, seq_id, cam_ids, all_cam_info) + + all_img_paths += img_paths + all_gt2ds.append(gt2ds) + all_gt3ds.append(gt3ds) + all_cams.append(cams) + + all_gt2ds = np.vstack(all_gt2ds) + all_gt3ds = np.vstack(all_gt3ds) + all_cams = np.vstack(all_cams) + assert (all_gt3ds.shape[0] == len(all_img_paths)) + # Now shuffle it all. + shuffle_id = np.random.permutation(len(all_img_paths)) + all_img_paths = np.array(all_img_paths)[shuffle_id] + all_gt2ds = all_gt2ds[shuffle_id] + all_gt3ds = all_gt3ds[shuffle_id] + all_cams = all_cams[shuffle_id] + + save_to_tfrecord(out_path, all_img_paths, all_gt2ds, all_gt3ds, all_cams, + num_shards) + + +def main(unused_argv): + print('Saving results to %s' % FLAGS.output_directory) + + if not exists(FLAGS.output_directory): + makedirs(FLAGS.output_directory) + + if FLAGS.split == 'train' or FLAGS.split == 'trainval': + is_train = FLAGS.split == 'train' + process_mpi_inf_3dhp_train( + FLAGS.data_directory, FLAGS.output_directory, is_train=is_train) + else: + print('Unknown split %s' % FLAGS.split) + import ipdb + ipdb.set_trace() + + +if __name__ == '__main__': + tf.app.run() diff --git a/src/datasets/mpii_to_tfrecords.py b/src/datasets/mpii_to_tfrecords.py new file mode 100644 index 000000000..5b3d4cd94 --- /dev/null +++ b/src/datasets/mpii_to_tfrecords.py @@ -0,0 +1,295 @@ +""" +Convert MPII to TFRecords. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os import makedirs +from os.path import join, exists +from time import time + +import numpy as np + +import tensorflow as tf + +from .common import convert_to_example, ImageCoder, resize_img + +tf.app.flags.DEFINE_string('img_directory', + '/scratch1/storage/human_datasets/mpii', + 'image data directory') +tf.app.flags.DEFINE_string( + 'output_directory', '/Users/kanazawa/projects/datasets/tf_datasets/mpii', + 'Output data directory') + +tf.app.flags.DEFINE_integer('train_shards', 500, + 'Number of shards in training TFRecord files.') +tf.app.flags.DEFINE_integer('validation_shards', 500, + 'Number of shards in validation TFRecord files.') + +FLAGS = tf.app.flags.FLAGS + + +def load_anno(fname): + import scipy.io as sio + t0 = time() + print('Reading annotation..') + res = sio.loadmat(fname, struct_as_record=False, squeeze_me=True) + print('took %g sec..' % (time() - t0)) + + return res['RELEASE'] + + +def convert_is_visible(is_visible): + """ + this field is u'1' or empty numpy array.. + """ + if isinstance(is_visible, np.ndarray): + assert (is_visible.size == 0) + return 0 + else: + return int(is_visible) + + +def read_joints(rect): + """ + Reads joints in the common joint order. + Assumes rect has annopoints as field. + + Returns: + joints: 3 x |common joints| + """ + # Mapping from MPII joints to LSP joints (0:13). In this roder: + _COMMON_JOINT_IDS = [ + 0, # R ankle + 1, # R knee + 2, # R hip + 3, # L hip + 4, # L knee + 5, # L ankle + 10, # R Wrist + 11, # R Elbow + 12, # R shoulder + 13, # L shoulder + 14, # L Elbow + 15, # L Wrist + 8, # Neck top + 9, # Head top + ] + assert ('annopoints' in rect._fieldnames) + points = rect.annopoints.point + if not isinstance(points, np.ndarray): + # There is only one! so ignore this image + return None + # Not all joints are there.. read points in a dict. + read_points = {} + + for point in points: + vis = convert_is_visible(point.is_visible) + read_points[point.id] = np.array([point.x, point.y, vis]) + + # Go over each common joint ids + joints = np.zeros((3, len(_COMMON_JOINT_IDS))) + for i, jid in enumerate(_COMMON_JOINT_IDS): + if jid in read_points.keys(): + joints[:, i] = read_points[jid] + # If it's annotated, then use it as visible + # (in this visible = 0 iff no gt label) + joints[2, i] = 1. + + return joints + + +def parse_people(anno_info, single_persons): + ''' + Parses people from rect annotation. + Assumes input is train data. + Input: + img_dir: str + anno_info: annolist[img_id] obj + single_persons: rect id idx for "single" people + + Returns: + people - list of annotated single-people in this image. + Its Entries are tuple (label, img_scale, obj_pos) + ''' + # No single persons in this image. + if single_persons.size == 0: + return [] + + rects = anno_info.annorect + if not isinstance(rects, np.ndarray): + rects = np.array([rects]) + + # Read each human: + people = [] + + for ridx in single_persons: + rect = rects[ridx - 1] + pos = np.array([rect.objpos.x, rect.objpos.y]) + joints = read_joints(rect) + if joints is None: + continue + # Compute the scale using the keypoints so the person is 150px. + visible = joints[2, :].astype(bool) + # If ankles are visible + if visible[0] or visible[5]: + min_pt = np.min(joints[:2, visible], axis=1) + max_pt = np.max(joints[:2, visible], axis=1) + person_height = np.linalg.norm(max_pt - min_pt) + scale = 150. / person_height + else: + # Torso points left should, right shold, right hip, left hip + # torso_points = joints[:, [8, 9, 3, 2]] + torso_heights = [] + if visible[13] and visible[2]: + torso_heights.append( + np.linalg.norm(joints[:2, 13] - joints[:2, 2])) + if visible[13] and visible[3]: + torso_heights.append( + np.linalg.norm(joints[:2, 13] - joints[:2, 3])) + # Make torso 75px + if len(torso_heights) > 0: + scale = 75. / np.mean(torso_heights) + else: + if visible[8] and visible[2]: + torso_heights.append( + np.linalg.norm(joints[:2, 8] - joints[:2, 2])) + if visible[9] and visible[3]: + torso_heights.append( + np.linalg.norm(joints[:2, 9] - joints[:2, 3])) + if len(torso_heights) > 0: + scale = 56. / np.mean(torso_heights) + else: + # Skip, person is too close. + continue + + people.append((joints, scale, pos)) + + return people + + +def add_to_tfrecord(anno, img_id, img_dir, coder, writer, is_train): + """ + Add each "single person" in this image. + anno - the entire annotation file. + + Returns: + The number of people added. + """ + anno_info = anno.annolist[img_id] + # Make it consistent,, always a numpy array. + single_persons = anno.single_person[img_id] + if not isinstance(single_persons, np.ndarray): + single_persons = np.array([single_persons]) + + people = parse_people(anno_info, single_persons) + + if len(people) == 0: + return 0 + + # Add each people to tf record + image_path = join(img_dir, anno_info.image.name) + with tf.gfile.FastGFile(image_path, 'rb') as f: + image_data = f.read() + image = coder.decode_jpeg(image_data) + + for joints, scale, pos in people: + # Scale image: + image_scaled, scale_factors = resize_img(image, scale) + height, width = image_scaled.shape[:2] + joints_scaled = np.copy(joints) + joints_scaled[0, :] *= scale_factors[0] + joints_scaled[1, :] *= scale_factors[1] + + visible = joints[2, :].astype(bool) + min_pt = np.min(joints_scaled[:2, visible], axis=1) + max_pt = np.max(joints_scaled[:2, visible], axis=1) + center = (min_pt + max_pt) / 2. + + ## Crop 600x600 around this image.. + margin = 300 + start_pt = np.maximum(center - margin, 0).astype(int) + end_pt = (center + margin).astype(int) + end_pt[0] = min(end_pt[0], width) + end_pt[1] = min(end_pt[1], height) + image_scaled = image_scaled[start_pt[1]:end_pt[1], start_pt[0]:end_pt[ + 0], :] + # Update others oo. + joints_scaled[0, :] -= start_pt[0] + joints_scaled[1, :] -= start_pt[1] + center -= start_pt + height, width = image_scaled.shape[:2] + + # Encode image: + image_data_scaled = coder.encode_jpeg(image_scaled) + + example = convert_to_example(image_data_scaled, image_path, height, + width, joints_scaled, center) + writer.write(example.SerializeToString()) + + # Finally return how many were written. + return len(people) + + +def process_mpii(anno, img_dir, out_dir, num_shards, is_train=True): + all_ids = np.array(range(len(anno.annolist))) + if is_train: + out_path = join(out_dir, 'train_%03d.tfrecord') + img_inds = all_ids[anno.img_train.astype('bool')] + else: + out_path = join(out_dir, 'test_%03d.tfrecord') + img_inds = all_ids[np.logical_not(anno.img_train)] + print('Not implemented for test data') + exit(1) + + # MPII annotation is tricky (maybe the way scipy reads them) + # If there's only 1 person in the image, annorect is not an array + # So just go over each image, and add every single_person in that image + # add_to_tfrecords returns the # of ppl added. + # So it's possible some shards go over the limit but this is ok. + + coder = ImageCoder() + + i = 0 + # Count on shards + fidx = 0 + num_ppl = 0 + while i < len(img_inds): + + tf_filename = out_path % fidx + print('Starting tfrecord file %s' % tf_filename) + with tf.python_io.TFRecordWriter(tf_filename) as writer: + # Count on total ppl in each shard + num_ppl = 0 + while i < len(img_inds) and num_ppl < num_shards: + if i % 100 == 0: + print('Reading img %d/%d' % (i, len(img_inds))) + num_ppl += add_to_tfrecord(anno, img_inds[i], img_dir, coder, + writer, is_train) + i += 1 + + fidx += 1 + + +def main(unused_argv): + print('Saving results to %s' % FLAGS.output_directory) + + if not exists(FLAGS.output_directory): + makedirs(FLAGS.output_directory) + + anno_mat = join(FLAGS.img_directory, 'annotations', + 'mpii_human_pose_v1_u12_1.mat') + anno = load_anno(anno_mat) + + img_dir = join(FLAGS.img_directory, 'images') + process_mpii( + anno, + img_dir, + FLAGS.output_directory, + FLAGS.train_shards, + is_train=True) + + +if __name__ == '__main__': + tf.app.run() diff --git a/src/datasets/smpl_to_tfrecords.py b/src/datasets/smpl_to_tfrecords.py new file mode 100644 index 000000000..214ed5610 --- /dev/null +++ b/src/datasets/smpl_to_tfrecords.py @@ -0,0 +1,119 @@ +""" +Convert MoCap SMPL data to tfrecords. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os import makedirs +from os.path import join, exists +import numpy as np +from glob import glob +import cPickle as pickle + +import tensorflow as tf + +from .common import float_feature + +tf.app.flags.DEFINE_string( + 'dataset_name', 'neutrSMPL_CMU', + 'neutrSMPL_CMU, neutrSMPL_H3.6, or neutrSMPL_jointLim') +tf.app.flags.DEFINE_string('data_directory', + '/scratch1/storage/human_datasets/neutrMosh/', + 'data directory where SMPL npz/pkl lies') +tf.app.flags.DEFINE_string('output_directory', + '/scratch1/projects/tf_datasets/mocap_neutrMosh/', + 'Output data directory') + +tf.app.flags.DEFINE_integer('num_shards', 10000, + 'Number of shards in TFRecord files.') + +FLAGS = tf.app.flags.FLAGS + + +def convert_to_example(pose, shape=None): + """Build an Example proto for an image example. + Args: + pose: 72-D vector, float + shape: 10-D vector, float + Returns: + Example proto + """ + if shape is None: + example = tf.train.Example(features=tf.train.Features( + feature={ + 'pose': float_feature(pose.astype(np.float)) + })) + else: + example = tf.train.Example(features=tf.train.Features( + feature={ + 'pose': float_feature(pose.astype(np.float)), + 'shape': float_feature(shape.astype(np.float)), + })) + + return example + + +def process_smpl_mocap(all_pkls, out_dir, num_shards, dataset_name): + all_poses, all_shapes, all_shapes_unique = [], [], [] + for pkl in all_pkls: + with open(pkl, 'rb') as f: + res = pickle.load(f) + all_poses.append(res['poses']) + num_poses_here = res['poses'].shape[0] + all_shapes.append( + np.tile(np.reshape(res['betas'], (10, 1)), num_poses_here)) + all_shapes_unique.append(res['betas']) + + all_poses = np.vstack(all_poses) + all_shapes = np.hstack(all_shapes).T + + out_path = join(out_dir, '%s_%%03d.tfrecord' % dataset_name) + + # shuffle results + num_mocap = all_poses.shape[0] + shuffle_id = np.random.permutation(num_mocap) + all_poses = all_poses[shuffle_id] + all_shapes = all_shapes[shuffle_id] + + i = 0 + fidx = 0 + while i < num_mocap: + # Open new TFRecord file. + tf_filename = out_path % fidx + print('Starting tfrecord file %s' % tf_filename) + with tf.python_io.TFRecordWriter(tf_filename) as writer: + j = 0 + while i < num_mocap and j < num_shards: + if i % 10000 == 0: + print('Converting mosh %d/%d' % (i, num_mocap)) + example = convert_to_example(all_poses[i], shape=all_shapes[i]) + writer.write(example.SerializeToString()) + i += 1 + j += 1 + + fidx += 1 + + +def main(unused_argv): + data_dir = join(FLAGS.data_directory, FLAGS.dataset_name) + # Ignore H3.6M test subjects!! + all_pkl = sorted([ + f for f in glob(join(data_dir, '*/*.pkl')) + if 'S9' not in f and 'S11' not in f + ]) + if len(all_pkl) == 0: + print('Something is wrong with the path bc I cant find any pkls!') + import ipdb; ipdb.set_trace() + + print('Saving results to %s' % FLAGS.output_directory) + + if not exists(FLAGS.output_directory): + makedirs(FLAGS.output_directory) + + process_smpl_mocap(all_pkl, FLAGS.output_directory, FLAGS.num_shards, + FLAGS.dataset_name) + + +if __name__ == '__main__': + tf.app.run() diff --git a/src/main.py b/src/main.py new file mode 100644 index 000000000..f02bc5a48 --- /dev/null +++ b/src/main.py @@ -0,0 +1,30 @@ +""" Driver for train """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from .config import get_config, prepare_dirs, save_config +from .data_loader import DataLoader +from .trainer import HMRTrainer + + +def main(config): + prepare_dirs(config) + + # Load data on CPU + with tf.device("/cpu:0"): + data_loader = DataLoader(config) + image_loader = data_loader.load() + smpl_loader = data_loader.get_smpl_loader() + + trainer = HMRTrainer(config, image_loader, smpl_loader) + save_config(config) + trainer.train() + + +if __name__ == '__main__': + config = get_config() + main(config) diff --git a/src/models.py b/src/models.py index ff0edcbac..55315f500 100644 --- a/src/models.py +++ b/src/models.py @@ -5,6 +5,8 @@ @Encoder_resnet_v1_101 @Encoder_fc3_dropout +@Discriminator_separable_rotations + Helper: @get_encoder_fn_separate """ @@ -18,6 +20,7 @@ from tensorflow.contrib.layers.python.layers.initializers import variance_scaling_initializer + def Encoder_resnet(x, is_training=True, weight_decay=0.001, reuse=False): """ Resnet v2-50 @@ -47,35 +50,6 @@ def Encoder_resnet(x, is_training=True, weight_decay=0.001, reuse=False): variables = tf.contrib.framework.get_variables('resnet_v2_50') return net, variables -def Encoder_resnet_v1_101(x, - weight_decay, - is_training=True, - reuse=False): - """ - Resnet v1-101 encoder, adds 2 fc layers after Resnet. - Assumes input is [batch, height_in, width_in, channels]!! - Input: - - x: N x H x W x 3 - - weight_decay: float - - reuse: bool-> True if test - - Outputs: - - net: N x F - - variables: tf variables - """ - from tensorflow.contrib.slim.python.slim.nets import resnet_v1 - with tf.name_scope("Encoder_resnet_v1_101", [x]): - with slim.arg_scope( - resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): - net, end_points = resnet_v1.resnet_v1_101( - x, - num_classes=None, - is_training=is_training, - reuse=reuse, - scope='resnet_v1_101') - net = tf.reshape(net, [net.shape.as_list()[0], -1]) - variables = tf.contrib.framework.get_variables('resnet_v1_101') - return net, variables def Encoder_fc3_dropout(x, num_output=85, @@ -123,16 +97,83 @@ def get_encoder_fn_separate(model_type): """ encoder_fn = None threed_fn = None - if 'resnet_v1_101' in model_type: - encoder_fn = Encoder_resnet_v1_101 - elif 'resnet' in model_type: + if 'resnet' in model_type: encoder_fn = Encoder_resnet - + else: + print('Unknown encoder %s!' % model_type) + exit(1) + if 'fc3_dropout' in model_type: threed_fn = Encoder_fc3_dropout - + if encoder_fn is None or threed_fn is None: print('Dont know what encoder to use for %s' % model_type) - import ipdb; ipdb.set_trace() + import ipdb + ipdb.set_trace() return encoder_fn, threed_fn + + +def Discriminator_separable_rotations( + poses, + shapes, + weight_decay, +): + """ + 23 Discriminators on each joint + 1 for all joints + 1 for shape. + To share the params on rotations, this treats the 23 rotation matrices + as a "vertical image": + Do 1x1 conv, then send off to 23 independent classifiers. + + Input: + - poses: N x 23 x 1 x 9, NHWC ALWAYS!! + - shapes: N x 10 + - weight_decay: float + + Outputs: + - prediction: N x (1+23) or N x (1+23+1) if do_joint is on. + - variables: tf variables + """ + data_format = "NHWC" + with tf.name_scope("Discriminator_sep_rotations", [poses, shapes]): + with tf.variable_scope("D") as scope: + with slim.arg_scope( + [slim.conv2d, slim.fully_connected], + weights_regularizer=slim.l2_regularizer(weight_decay)): + with slim.arg_scope([slim.conv2d], data_format=data_format): + poses = slim.conv2d(poses, 32, [1, 1], scope='D_conv1') + poses = slim.conv2d(poses, 32, [1, 1], scope='D_conv2') + theta_out = [] + for i in range(0, 23): + theta_out.append( + slim.fully_connected( + poses[:, i, :, :], + 1, + activation_fn=None, + scope="pose_out_j%d" % i)) + theta_out_all = tf.squeeze(tf.stack(theta_out, axis=1)) + + # Do shape on it's own: + shapes = slim.stack( + shapes, + slim.fully_connected, [10, 5], + scope="shape_fc1") + shape_out = slim.fully_connected( + shapes, 1, activation_fn=None, scope="shape_final") + """ Compute joint correlation prior!""" + nz_feat = 1024 + poses_all = slim.flatten(poses, scope='vectorize') + poses_all = slim.fully_connected( + poses_all, nz_feat, scope="D_alljoints_fc1") + poses_all = slim.fully_connected( + poses_all, nz_feat, scope="D_alljoints_fc2") + poses_all_out = slim.fully_connected( + poses_all, + 1, + activation_fn=None, + scope="D_alljoints_out") + out = tf.concat([theta_out_all, + poses_all_out, shape_out], 1) + + variables = tf.contrib.framework.get_variables(scope) + return out, variables diff --git a/src/ops.py b/src/ops.py new file mode 100644 index 000000000..860dbe3d5 --- /dev/null +++ b/src/ops.py @@ -0,0 +1,60 @@ +""" +TF util operations. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def keypoint_l1_loss(kp_gt, kp_pred, scale=1., name=None): + """ + computes: \Sum_i [0.5 * vis[i] * |kp_gt[i] - kp_pred[i]|] / (|vis|) + Inputs: + kp_gt : N x K x 3 + kp_pred: N x K x 2 + """ + with tf.name_scope(name, "keypoint_l1_loss", [kp_gt, kp_pred]): + kp_gt = tf.reshape(kp_gt, (-1, 3)) + kp_pred = tf.reshape(kp_pred, (-1, 2)) + + vis = tf.expand_dims(tf.cast(kp_gt[:, 2], tf.float32), 1) + res = tf.losses.absolute_difference(kp_gt[:, :2], kp_pred, weights=vis) + return res + + +def compute_3d_loss(params_pred, params_gt, has_gt3d): + """ + Computes the l2 loss between 3D params pred and gt for those data that has_gt3d is True. + Parameters to compute loss over: + 3Djoints: 14*3 = 42 + rotations:(24*9)= 216 + shape: 10 + total input: 226 (gt SMPL params) or 42 (just joints) + + Inputs: + params_pred: N x {226, 42} + params_gt: N x {226, 42} + # has_gt3d: (N,) bool + has_gt3d: N x 1 tf.float32 of {0., 1.} + """ + with tf.name_scope("3d_loss", [params_pred, params_gt, has_gt3d]): + weights = tf.expand_dims(tf.cast(has_gt3d, tf.float32), 1) + res = tf.losses.mean_squared_error( + params_gt, params_pred, weights=weights) * 0.5 + return res + + +def align_by_pelvis(joints): + """ + Assumes joints is N x 14 x 3 in LSP order. + Then hips are: [3, 2] + Takes mid point of these points, then subtracts it. + """ + with tf.name_scope("align_by_pelvis", [joints]): + left_id = 3 + right_id = 2 + pelvis = (joints[:, left_id, :] + joints[:, right_id, :]) / 2. + return joints - tf.expand_dims(pelvis, axis=1) diff --git a/src/tf_smpl/batch_smpl.py b/src/tf_smpl/batch_smpl.py index d832e86c2..837cdd01a 100644 --- a/src/tf_smpl/batch_smpl.py +++ b/src/tf_smpl/batch_smpl.py @@ -155,7 +155,7 @@ def __call__(self, beta, theta, get_skin=False, name=None): joints = tf.stack([joint_x, joint_y, joint_z], axis=2) if get_skin: - return verts, joints + return verts, joints, Rs else: return joints diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 000000000..17a33f58e --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,606 @@ +""" +HMR trainer. +From an image input, trained a model that outputs 85D latent vector +consisting of [cam (3 - [scale, tx, ty]), pose (72), shape (10)] +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .data_loader import num_examples + +from .ops import keypoint_l1_loss, compute_3d_loss, align_by_pelvis +from .models import Discriminator_separable_rotations, get_encoder_fn_separate + +from .tf_smpl.batch_lbs import batch_rodrigues +from .tf_smpl.batch_smpl import SMPL +from .tf_smpl.projection import batch_orth_proj_idrot + +from tensorflow.python.ops import control_flow_ops + +from time import time +import tensorflow as tf +import numpy as np + +from os.path import join, dirname +import deepdish as dd + +# For drawing +from .util import renderer as vis_util + + +class HMRTrainer(object): + def __init__(self, config, data_loader, mocap_loader): + """ + Args: + config + if no 3D label is available, + data_loader is a dict + else + data_loader is a dict + mocap_loader is a tuple (pose, shape) + """ + # Config + path + self.config = config + self.model_dir = config.model_dir + self.load_path = config.load_path + + self.data_format = config.data_format + self.smpl_model_path = config.smpl_model_path + self.pretrained_model_path = config.pretrained_model_path + self.encoder_only = config.encoder_only + self.use_3d_label = config.use_3d_label + + # Data size + self.img_size = config.img_size + self.num_stage = config.num_stage + self.batch_size = config.batch_size + self.max_epoch = config.epoch + + self.num_cam = 3 + self.proj_fn = batch_orth_proj_idrot + + self.num_theta = 72 # 24 * 3 + self.total_params = self.num_theta + self.num_cam + 10 + + # Data + num_images = num_examples(config.datasets) + num_mocap = num_examples(config.mocap_datasets) + + self.num_itr_per_epoch = num_images / self.batch_size + self.num_mocap_itr_per_epoch = num_mocap / self.batch_size + + # First make sure data_format is right + if self.data_format == 'NCHW': + # B x H x W x 3 --> B x 3 x H x W + data_loader['image'] = tf.transpose(data_loader['image'], + [0, 3, 1, 2]) + + self.image_loader = data_loader['image'] + self.kp_loader = data_loader['label'] + + if self.use_3d_label: + self.poseshape_loader = data_loader['label3d'] + # image_loader[3] is N x 2, first column is 3D_joints gt existence, + # second column is 3D_smpl gt existence + self.has_gt3d_joints = data_loader['has3d'][:, 0] + self.has_gt3d_smpl = data_loader['has3d'][:, 1] + + self.pose_loader = mocap_loader[0] + self.shape_loader = mocap_loader[1] + + self.global_step = tf.Variable(0, name='global_step', trainable=False) + self.log_img_step = config.log_img_step + + # For visualization: + num2show = np.minimum(6, self.batch_size) + # Take half from front & back + self.show_these = tf.constant( + np.hstack( + [np.arange(num2show / 2), self.batch_size - np.arange(3) - 1]), + tf.int32) + + # Model spec + self.model_type = config.model_type + self.keypoint_loss = keypoint_l1_loss + + # Optimizer, learning rate + self.e_lr = config.e_lr + self.d_lr = config.d_lr + # Weight decay + self.e_wd = config.e_wd + self.d_wd = config.d_wd + self.e_loss_weight = config.e_loss_weight + self.d_loss_weight = config.d_loss_weight + self.e_3d_weight = config.e_3d_weight + + self.optimizer = tf.train.AdamOptimizer + + # Instantiate SMPL + self.smpl = SMPL(self.smpl_model_path) + self.E_var = [] + self.build_model() + + # Logging + init_fn = None + if self.use_pretrained(): + # Make custom init_fn + print("Fine-tuning from %s" % self.pretrained_model_path) + if 'resnet_v2_50' in self.pretrained_model_path: + resnet_vars = [ + var for var in self.E_var if 'resnet_v2_50' in var.name + ] + self.pre_train_saver = tf.train.Saver(resnet_vars) + elif 'pose-tensorflow' in self.pretrained_model_path: + resnet_vars = [ + var for var in self.E_var if 'resnet_v1_101' in var.name + ] + self.pre_train_saver = tf.train.Saver(resnet_vars) + else: + self.pre_train_saver = tf.train.Saver() + + def load_pretrain(sess): + self.pre_train_saver.restore(sess, self.pretrained_model_path) + + init_fn = load_pretrain + + self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=5) + self.summary_writer = tf.summary.FileWriter(self.model_dir) + self.sv = tf.train.Supervisor( + logdir=self.model_dir, + global_step=self.global_step, + saver=self.saver, + summary_writer=self.summary_writer, + init_fn=init_fn) + gpu_options = tf.GPUOptions(allow_growth=True) + self.sess_config = tf.ConfigProto( + allow_soft_placement=False, + log_device_placement=False, + gpu_options=gpu_options) + + def use_pretrained(self): + """ + Returns true only if: + 1. model_type is "resnet" + 2. pretrained_model_path is not None + 3. model_dir is NOT empty, meaning we're picking up from previous + so fuck this pretrained model. + """ + if ('resnet' in self.model_type) and (self.pretrained_model_path is + not None): + # Check is model_dir is empty + import os + if os.listdir(self.model_dir) == []: + return True + + return False + + def load_mean_param(self): + mean = np.zeros((1, self.total_params)) + # Initialize scale at 0.9 + mean[0, 0] = 0.9 + mean_path = join( + dirname(self.smpl_model_path), 'neutral_smpl_mean_params.h5') + mean_vals = dd.io.load(mean_path) + + mean_pose = mean_vals['pose'] + # Ignore the global rotation. + mean_pose[:3] = 0. + mean_shape = mean_vals['shape'] + + # This initializes the global pose to be up-right when projected + mean_pose[0] = np.pi + + mean[0, 3:] = np.hstack((mean_pose, mean_shape)) + mean = tf.constant(mean, tf.float32) + self.mean_var = tf.Variable( + mean, name="mean_param", dtype=tf.float32, trainable=True) + self.E_var.append(self.mean_var) + init_mean = tf.tile(self.mean_var, [self.batch_size, 1]) + return init_mean + + def build_model(self): + img_enc_fn, threed_enc_fn = get_encoder_fn_separate(self.model_type) + # Extract image features. + self.img_feat, self.E_var = img_enc_fn( + self.image_loader, weight_decay=self.e_wd, reuse=False) + + loss_kps = [] + if self.use_3d_label: + loss_3d_joints, loss_3d_params = [], [] + # For discriminator + fake_rotations, fake_shapes = [], [] + # Start loop + theta_prev = self.load_mean_param() + + # For visualizations + self.all_verts = [] + self.all_pred_kps = [] + self.all_pred_cams = [] + self.all_delta_thetas = [] + self.all_theta_prev = [] + + for i in np.arange(self.num_stage): + print('Iteration %d' % i) + # ---- Compute outputs + state = tf.concat([self.img_feat, theta_prev], 1) + + if i == 0: + delta_theta, threeD_var = threed_enc_fn( + state, + num_output=self.total_params, + reuse=False) + self.E_var.extend(threeD_var) + else: + delta_theta, _ = threed_enc_fn( + state, num_output=self.total_params, reuse=True) + + # Compute new theta + theta_here = theta_prev + delta_theta + # cam = N x 3, pose N x self.num_theta, shape: N x 10 + cams = theta_here[:, :self.num_cam] + poses = theta_here[:, self.num_cam:(self.num_cam + self.num_theta)] + shapes = theta_here[:, (self.num_cam + self.num_theta):] + # Rs_wglobal is Nx24x3x3 rotation matrices of poses + verts, Js, pred_Rs = self.smpl(shapes, poses, get_skin=True) + pred_kp = batch_orth_proj_idrot( + Js, cams, name='proj2d_stage%d' % i) + # --- Compute losses: + loss_kps.append(self.e_loss_weight * self.keypoint_loss( + self.kp_loader, pred_kp)) + pred_Rs = tf.reshape(pred_Rs, [-1, 24, 9]) + if self.use_3d_label: + loss_poseshape, loss_joints = self.get_3d_loss( + pred_Rs, shapes, Js) + loss_3d_params.append(loss_poseshape) + loss_3d_joints.append(loss_joints) + + # Save pred_rotations for Discriminator + fake_rotations.append(pred_Rs[:, 1:, :]) + fake_shapes.append(shapes) + + # Save things for visualiations: + self.all_verts.append(tf.gather(verts, self.show_these)) + self.all_pred_kps.append(tf.gather(pred_kp, self.show_these)) + self.all_pred_cams.append(tf.gather(cams, self.show_these)) + + # Finally update to end iteration. + theta_prev = theta_here + + if not self.encoder_only: + self.setup_discriminator(fake_rotations, fake_shapes) + + # Gather losses. + with tf.name_scope("gather_e_loss"): + # Just the last loss. + self.e_loss_kp = loss_kps[-1] + + if self.encoder_only: + self.e_loss = self.e_loss_kp + else: + self.e_loss = self.d_loss_weight * self.e_loss_disc + self.e_loss_kp + + if self.use_3d_label: + self.e_loss_3d = loss_3d_params[-1] + self.e_loss_3d_joints = loss_3d_joints[-1] + + self.e_loss += (self.e_loss_3d + self.e_loss_3d_joints) + + if not self.encoder_only: + with tf.name_scope("gather_d_loss"): + self.d_loss = self.d_loss_weight * ( + self.d_loss_real + self.d_loss_fake) + + # For visualizations, only save selected few into: + # B x T x ... + self.all_verts = tf.stack(self.all_verts, axis=1) + self.all_pred_kps = tf.stack(self.all_pred_kps, axis=1) + self.all_pred_cams = tf.stack(self.all_pred_cams, axis=1) + self.show_imgs = tf.gather(self.image_loader, self.show_these) + self.show_kps = tf.gather(self.kp_loader, self.show_these) + + # Don't forget to update batchnorm's moving means. + print('collecting batch norm moving means!!') + bn_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + if bn_ops: + self.e_loss = control_flow_ops.with_dependencies( + [tf.group(*bn_ops)], self.e_loss) + + # Setup optimizer + print('Setting up optimizer..') + d_optimizer = self.optimizer(self.d_lr) + e_optimizer = self.optimizer(self.e_lr) + + self.e_opt = e_optimizer.minimize( + self.e_loss, global_step=self.global_step, var_list=self.E_var) + if not self.encoder_only: + self.d_opt = d_optimizer.minimize(self.d_loss, var_list=self.D_var) + + self.setup_summaries(loss_kps) + + print('Done initializing trainer!') + + def setup_summaries(self, loss_kps): + # Prepare Summary + always_report = [ + tf.summary.scalar("loss/e_loss_kp_noscale", + self.e_loss_kp / self.e_loss_weight), + tf.summary.scalar("loss/e_loss", self.e_loss), + ] + if self.encoder_only: + print('ENCODER ONLY!!!') + else: + always_report.extend([ + tf.summary.scalar("loss/d_loss", self.d_loss), + tf.summary.scalar("loss/d_loss_fake", self.d_loss_fake), + tf.summary.scalar("loss/d_loss_real", self.d_loss_real), + tf.summary.scalar("loss/e_loss_disc", + self.e_loss_disc / self.d_loss_weight), + ]) + # loss at each stage. + for i in np.arange(self.num_stage): + name_here = "loss/e_losses_noscale/kp_loss_stage%d" % i + always_report.append( + tf.summary.scalar(name_here, loss_kps[i] / self.e_loss_weight)) + if self.use_3d_label: + always_report.append( + tf.summary.scalar("loss/e_loss_3d_params_noscale", + self.e_loss_3d / self.e_3d_weight)) + always_report.append( + tf.summary.scalar("loss/e_loss_3d_joints_noscale", + self.e_loss_3d_joints / self.e_3d_weight)) + + if not self.encoder_only: + summary_occ = [] + # Report D output for each joint. + smpl_names = [ + 'Left_Hip', 'Right_Hip', 'Waist', 'Left_Knee', 'Right_Knee', + 'Upper_Waist', 'Left_Ankle', 'Right_Ankle', 'Chest', + 'Left_Toe', 'Right_Toe', 'Base_Neck', 'Left_Shoulder', + 'Right_Shoulder', 'Upper_Neck', 'Left_Arm', 'Right_Arm', + 'Left_Elbow', 'Right_Elbow', 'Left_Wrist', 'Right_Wrist', + 'Left_Finger', 'Right_Finger' + ] + # d_out is 25 (or 24), last bit is shape, first 24 is pose + # 23(relpose) + 1(jointpose) + 1(shape) => 25 + d_out_pose = self.d_out[:, :24] + for i, name in enumerate(smpl_names): + summary_occ.append( + tf.summary.histogram("d_out/%s" % name, d_out_pose[i])) + summary_occ.append( + tf.summary.histogram("d_out/all_joints", d_out_pose[23])) + summary_occ.append( + tf.summary.histogram("d_out/beta", self.d_out[:, 24])) + + self.summary_op_occ = tf.summary.merge( + summary_occ, collections=['occasional']) + self.summary_op_always = tf.summary.merge(always_report) + + def setup_discriminator(self, fake_rotations, fake_shapes): + # Compute the rotation matrices of "rea" pose. + # These guys are in 24 x 3. + real_rotations = batch_rodrigues(tf.reshape(self.pose_loader, [-1, 3])) + real_rotations = tf.reshape(real_rotations, [-1, 24, 9]) + # Ignoring global rotation. N x 23*9 + # The # of real rotation is B*num_stage so it's balanced. + real_rotations = real_rotations[:, 1:, :] + all_fake_rotations = tf.reshape( + tf.concat(fake_rotations, 0), + [self.batch_size * self.num_stage, -1, 9]) + comb_rotations = tf.concat( + [real_rotations, all_fake_rotations], 0, name="combined_pose") + + comb_rotations = tf.expand_dims(comb_rotations, 2) + all_fake_shapes = tf.concat(fake_shapes, 0) + comb_shapes = tf.concat( + [self.shape_loader, all_fake_shapes], 0, name="combined_shape") + + disc_input = { + 'weight_decay': self.d_wd, + 'shapes': comb_shapes, + 'poses': comb_rotations + } + + self.d_out, self.D_var = Discriminator_separable_rotations( + **disc_input) + + self.d_out_real, self.d_out_fake = tf.split(self.d_out, 2) + # Compute losses: + with tf.name_scope("comp_d_loss"): + self.d_loss_real = tf.reduce_mean( + tf.reduce_sum((self.d_out_real - 1)**2, axis=1)) + self.d_loss_fake = tf.reduce_mean( + tf.reduce_sum((self.d_out_fake)**2, axis=1)) + # Encoder loss + self.e_loss_disc = tf.reduce_mean( + tf.reduce_sum((self.d_out_fake - 1)**2, axis=1)) + + def get_3d_loss(self, Rs, shape, Js): + """ + Rs is N x 24 x 3*3 rotation matrices of pose + Shape is N x 10 + Js is N x 19 x 3 joints + + Ground truth: + self.poseshape_loader is a long vector of: + relative rotation (24*9) + shape (10) + 3D joints (14*3) + """ + Rs = tf.reshape(Rs, [self.batch_size, -1]) + params_pred = tf.concat([Rs, shape], 1, name="prep_params_pred") + # 24*9+10 = 226 + gt_params = self.poseshape_loader[:, :226] + loss_poseshape = self.e_3d_weight * compute_3d_loss( + params_pred, gt_params, self.has_gt3d_smpl) + # 14*3 = 42 + gt_joints = self.poseshape_loader[:, 226:] + pred_joints = Js[:, :14, :] + # Align the joints by pelvis. + pred_joints = align_by_pelvis(pred_joints) + pred_joints = tf.reshape(pred_joints, [self.batch_size, -1]) + gt_joints = tf.reshape(gt_joints, [self.batch_size, 14, 3]) + gt_joints = align_by_pelvis(gt_joints) + gt_joints = tf.reshape(gt_joints, [self.batch_size, -1]) + + loss_joints = self.e_3d_weight * compute_3d_loss( + pred_joints, gt_joints, self.has_gt3d_joints) + + return loss_poseshape, loss_joints + + def visualize_img(self, img, gt_kp, vert, pred_kp, cam, renderer): + """ + Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + gt_vis = gt_kp[:, 2].astype(bool) + loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2) + debug_text = {"sc": cam[0], "tx": cam[1], "ty": cam[2], "kpl": loss} + # Fix a flength so i can render this with persp correct scale + f = 5. + tz = f / cam[0] + cam_for_render = 0.5 * self.img_size * np.array([f, 1, 1]) + cam_t = np.array([cam[1], cam[2], tz]) + # Undo pre-processing. + input_img = (img + 1) * 0.5 + rend_img = renderer(vert + cam_t, cam_for_render, img=input_img) + rend_img = vis_util.draw_text(rend_img, debug_text) + + # Draw skeleton + gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * self.img_size + pred_joint = ((pred_kp + 1) * 0.5) * self.img_size + img_with_gt = vis_util.draw_skeleton( + input_img, gt_joint, draw_edges=False, vis=gt_vis) + skel_img = vis_util.draw_skeleton(img_with_gt, pred_joint) + + combined = np.hstack([skel_img, rend_img / 255.]) + + # import matplotlib.pyplot as plt + # plt.ion() + # plt.imshow(skel_img) + # import ipdb; ipdb.set_trace() + return combined + + def draw_results(self, result): + from StringIO import StringIO + import matplotlib.pyplot as plt + + # This is B x H x W x 3 + imgs = result["input_img"] + # B x 19 x 3 + gt_kps = result["gt_kp"] + if self.data_format == 'NCHW': + imgs = np.transpose(imgs, [0, 2, 3, 1]) + # This is B x T x 6890 x 3 + est_verts = result["e_verts"] + # B x T x 19 x 2 + joints = result["joints"] + # B x T x 3 + cams = result["cam"] + + img_summaries = [] + + for img_id, (img, gt_kp, verts, joints, cams) in enumerate( + zip(imgs, gt_kps, est_verts, joints, cams)): + # verts, joints, cams are a list of len T. + all_rend_imgs = [] + for vert, joint, cam in zip(verts, joints, cams): + rend_img = self.visualize_img(img, gt_kp, vert, joint, cam, + self.renderer) + all_rend_imgs.append(rend_img) + combined = np.vstack(all_rend_imgs) + + sio = StringIO() + plt.imsave(sio, combined, format='png') + vis_sum = tf.Summary.Image( + encoded_image_string=sio.getvalue(), + height=combined.shape[0], + width=combined.shape[1]) + img_summaries.append( + tf.Summary.Value(tag="vis_images/%d" % img_id, image=vis_sum)) + + img_summary = tf.Summary(value=img_summaries) + self.summary_writer.add_summary( + img_summary, global_step=result['step']) + + def train(self): + # For rendering! + self.renderer = vis_util.SMPLRenderer( + img_size=self.img_size, + face_path=self.config.smpl_face_path) + + step = 0 + + with self.sv.managed_session(config=self.sess_config) as sess: + while not self.sv.should_stop(): + fetch_dict = { + "summary": self.summary_op_always, + "step": self.global_step, + "e_loss": self.e_loss, + # The meat + "e_opt": self.e_opt, + "loss_kp": self.e_loss_kp + } + if not self.encoder_only: + fetch_dict.update({ + # For D: + "d_opt": self.d_opt, + "d_loss": self.d_loss, + "loss_disc": self.e_loss_disc, + }) + if self.use_3d_label: + fetch_dict.update({ + "loss_3d_params": self.e_loss_3d, + "loss_3d_joints": self.e_loss_3d_joints + }) + + if step % self.log_img_step == 0: + fetch_dict.update({ + "input_img": self.show_imgs, + "gt_kp": self.show_kps, + "e_verts": self.all_verts, + "joints": self.all_pred_kps, + "cam": self.all_pred_cams, + }) + if not self.encoder_only: + fetch_dict.update({ + "summary_occasional": + self.summary_op_occ + }) + + t0 = time() + result = sess.run(fetch_dict) + t1 = time() + + self.summary_writer.add_summary( + result['summary'], global_step=result['step']) + + e_loss = result['e_loss'] + step = result['step'] + + epoch = float(step) / self.num_itr_per_epoch + if self.encoder_only: + print("itr %d/(epoch %.1f): time %g, Enc_loss: %.4f" % + (step, epoch, t1 - t0, e_loss)) + else: + d_loss = result['d_loss'] + print( + "itr %d/(epoch %.1f): time %g, Enc_loss: %.4f, Disc_loss: %.4f" + % (step, epoch, t1 - t0, e_loss, d_loss)) + + if step % self.log_img_step == 0: + if not self.encoder_only: + self.summary_writer.add_summary( + result['summary_occasional'], + global_step=result['step']) + self.draw_results(result) + + self.summary_writer.flush() + if epoch > self.max_epoch: + self.sv.request_stop() + + step += 1 + + print('Finish training on %s' % self.model_dir) diff --git a/src/util/data_utils.py b/src/util/data_utils.py new file mode 100644 index 000000000..c4d53c06d --- /dev/null +++ b/src/util/data_utils.py @@ -0,0 +1,345 @@ +""" +Utils for data loading for training. +""" + +from os.path import join +from glob import glob + +import tensorflow as tf + + +def parse_example_proto(example_serialized, has_3d=False): + """Parses an Example proto. + It's contents are: + + 'image/height' : _int64_feature(height), + 'image/width' : _int64_feature(width), + 'image/x' : _float_feature(label[0,:].astype(np.float)), + 'image/y' : _float_feature(label[1,:].astype(np.float)), + 'image/visibility' : _int64_feature(label[2,:].astype(np.int)), + 'image/format' : _bytes_feature + 'image/filename' : _bytes_feature + 'image/encoded' : _bytes_feature + 'image/face_points' : _float_feature, + this is the 2D keypoints of the face points in coco 5*3 (x,y,vis) = 15 + + if has_3d is on, it also has: + 'mosh/pose' : float_feature(pose.astype(np.float)), + 'mosh/shape' : float_feature(shape.astype(np.float)), + # gt3d is 14x3 + 'mosh/gt3d' : float_feature(shape.astype(np.float)), + """ + feature_map = { + 'image/encoded': + tf.FixedLenFeature([], dtype=tf.string, default_value=''), + 'image/height': + tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1), + 'image/width': + tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1), + 'image/filename': + tf.FixedLenFeature([], dtype=tf.string, default_value=''), + 'image/center': + tf.FixedLenFeature((2, 1), dtype=tf.int64), + 'image/visibility': + tf.FixedLenFeature((1, 14), dtype=tf.int64), + 'image/x': + tf.FixedLenFeature((1, 14), dtype=tf.float32), + 'image/y': + tf.FixedLenFeature((1, 14), dtype=tf.float32), + 'image/face_pts': + tf.FixedLenFeature( + (1, 15), + dtype=tf.float32, + default_value=[ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. + ]), + } + if has_3d: + feature_map.update({ + 'mosh/pose': + tf.FixedLenFeature((72, ), dtype=tf.float32), + 'mosh/shape': + tf.FixedLenFeature((10, ), dtype=tf.float32), + 'mosh/gt3d': + tf.FixedLenFeature((14 * 3, ), dtype=tf.float32), + # has_3d is for pose and shape: 0 for mpi_inf_3dhp, 1 for h3.6m. + 'meta/has_3d': + tf.FixedLenFeature((1), dtype=tf.int64, default_value=[0]), + }) + + features = tf.parse_single_example(example_serialized, feature_map) + + height = tf.cast(features['image/height'], dtype=tf.int32) + width = tf.cast(features['image/width'], dtype=tf.int32) + center = tf.cast(features['image/center'], dtype=tf.int32) + fname = tf.cast(features['image/filename'], dtype=tf.string) + fname = tf.Print(fname, [fname], message="image name: ") + + face_pts = tf.reshape( + tf.cast(features['image/face_pts'], dtype=tf.float32), [3, 5]) + + vis = tf.cast(features['image/visibility'], dtype=tf.float32) + x = tf.cast(features['image/x'], dtype=tf.float32) + y = tf.cast(features['image/y'], dtype=tf.float32) + + label = tf.concat([x, y, vis], 0) + label = tf.concat([label, face_pts], 1) + + image = decode_jpeg(features['image/encoded']) + image_size = tf.concat([height, width], 0) + + if has_3d: + pose = tf.cast(features['mosh/pose'], dtype=tf.float32) + shape = tf.cast(features['mosh/shape'], dtype=tf.float32) + gt3d = tf.reshape( + tf.cast(features['mosh/gt3d'], dtype=tf.float32), [14, 3]) + has_smpl3d = tf.cast(features['meta/has_3d'], dtype=tf.bool) + return image, image_size, label, center, fname, pose, shape, gt3d, has_smpl3d + else: + return image, image_size, label, center, fname + + +def rescale_image(image): + """ + Rescales image from [0, 1] to [-1, 1] + Resnet v2 style preprocessing. + """ + # convert to [0, 1]. + image = tf.subtract(image, 0.5) + image = tf.multiply(image, 2.0) + return image + + +def get_all_files(dataset_dir, datasets, split='train'): + # Dataset with different name path + diff_name = ['h36m', 'mpi_inf_3dhp'] + + data_dirs = [ + join(dataset_dir, dataset, '%s_*.tfrecord' % split) + for dataset in datasets if dataset not in diff_name + ] + if 'h36m' in datasets: + data_dirs.append( + join(dataset_dir, 'tf_records_human36m_wjoints', split, + '*.tfrecord')) + if 'mpi_inf_3dhp' in datasets: + data_dirs.append( + join(dataset_dir, 'mpi_inf_3dhp', split, '*.tfrecord')) + + all_files = [] + for data_dir in data_dirs: + all_files += sorted(glob(data_dir)) + + return all_files + + +def read_smpl_data(filename_queue): + """ + Parses a smpl Example proto. + It's contents are: + 'pose' : 72-D float + 'shape' : 10-D float + """ + with tf.name_scope(None, 'read_smpl_data', [filename_queue]): + reader = tf.TFRecordReader() + _, example_serialized = reader.read(filename_queue) + + feature_map = { + 'pose': tf.FixedLenFeature((72, ), dtype=tf.float32), + 'shape': tf.FixedLenFeature((10, ), dtype=tf.float32) + } + + features = tf.parse_single_example(example_serialized, feature_map) + pose = tf.cast(features['pose'], dtype=tf.float32) + shape = tf.cast(features['shape'], dtype=tf.float32) + + return pose, shape + + +def decode_jpeg(image_buffer, name=None): + """Decode a JPEG string into one 3-D float image Tensor. + Args: + image_buffer: scalar string Tensor. + name: Optional name for name_scope. + Returns: + 3-D float Tensor with values ranging from [0, 1). + """ + with tf.name_scope(name, 'decode_jpeg', [image_buffer]): + # Decode the string as an RGB JPEG. + # Note that the resulting image contains an unknown height and width + # that is set dynamically by decode_jpeg. In other words, the height + # and width of image is unknown at compile-time. + image = tf.image.decode_jpeg(image_buffer, channels=3) + + # convert to [0, 1]. + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + return image + + +def jitter_center(center, trans_max): + with tf.name_scope(None, 'jitter_center', [center, trans_max]): + rand_trans = tf.random_uniform( + [2, 1], minval=-trans_max, maxval=trans_max, dtype=tf.int32) + return center + rand_trans + + +def jitter_scale(image, image_size, keypoints, center, scale_range): + with tf.name_scope(None, 'jitter_scale', [image, image_size, keypoints]): + scale_factor = tf.random_uniform( + [1], + minval=scale_range[0], + maxval=scale_range[1], + dtype=tf.float32) + new_size = tf.to_int32(tf.to_float(image_size) * scale_factor) + new_image = tf.image.resize_images(image, new_size) + + # This is [height, width] -> [y, x] -> [col, row] + actual_factor = tf.to_float( + tf.shape(new_image)[:2]) / tf.to_float(image_size) + x = keypoints[0, :] * actual_factor[1] + y = keypoints[1, :] * actual_factor[0] + + cx = tf.cast(center[0], actual_factor.dtype) * actual_factor[1] + cy = tf.cast(center[1], actual_factor.dtype) * actual_factor[0] + + return new_image, tf.stack([x, y]), tf.cast( + tf.stack([cx, cy]), tf.int32) + + +def pad_image_edge(image, margin): + """ Pads image in each dimension by margin, in numpy: + image_pad = np.pad(image, + ((margin, margin), + (margin, margin), (0, 0)), mode='edge') + tf doesn't have edge repeat mode,, so doing it with tile + Assumes image has 3 channels!! + """ + + def repeat_col(col, num_repeat): + # col is N x 3, ravels + # i.e. to N*3 and repeats, then put it back to num_repeat x N x 3 + with tf.name_scope(None, 'repeat_col', [col, num_repeat]): + return tf.reshape( + tf.tile(tf.reshape(col, [-1]), [num_repeat]), + [num_repeat, -1, 3]) + + with tf.name_scope(None, 'pad_image_edge', [image, margin]): + top = repeat_col(image[0, :, :], margin) + bottom = repeat_col(image[-1, :, :], margin) + + image = tf.concat([top, image, bottom], 0) + # Left requires another permute bc how img[:, 0, :]->(h, 3) + left = tf.transpose(repeat_col(image[:, 0, :], margin), perm=[1, 0, 2]) + right = tf.transpose( + repeat_col(image[:, -1, :], margin), perm=[1, 0, 2]) + image = tf.concat([left, image, right], 1) + + return image + + +def random_flip(image, kp, pose=None, gt3d=None): + """ + mirrors image L/R and kp, also pose if supplied + """ + + uniform_random = tf.random_uniform([], 0, 1.0) + mirror_cond = tf.less(uniform_random, .5) + + if pose is not None: + new_image, new_kp, new_pose, new_gt3d = tf.cond( + mirror_cond, lambda: flip_image(image, kp, pose, gt3d), + lambda: (image, kp, pose, gt3d)) + return new_image, new_kp, new_pose, new_gt3d + else: + new_image, new_kp = tf.cond(mirror_cond, lambda: flip_image(image, kp), + lambda: (image, kp)) + return new_image, new_kp + + +def flip_image(image, kp, pose=None, gt3d=None): + """ + Flipping image and kp. + kp is 3 x N! + pose is 72D + gt3d is 14 x 3 + """ + image = tf.reverse(image, [1]) + new_kp = kp + + new_x = tf.cast(tf.shape(image)[0], dtype=kp.dtype) - kp[0, :] - 1 + new_kp = tf.concat([tf.expand_dims(new_x, 0), kp[1:, :]], 0) + # Swap left and right limbs by gathering them in the right order + # For COCO+ + swap_inds = tf.constant( + [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 16, 15, 18, 17]) + new_kp = tf.transpose(tf.gather(tf.transpose(new_kp), swap_inds)) + + if pose is not None: + new_pose = reflect_pose(pose) + new_gt3d = reflect_joints3d(gt3d) + return image, new_kp, new_pose, new_gt3d + else: + return image, new_kp + + +def reflect_pose(pose): + """ + Input is a 72-Dim vector. + Global rotation (first 3) is left alone. + """ + with tf.name_scope("reflect_pose", [pose]): + """ + # How I got the indices: + right = [11, 8, 5, 2, 14, 17, 19, 21, 23] + left = [10, 7, 4, 1, 13, 16, 18, 20, 22] + new_map = {} + for r_id, l_id in zip(right, left): + for axis in range(0, 3): + rind = r_id * 3 + axis + lind = l_id * 3 + axis + new_map[rind] = lind + new_map[lind] = rind + asis = [id for id in np.arange(0, 24) if id not in right + left] + for a_id in asis: + for axis in range(0, 3): + aind = a_id * 3 + axis + new_map[aind] = aind + swap_inds = np.array([new_map[k] for k in sorted(new_map.keys())]) + """ + swap_inds = tf.constant([ + 0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 14, 18, + 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 34, 35, 30, 31, 32, + 36, 37, 38, 42, 43, 44, 39, 40, 41, 45, 46, 47, 51, 52, 53, 48, 49, + 50, 57, 58, 59, 54, 55, 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, + 67, 68 + ], tf.int32) + + # sign_flip = np.tile([1, -1, -1], (24)) (with the first 3 kept) + sign_flip = tf.constant( + [ + 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, + -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, + -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, 1, -1, + -1, 1, -1, -1 + ], + dtype=pose.dtype) + + new_pose = tf.gather(pose, swap_inds) * sign_flip + + return new_pose + + +def reflect_joints3d(joints): + """ + Assumes input is 14 x 3 (the LSP skeleton subset of H3.6M) + """ + swap_inds = tf.constant([5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13]) + with tf.name_scope("reflect_joints3d", [joints]): + joints_ref = tf.gather(joints, swap_inds) + flip_mat = tf.constant([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], tf.float32) + joints_ref = tf.transpose( + tf.matmul(flip_mat, joints_ref, transpose_b=True)) + # Assumes all joints3d are mean subtracted + joints_ref = joints_ref - tf.reduce_mean(joints_ref, axis=0) + return joints_ref diff --git a/src/util/renderer.py b/src/util/renderer.py index d3c7b4aef..60a3e5cba 100644 --- a/src/util/renderer.py +++ b/src/util/renderer.py @@ -30,7 +30,6 @@ def __init__(self, self.h = img_size self.flength = flength - def __call__(self, verts, cam=None, @@ -84,7 +83,7 @@ def __call__(self, def rotated(self, verts, deg, - cam=None, + cam=None, axis='y', img=None, do_alpha=True, @@ -311,7 +310,7 @@ def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None): image = input_image.copy() input_is_float = False - if isinstance(image, np.float): + if np.issubdtype(image.dtype, np.float): input_is_float = True max_val = image.max() if max_val <= 2.: # should be 1 but sometimes it's slightly above 1 @@ -323,7 +322,11 @@ def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None): joints = joints.T joints = np.round(joints).astype(int) - jcolors = ['light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink', 'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue', 'purple', 'purple', 'red', 'green', 'green', 'white', 'white'] + jcolors = [ + 'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink', + 'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue', + 'purple', 'purple', 'red', 'green', 'green', 'white', 'white' + ] if joints.shape[1] == 19: # parent indices -1 means no parents @@ -421,3 +424,30 @@ def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None): image = image.astype(np.float32) return image + + +def draw_text(input_image, content): + """ + content is a dict. draws key: val on image + Assumes key is str, val is float + """ + import numpy as np + import cv2 + image = input_image.copy() + input_is_float = False + if np.issubdtype(image.dtype, np.float): + input_is_float = True + image = (image * 255).astype(np.uint8) + + black = np.array([0, 0, 0]) + margin = 15 + start_x = 5 + start_y = margin + for key in sorted(content.keys()): + text = "%s: %.2g" % (key, content[key]) + cv2.putText(image, text, (start_x, start_y), 0, 0.45, black) + start_y += margin + + if input_is_float: + image = image.astype(np.float32) / 255. + return image