Skip to content

Commit

Permalink
v0.1 release to public
Browse files Browse the repository at this point in the history
  • Loading branch information
bmild authored and yenchenlin committed Apr 5, 2020
0 parents commit bcb6700
Show file tree
Hide file tree
Showing 35 changed files with 2,698 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
**/.ipynb_checkpoints
**/__pycache__
*.png
*.mp4
*.npy
*.npz
*.dae
data/*
logs/*
Empty file added .gitmodules
Empty file.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2020 bmild

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
164 changes: 164 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# NeRF-pytorch


[NeRF](http://www.matthewtancik.com/nerf) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):

![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif)
![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif)

This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is tested to match authors' Tensorflow implementation [here](https://github.com/bmild/nerf) numerically.

## Installation

```
git clone https://github.com/yenchenlin/nerf-pytorch.git
cd nerf-pytorch
pip install -r requirements.txt
cd torchsearchsorted
pip install .
cd ../
```

<details>
<summary> Dependencies (click to expand) </summary>

## Dependencies
- PyTorch 1.4
- matplotlib
- numpy
- imageio
- imageio-ffmpeg
- configargparse

The LLFF data loader requires ImageMagick.

You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.

</details>

## How To Run?

### Quick Start

Download data for two example datasets: `lego` and `fern`
```
bash download_example_data.sh
```

To train a low-res `lego` NeRF:
```
python run_nerf_torch.py --config configs/config_lego.txt
```
After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.

![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif)

---

To train a low-res `fern` NeRF:
```
python run_nerf_torch.py --config configs/config_fern.txt
```
After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`

![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif)

---

### More Datasets
To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
```
├── configs
│   ├── ...
│  
├── data
│   ├── nerf_llff_data
│   │   └── fern
│   │  └── flower # downloaded llff dataset
│   │  └── horns # downloaded llff dataset
| | └── ...
| ├── nerf_synthetic
| | └── lego
| | └── ship # downloaded synthetic dataset
| | └── ...
```

---

To train NeRF on different datasets:

```
python run_nerf_torch.py --config configs/config_{DATASET}.txt
```

replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.

---

To test NeRF trained on different datasets:

```
python run_nerf_torch.py --config configs/config_{DATASET}.txt --render_only
```

replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.


### Pre-trained Models

You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:

```
├── logs
│   ├── fern_test
│   ├── flower_test # downloaded logs
│ ├── trex_test # downloaded logs
```

### Reproducibility

Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
```
git checkout reproduce
py.test
```

## Method

[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
\*denotes equal contribution

<img src='imgs/pipeline.jpg'/>

> A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views

## Citation
Kudos to the authors for their amazing results:
```
@misc{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
eprint={2003.08934},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

However, if you find this implementation or pre-trained models helpful, please consider to cite:
```
@misc{lin2020nerfpytorch,
title={NeRF-pytorch},
author={Yen-Chen, Lin},
howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
year={2020}
}
```
15 changes: 15 additions & 0 deletions configs/config_fern.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = fern_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fern
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

15 changes: 15 additions & 0 deletions configs/config_flower.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = flower_test
basedir = ./logs
datadir = ./data/nerf_llff_data/flower
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

15 changes: 15 additions & 0 deletions configs/config_fortress.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = fortress_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fortress
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

15 changes: 15 additions & 0 deletions configs/config_horns.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = horns_test
basedir = ./logs
datadir = ./data/nerf_llff_data/horns
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

15 changes: 15 additions & 0 deletions configs/config_lego.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = lego_test
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender

half_res = True

N_samples = 64
N_importance = 64

use_viewdirs = True

white_bkgd = True

N_rand = 1024
15 changes: 15 additions & 0 deletions configs/config_trex.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
expname = trex_test
basedir = ./logs
datadir = ./data/nerf_llff_data/trex
dataset_type = llff

factor = 8
llffhold = 8

N_rand = 1024
N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1e0

6 changes: 6 additions & 0 deletions download_example_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
mkdir -p data
cd data
wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
unzip nerf_example_data.zip
cd ..
Binary file added imgs/pipeline.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 91 additions & 0 deletions load_blender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2


trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w


def load_blender_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
metas[s] = json.load(fp)

all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip

for frame in meta['frames'][::skip]:
fname = os.path.join(basedir, frame['file_path'] + '.png')
imgs.append(imageio.imread(fname))
poses.append(np.array(frame['transform_matrix']))
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)

i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]

imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)

H, W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)

if half_res:
H = H//2
W = W//2
focal = focal/2.

imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()


return imgs, poses, render_poses, [H, W, focal], i_split


Loading

0 comments on commit bcb6700

Please sign in to comment.