Skip to content

Commit

Permalink
readme updated
Browse files Browse the repository at this point in the history
  • Loading branch information
cpatel committed Jun 17, 2020
1 parent ae52d44 commit df83e43
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 29 deletions.
62 changes: 47 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[Under Development]

# TailorNet Training and Models
This repository contains training code for "TailorNet: Predicting Clothing in 3D as a Function of Human Pose, Shape and Garment Style" (CVPR 2020 Oral)

# TailorNet
This repository contains training and inference code for the following paper:
```
TailorNet: Predicting Clothing in 3D as a Function of Human Pose, Shape and Garment Style
Chaitanya Patel*, Zhouyingcheng Liao*, Gerard Pons-Moll
CVPR 2020 (ORAL)
```
[[arxiv](https://arxiv.org/abs/2003.04583)]
[[project website](https://virtualhumans.mpi-inf.mpg.de/tailornet/)]
[[Dataset Repo](https://github.com/zycliao/TailorNet_dataset)]
Expand All @@ -17,25 +19,55 @@ scipy

## How to Run
- Download and prepare SMPL model and TailorNet data from [here](https://github.com/zycliao/TailorNet_dataset).
- Set DATA_DIR variable in `global_var.py` file.
- Download trained models from here. [Coming Soon]
- Set DATA_DIR and SMPL paths in `global_var.py` file accordingly.
- Download trained models' weights. Unzip it and set paths of downloaded models to `LF_MODEL_PATH`, `HF_MODEL_PATH` and
`SS2G_MODEL_PATH` variables in `global_var.py`.
- [old-t-shirt_female_weights](https://datasets.d2.mpi-inf.mpg.de/tailornet/old-t-shirt_female_weights.zip)
- [Other garments coming soon]
- Set output path in `run_tailornet.py` and run it to predict garments on some random inputs. You can play with
different inputs. You can also run inference on motion sequence data.
- To visualize predicted garment using blender, run `python run_tailornet.py render`. (blender needs to be installed.)
- To visualize predicted garment using blender, run `python run_tailornet.py render`. (Blender needs to be installed.)


#### If you download trained model weights for multiple garments...
... then you can merge downloaded weights directories to follow a directory structure similar to the following.
```
weights_folder
----tn_orig_baseline
--------{garment_class}_{gender} (e.g. t-shirt_female)
------------lin.pth.tar (model weights)
------------params.json (some model params)
----tn_orig_lf
--------{garment_class}_{gender}
------------lin.pth.tar
------------params.json
----tn_orig_ss2g
--------{garment_class}_{gender}
------------lin.pth.tar
------------params.json
----tn_orig_hf
--------{garment_class}_{gender}
------------{shape_idx}_{style_idx} (e.g. 000_023 pivot)
----------------lin.pth.tar
----------------params.json
```
Then you won't need to change model checkpoint paths for multiple garments.


## Training TailorNet yourself
- Set appropriate global variables in `global_var.py`, especially LOG_DIR where training logs will be stored.
- Set appropriate config variables in `trainer/base_trainer.py` and run `python trainer/base_trainer.py` to train
simple MLP baseline.
- Set global variables in `global_var.py`, especially LOG_DIR where training logs will be stored.
- Set config variables in `trainer/base_trainer.py` and run `python trainer/base_trainer.py` to train
TailorNet MLP baseline.
- Similarly, run `trainer/lf_trainer.py` for training low frequency predictor and `trainer/ss2g_trainer.py` for
shape-style-to-garment(in canonical pose) model.
- Run `python trainer/hf_trainer.py --shape_style <shape1>_<style1> <shape2>_<style2>` to train pivot high frequency
predictors for pivots `<shape1>_<style1>` and `<shape2>_<style2>`. See `DATA_DIR/<garment_class>_<gender>/pivots.txt`
to know available pivots.
- Run `python trainer/hf_trainer.py --shape_style <shape1>_<style1> <shape2>_<style2> ...` to train pivot high
frequency predictors for pivots `<shape1>_<style1>`, `<shape2>_<style2>`, and so on. See
`DATA_DIR/<garment_class>_<gender>/pivots.txt` to know available pivots.
- Use `models.tailornet_model.TailorNetModel` to do prediction.

## Citation
Cite us:
Cite us if you use our model, code or data:
```
@inproceedings{patel20tailornet,
title = {TailorNet: Predicting Clothing in 3D as a Function of Human Pose, Shape and Garment Style},
Expand Down
19 changes: 14 additions & 5 deletions global_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@
# Dataset root directory. Change it to point to downloaded data root directory.
DATA_DIR = '/BS/cloth-anim/static00/tailor_data'

# Set the paths to SMPL model
SMPL_PATH_NEUTRAL = '/BS/RVH/work/data/smpl_models/neutral/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
SMPL_PATH_MALE = '/BS/RVH/work/data/smpl_models/lrotmin/lbs_tj10smooth6_0fixed_normalized/male/model.pkl'
SMPL_PATH_FEMALE = '/BS/RVH/work/data/smpl_models/lrotmin/lbs_tj10smooth6_0fixed_normalized/female/model.pkl'

# Log directory where training logs, checkpoints and visualizations will be stored
LOG_DIR = '/BS/cpatel/work/data/learn_anim'

# TailorNet trained models path
LF_MODEL_PATH = "/BS/cpatel/work/data/learn_anim/tn_orig_lf"
HF_MODEL_PATH = "/BS/cpatel/work/data/learn_anim/tn_orig_hf"
SS2G_MODEL_PATH = "/BS/cpatel/work/data/learn_anim/tn_orig_ss2g"

# --------------------------------------------------------------------
# Variable below hardly need to be changed
# --------------------------------------------------------------------

# Available genders
GENDERS = ['neutral', 'male', 'female']

Expand All @@ -21,11 +35,6 @@
# template which make <garment_class> garment
GAR_INFO_FILE = 'garment_class_info.pkl'

# Set the paths to SMPL model
SMPL_PATH_NEUTRAL = '/BS/RVH/work/data/smpl_models/neutral/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
SMPL_PATH_MALE = '/BS/RVH/work/data/smpl_models/lrotmin/lbs_tj10smooth6_0fixed_normalized/male/model.pkl'
SMPL_PATH_FEMALE = '/BS/RVH/work/data/smpl_models/lrotmin/lbs_tj10smooth6_0fixed_normalized/female/model.pkl'

# # Skirt template path
# SKIRT_TEMPLATE = "/BS/cpatel/work/data/garment/Skirt/smooth_Skirt.ply"

Expand Down
18 changes: 9 additions & 9 deletions models/tailornet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, lf_logdir, hf_logdir, ss2g_logdir, garment_class, gender):
self.train_pivots = pivots_ds.ss_list

self.hf_runners = [
hf_runner("{}/{}_{}/{}_{}".format(hf_logdir, garment_class, gender, shape_idx, style_idx))
hf_runner("{}/{}_{}".format(hf_logdir, shape_idx, style_idx))
for shape_idx, style_idx in self.train_pivots
]
self.lf_runner = lf_runner(lf_logdir)
Expand Down Expand Up @@ -92,13 +92,13 @@ def interp4(self, thetas, betas, gammas, pred_disp_pivot, sigma=0.5):

def get_best_runner(garment_class='t-shirt', gender='female', lf_logdir=None, hf_logdir=None, ss2g_logdir=None):
"""Helper function to get TailorNet runner."""
prefix = "tn_orig"
if lf_logdir is None:
lf_logdir = "/BS/cpatel/work/data/learn_anim/{}_lf/{}_{}/".format(prefix, garment_class, gender)
if hf_logdir is None:
hf_logdir = "/BS/cpatel/work/data/learn_anim/{}_hf".format(prefix)
if ss2g_logdir is None:
ss2g_logdir = "/BS/cpatel/work/data/learn_anim/{}_ss2g/{}_{}".format(prefix, garment_class, gender)
lf_logdir = global_var.LF_MODEL_PATH if lf_logdir is None else lf_logdir
hf_logdir = global_var.HF_MODEL_PATH if hf_logdir is None else hf_logdir
ss2g_logdir = global_var.SS2G_MODEL_PATH if ss2g_logdir is None else ss2g_logdir

lf_logdir = os.path.join(lf_logdir, "{}_{}".format(garment_class, gender))
hf_logdir = os.path.join(hf_logdir, "{}_{}".format(garment_class, gender))
ss2g_logdir = os.path.join(ss2g_logdir, "{}_{}".format(garment_class, gender))
runner = TailorNetModel(lf_logdir, hf_logdir, ss2g_logdir, garment_class, gender)
return runner

Expand All @@ -113,7 +113,7 @@ def evaluate():
gender = 'female'
garment_class = 'old-t-shirt'

dataset = MultiStyleShape(garment_class=garment_class, gender=gender, split='train_test')
dataset = MultiStyleShape(garment_class=garment_class, gender=gender, split='test')
dataloader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=False, drop_last=False)
print(len(dataset))

Expand Down

0 comments on commit df83e43

Please sign in to comment.