Skip to content

Commit

Permalink
Updates for FVD + support more pre-trained model features
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson1yan committed Jun 27, 2021
1 parent 621cc8a commit 765a245
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 46 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ We present VideoGPT: a conceptually simple architecture for scaling likelihood b
## Installation
Change the `cudatoolkit` version compatible to your machine.
```bash
$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install git+https://github.com/wilson1yan/VideoGPT.git
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install git+https://github.com/wilson1yan/VideoGPT.git
```

### Sparse Attention (Optional)
For limited compute scenarios, it may be beneficial to use [sparse attention](https://arxiv.org/abs/1904.10509).
```bash
$ sudo apt-get install llvm-9-dev
$ DS_BUILD_SPARSE_ATTN=1 pip install deepspeed
sudo apt-get install llvm-9-dev
DS_BUILD_SPARSE_ATTN=1 pip install deepspeed
```
After installng `deepspeed`, you can train a sparse transformer by setting the flag `--attn_type sparse` in `scripts/train_videogpt.py`. The default supported sparsity configuration is an N-d strided sparsity layout, however, you can write your own arbitrary layouts to use.

Expand Down Expand Up @@ -107,6 +107,12 @@ Use the `scripts/train_vqvae.py` script to train a VQ-VAE. Execute `python scrip
* `--resolution 128`: spatial resolution to train on
* `--sequence_length 16`: temporal resolution, or video clip length

## Using Pretrained VideoGPTs
There are two available pre-trained VideoGPT models
* `bair_gpt`: single frame-conditional BAIR model using discrete encodings from `bair_stride4x2x2` VQ-VAE
* `ucf101_uncond_gpt`: unconditional UCF101 model using discrete encodings from `ucf101_stride4x4x4` VQ-VAE
Note that both pre-trained models use sparse attention. For purposes of fine-tuning, you will need to install sparse attention, however, sampling does not required sparse attention to be installed.

## Training VideoGPT
You can download a pretrained VQ-VAE, or train your own. Afterwards, use the `scripts/train_videogpt.py` script to train an VideoGPT model for sampling. Execute `python scripts/train_videogpt.py -h` for information on all available training settings. A subset of more relevant settings are listed below, along with default values.
### VideoGPT Specific Settings
Expand All @@ -132,7 +138,10 @@ You can download a pretrained VQ-VAE, or train your own. Afterwards, use the `sc
* `--sequence_length 16`: temporal resolution, or video clip length

## Sampling VideoGPT
After training, the VideoGPT model can be sampled using the `scripts/sample_videogpt.py`. You may need to install `ffmpeg`: `sudo apt-get install ffmpeg`
VideoGPT models can be sampled using the `scripts/sample_videogpt.py`. You can specify a path to a checkpoint during training, or the name of a pretrained model. You may need to install `ffmpeg`: `sudo apt-get install ffmpeg`

## Evaluation
Evaluation is done primarily using [Frechet Video Distance (FVD)](https://arxiv.org/abs/1812.01717) for BAIR and Kinetics, and [Inception Score](https://arxiv.org/abs/1606.03498) for UCF-101. Inception Score can be computed by generating samples and using the code from the [TGANv2 repo](https://github.com/pfnet-research/tgan2). FVD can be computed through `python scripts/compute_fvd.py`, which runs a PyTorch-ported version of the [original codebase](https://github.com/google-research/google-research/tree/master/frechet_video_distance)

## Reproducing Paper Results
Note that this repo is primarily designed for simplicity and extending off of our method. Reproducing the full paper results can be done using code found at a [separate repo](https://github.com/wilson1yan/VideoGPT-Paper). However, be aware that the code is not as clean.
Expand Down
7 changes: 4 additions & 3 deletions scripts/compute_fvd.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import functools
import argparse
from videogpt.download import load_i3d_pretrained
from tqdm import tqdm
import numpy as np

import torch
import torch.multiprocessing as mp
import torch.distributed as dist

from videogpt.fvd.fvd import get_fvd_logits, frechet_distance, load_fvd_model
from videogpt.fvd.fvd import get_fvd_logits, frechet_distance
from videogpt import VideoData, VideoGPT, load_videogpt


Expand Down Expand Up @@ -47,7 +48,7 @@ def main_worker(rank, size, args_in):
loader = VideoData(args).test_dataloader()

#################### Load I3D ########################################
i3d = load_fvd_model(device)
i3d = load_i3d_pretrained(device)

#################### Compute FVD ###############################
fvds = []
Expand Down Expand Up @@ -122,7 +123,7 @@ def eval_fvd(i3d, videogpt, loader, device):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--ckpt', type=str, default='bair_gpt')
parser.add_argument('--n_trials', type=int, default=1, help="Number of trials to compute mean/std")
parser.add_argument('--port', type=int, default=23452)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion scripts/sample_videogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, required=None)
parser.add_argument('--ckpt', type=str, default='ucf101_uncond_gpt')
parser.add_argument('--n', type=int, default=8)
args = parser.parse_args()
n = args.n
Expand Down
2 changes: 1 addition & 1 deletion videogpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .vqvae import VQVAE
from .gpt import VideoGPT
from .data import VideoData
from .download import load_vqvae, load_videogpt, download
from .download import load_vqvae, load_videogpt, load_i3d_pretrained, download

12 changes: 12 additions & 0 deletions videogpt/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def load_vqvae(model_name, device=torch.device('cpu')):

_VIDEOGPT = {
'bair_gpt': '1fNTtJAgO6grEtPNrufkpbee1CfGztW-1', # 1-frame conditional, 16 frames of 64 x 64 images
'ucf101_uncond_gpt': '1QkF_Sb2XVRgSbFT_SxQ6aZUeDFoliPQq', # unconditional, 16 frames of 128 x 128 images
}

def load_videogpt(model_name, device=torch.device('cpu')):
Expand All @@ -73,3 +74,14 @@ def load_videogpt(model_name, device=torch.device('cpu')):
gpt.eval()

return gpt


_I3D_PRETRAINED_ID = '1mQK8KD8G6UWRa5t87SRMm5PVXtlpneJT'

def load_i3d_pretrained(device=torch.device('cpu')):
from .fvd.pytorch_i3d import InceptionI3d
i3d = InceptionI3d(400, in_channels=3).to(device)
filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt')
i3d.load_state_dict(torch.load(filepath, map_location=device))
i3d.eval()
return i3d
17 changes: 0 additions & 17 deletions videogpt/fvd/fvd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.utils.data as data

from ..data import preprocess as preprocess_single
from .pytorch_i3d import InceptionI3d
import os


def preprocess(videos, target_resolution=224):
Expand All @@ -22,15 +14,6 @@ def get_fvd_logits(videos, i3d, device):
embeddings = get_logits(i3d, videos, device)
return embeddings

def load_fvd_model(device):
i3d = InceptionI3d(400, in_channels=3).to(device)
current_dir = os.path.dirname(os.path.abspath(__file__))
i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt')
i3d.load_state_dict(torch.load(i3d_path, map_location=device))
i3d.eval()
return i3d


# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
def _symmetric_matrix_square_root(mat, eps=1e-10):
u, s, v = torch.svd(mat)
Expand Down
20 changes: 1 addition & 19 deletions videogpt/fvd/pytorch_i3d.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
# https://github.com/piergiaj/pytorch-i3d
# Original code from https://github.com/piergiaj/pytorch-i3d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np

import os
import sys
from collections import OrderedDict


class MaxPool3dSamePadding(nn.MaxPool3d):

def compute_pad(self, dim, s):
Expand All @@ -22,15 +15,12 @@ def compute_pad(self, dim, s):
def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
#print t,h,w
out_t = np.ceil(float(t) / float(self.stride[0]))
out_h = np.ceil(float(h) / float(self.stride[1]))
out_w = np.ceil(float(w) / float(self.stride[2]))
#print out_t, out_h, out_w
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)
#print pad_t, pad_h, pad_w

pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
Expand All @@ -40,8 +30,6 @@ def forward(self, x):
pad_w_b = pad_w - pad_w_f

pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
#print x.size()
#print pad
x = F.pad(x, pad)
return super(MaxPool3dSamePadding, self).forward(x)

Expand Down Expand Up @@ -90,15 +78,12 @@ def compute_pad(self, dim, s):
def forward(self, x):
# compute 'same' padding
(batch, channel, t, h, w) = x.size()
#print t,h,w
out_t = np.ceil(float(t) / float(self._stride[0]))
out_h = np.ceil(float(h) / float(self._stride[1]))
out_w = np.ceil(float(w) / float(self._stride[2]))
#print out_t, out_h, out_w
pad_t = self.compute_pad(0, t)
pad_h = self.compute_pad(1, h)
pad_w = self.compute_pad(2, w)
#print pad_t, pad_h, pad_w

pad_t_f = pad_t // 2
pad_t_b = pad_t - pad_t_f
Expand All @@ -108,10 +93,7 @@ def forward(self, x):
pad_w_b = pad_w - pad_w_f

pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
#print x.size()
#print pad
x = F.pad(x, pad)
#print x.size()

x = self.conv3d(x)
if self._use_batch_norm:
Expand Down

0 comments on commit 765a245

Please sign in to comment.