Skip to content

Commit

Permalink
update tips about training postnet
Browse files Browse the repository at this point in the history
  • Loading branch information
yerfor committed Feb 6, 2023
1 parent c694897 commit 2c4c146
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 17 deletions.
Binary file added assets/tips_to_select_postnet_ckpt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/prepare_env/install_guide_lrs3.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[中文文档](docs/prepare_env/zh/install_guide_lrs3-zh.md)
[中文文档](./zh/install_guide_lrs3-zh.md)

This guide is about building a python env to process the LRS3-TED dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs/prepare_env/install_guide_nerf.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[中文文档](docs/prepare_env/zh/install_guide_nerf-zh.md)
[中文文档](./zh/install_guide_nerf-zh.md)

This guide is about building a python environment, which is necessary to process the dataset for NeRF and train the GeneFace.

Expand Down
2 changes: 1 addition & 1 deletion docs/process_data/process_lrs3.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[中文文档](docs/process_data/zh/process_lrs3-zh.md)
[中文文档](./zh/process_lrs3-zh.md)

# Process the Target Person Video

Expand Down
2 changes: 1 addition & 1 deletion docs/process_data/process_target_person_video.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[中文文档](docs/process_data/zh/process_target_person_video-zh.md)
[中文文档](./zh/process_target_person_video-zh.md)

# Process the Target Person Video

Expand Down
45 changes: 35 additions & 10 deletions docs/train_models/train_models-zh.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# 训练 GeneFace!
GeneFace 包含三个模块:1)一个训练于LRS3数据集并通用于所有说话人的`语音转动作`模块;2)一个适用于特定说话人的`动作后处理`网络,它被训练于LRS3数据集和对应说话人的视频数据;3)一个适用于特定说话人的`基于NeRF的渲染器`,它被训练于对应说话人的视频数据。

要训练GeneFace,请首先按照我们在`docs/prepare_env`文档和`docs/process_data`文档中的步骤,分别完成搭建环境和准备数据集。
GeneFace 包含三个模块:1)一个训练于LRS3数据集并通用于所有说话人的 `语音转动作`模块;2)一个适用于特定说话人的 `动作后处理`网络,它被训练于LRS3数据集和对应说话人的视频数据;3)一个适用于特定说话人的 `基于NeRF的渲染器`,它被训练于对应说话人的视频数据。

要训练GeneFace,请首先按照我们在 `docs/prepare_env`文档和 `docs/process_data`文档中的步骤,分别完成搭建环境和准备数据集。

[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)中,我们还准备了GeneFace的预训练模型,其中:

* `lrs3.zip` 包含了在LRS3数据集上训练的模型 (包括一个`lm3d_vae`模型以实现语音转动作的变换,和一个`syncnet`以实现对语音-嘴形对齐程度的衡量),这些模型是通用于所有说话人视频的。
* `May.zip` 包含了我们在`May.mp4`视频上训练的所有模型(包括一个`postnet`以对`lm3d_vae`产生的3D landmark进行后处理,以及一个`lm3d_nerf``lm3d_nerf_torso`分别渲染说话人的头部和躯干部位。)对每个说话人视频,你都需要新训练这三个模型。
* `lrs3.zip` 包含了在LRS3数据集上训练的模型 (包括一个 `lm3d_vae`模型以实现语音转动作的变换,和一个 `syncnet`以实现对语音-嘴形对齐程度的衡量),这些模型是通用于所有说话人视频的。
* `May.zip` 包含了我们在 `May.mp4`视频上训练的所有模型(包括一个 `postnet`以对 `lm3d_vae`产生的3D landmark进行后处理,以及一个 `lm3d_nerf` `lm3d_nerf_torso`分别渲染说话人的头部和躯干部位。)对每个说话人视频,你都需要新训练这三个模型。

## 步骤1. 训练SyncNet模型
注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`lrs3.zip`文件中提供了预训练好的SyncNet,你可以将其下载并提取出其中的`syncnet`文件夹,并将它放到`checkpoints/lrs3/syncnet`路径中。

注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`lrs3.zip`文件中提供了预训练好的SyncNet,你可以将其下载并提取出其中的 `syncnet`文件夹,并将它放到 `checkpoints/lrs3/syncnet`路径中。

如果你想要从头训练SyncNet,请执行以下命令行(你需要首先准备好LRS3数据集):

Expand All @@ -23,7 +25,7 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/lrs3/lm3d_syncn

## 步骤2. 训练Audio2Motion模型

注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`lrs3.zip`文件中提供了预训练好的audio2motion模型,你可以将其下载并提取出其中的`lm3d_vae`文件夹,并将它放到`checkpoints/lrs3/lm3d_vae`路径中。
注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0) `lrs3.zip`文件中提供了预训练好的audio2motion模型,你可以将其下载并提取出其中的 `lm3d_vae`文件夹,并将它放到 `checkpoints/lrs3/lm3d_vae`路径中。

如果你想要从头训练audio2motion模型,请执行以下命令行(你需要首先准备好LRS3数据集):

Expand All @@ -33,12 +35,11 @@ export PYTHONPATH=./
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/lrs3/lm3d_vae_sync.yaml --exp_name=checkpoints/lrs3/lm3d_vae
```

注意名为`lm3d_vae`的audio2motion模型适用于所有说话人视频,所以你只需要训练它一次!
注意名为 `lm3d_vae`的audio2motion模型适用于所有说话人视频,所以你只需要训练它一次!

## 步骤3. 训练PostNet模型


注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`May.zip`文件中提供了专用于`data/raw/videos/May.mp4`视频的预训练好的Postnet模型,你可以将其下载并提取出其中的`postnet`文件夹,并将它放到`checkpoints/May/postnet`路径中。
注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`May.zip`文件中提供了专用于 `data/raw/videos/May.mp4`视频的预训练好的Postnet模型,你可以将其下载并提取出其中的 `postnet`文件夹,并将它放到 `checkpoints/May/postnet`路径中。

如果你想要从头训练postnet模型,请执行以下命令行(你需要首先准备好LRS3数据集和对应的说话人视频数据集):

Expand All @@ -50,9 +51,33 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d

注意postnet模型仅适用于对应的说话人视频,所以对每个新的说话人视频你都需要训练一个新的postnet。

#### 训练小tips:选择合适步数的checkpoint

由于我们的postnet的训练属于对抗域适应(Adversarial Domain Adaptation)过程,而对抗训练的训练过程被广泛公认是不稳定的。例如当训练步数过多时,可能导致模型出现模式坍塌,比如postnet可能会学到将输入的任意表情都映射到同一个target person domain的表情(体现在validation sync/mse loss上升)。因此为了避免最终得到的人脸表情的lip-sync性能下降过大,我们应该early stop,即选择步数较小的checkpoint。但同时,当步数过小的时候,postnet可能还欠拟合,无法保证能够将各种各样的表情成功地映射到target person domain(体现在adversarial loss未收敛)。

因此,在实际操作中,我们一般根据三个原则来选择合适步数的checkpoint:(1)validation sync/mse loss越低越好;(2)adversarial loss达到收敛。(3)尽量选择步数较小的checkpoint。

下图我们展示了一个实例,它是训练 `May.mp4`时我们选择合适的postnet checkpoint的过程。我们发现6k步的时候,`val/mse``val/sync`较小,并且 `tr/disc_neg_conf``tr/disc_pos_conf`都约等于0.5(这意味着discriminator已经无法区分正样本和postnet产生的负样本之间的差异),因此我们选择6k步的checkpoint。

<p align="center">
<br>
<img src="../../assets/tips_to_select_postnet_ckpt.png" width="1000"/>
<br>
</p>

最后,为了快速验证选择的postnet checkpoint的lip-sync性能。我们还提供了一个3D landmark的可视化脚本。运行以下脚本(你可能需要修改以下 `.sh``.py`文件内的路径名):

```
conda activate geneface
bash infer_postnet.sh # use the selected postnet checkpoint to predict the 3D landmark sequence.
python utils/visualization/lm_visualizer.py # visualize the 3D landmark sequence.
```

你能在 `./3d_landmark.mp4`路径中看到可视化的3d landmark视频。

## 步骤4. 训练基于NeRF的渲染器

注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0)`May.zip`文件中提供了专用于`data/raw/videos/May.mp4`视频的预训练好的NeRF模型,你可以将其下载并提取出其中的`lm3d_nerf``lm3d_nerf_torso`文件夹,并将它放到`checkpoints/May/lm3d_nerf``checkpoints/May/lm3d_nerf_torso`路径中。
注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0) `May.zip`文件中提供了专用于 `data/raw/videos/May.mp4`视频的预训练好的NeRF模型,你可以将其下载并提取出其中的 `lm3d_nerf` `lm3d_nerf_torso`文件夹,并将它放到 `checkpoints/May/lm3d_nerf` `checkpoints/May/lm3d_nerf_torso`路径中。

如果你想要从头训练NeRF模型,请执行以下命令行(你需要首先准备好LRS3数据集和对应的说话人视频数据集):

Expand Down
26 changes: 25 additions & 1 deletion docs/train_models/train_models.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[中文文档](docs/train_models/train_models-zh.md)
[中文文档](./train_models-zh.md)

# Train GeneFace!

Expand Down Expand Up @@ -53,6 +53,30 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d

Note that the Post-net is person-specific, so for each target person video, you need to train a new Post-net.

#### tips: choosing the appropriate checkpoint

Since our postnet belongs to **Adversarial** Domain Adaptation, whose training process is widely considered to be unstable. For example, training the model for too many steps may lead to model collapse. For example, when mode collapse occurs, the postnet may map abitrary input landmark into the same landmark in the target person domain (which results in rises in validation sync/mse loss). Therefore, to avoid degradation of the lip-sync performance, we should make an early stop, i.e., select a checkpoint trained with a small number of iterations. However, at the same time, if the number of iterations is too small, postnet may be underfitting and cannot successfully map the landmarks into the target person domain (which means the adversarial loss is not converged).

Therefore, in practice, we choose the checkpoint with the appropriate number of iterations according to three principles: (1) validation sync/mse loss should be as low as possible; (2) the adversarial loss should be convergenced. (3) a small number of iterations is desirable.

The following figure shows an example of the process of selecting the appropriate postnet checkpoint when training `May.mp4`. We found that `val/mse` and `val/sync` are relatively low at 6k steps. Besides, `tr/disc_neg_conf` and `tr/disc_pos_conf` are both about 0.5 (which means that the discriminator cannot distinguish between the (GT) positive samples and the (postnet-generated) negative samples), so we choose the checkpoint at 6k steps.

<p align="center">
<br>
<img src="../../assets/tips_to_select_postnet_ckpt.png" width="1000"/>
<br>
</p>

Finally, to quickly verify the lip-sync performance of the selected postnet checkpoint, we also provide a script to visualize the predicted 3D landmark. Run the following script (you may need to modify the path names in the following `.sh` and `.py` files):

```
conda activate geneface
bash infer_postnet.sh # use the selected postnet checkpoint to predict the 3D landmark sequence.
python utils/visualization/lm_visualizer.py # visualize the 3D landmark sequence.
```

You can see the visual 3d landmark video in `./3d_landmark.mp4`.

## Step4. Train the NeRF-based Render

NOTE: We provide the pre-trained NeRF model for the target person video named `data/raw/videos/May.mp4` in `May.zip` at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.0.0), you can download it and extract the `lm3d_nerf` and `lm3d_nef_torso` folder, then place it into the path `checkpoints/May/lm3d_nerf` and `checkpoints/May/lm3d_nerf_torso`, respectively.
Expand Down
9 changes: 7 additions & 2 deletions utils/visualization/lm_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def render_idexp_npy_to_lm_video(npy_name, out_video_name, audio_name=None):
for i_img in range(len(lm3d)):
lm2d = lm3d[i_img ,:, :2] # [68, 2]
img = np.ones([WH, WH, 3], dtype=np.uint8) * 255

for i in range(len(lm2d)):
x, y = lm2d[i]
if i in eye_idx:
Expand All @@ -35,15 +36,19 @@ def render_idexp_npy_to_lm_video(npy_name, out_video_name, audio_name=None):
color = (255,0,0)
img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
font = cv2.FONT_HERSHEY_SIMPLEX
img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
img = cv2.flip(img, 0)
for i in range(len(lm2d)):
x, y = lm2d[i]
y = WH - y
img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))

out_name = os.path.join(tmp_img_dir, f'{format(i_img, "05d")}.png')
cv2.imwrite(out_name, img)
imgs_to_video(tmp_img_dir, out_video_name, audio_name)
os.system(f"rm -r {tmp_img_dir}")

if __name__ == '__main__':
npy_name = f"infer_out/May/pred_lm3d/zozo.npy"
out_path = "./1.mp4"
out_path = "./3d_landmark.mp4"
audio_path = "data/raw/val_wavs/zozo.wav"
render_idexp_npy_to_lm_video(npy_name, out_path, audio_path)

0 comments on commit 2c4c146

Please sign in to comment.