NeRF (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes.
This project is a NeRF implementation that uses QAT (Quantization Aware Training) to reduce the size by four times while reducing the loss rate by less than 4%. This code is based on the Pytorch implementation here, with the addition of quantization module and model storage. Here are some videos generated by this repository:
git clone https://github.com/haujinnn/nerf-qat.git
cd nerf-qat
pip install -r requirements.txt
Download data for two example datasets: lego
and fern
bash download_example_data.sh
To train lego
QAT NeRF:
python run_nerf_qat.py --config configs/lego_qat.txt
After training, you can find the following video at logs/lego_qat/lego_test_spiral_100000_rgb.mp4
.
To train fern
QAT NeRF:
python run_nerf_qat.py --config configs/fern_qat.txt
After training for 50k iterations (~3 hours on Quadro RTX 6000/8000), you can find the following video at logs/fern_qat/fern_qat_spiral_200000_rgb.mp4
and logs/fern_qat/fern_qat_spiral_50000_disp.mp4
To test QAT NeRF:
python run_nerf_qat.py --config configs/{DATASET}_qat.txt
replace {DATASET}
with lego
| fern
| etc.
The process of QAT is as follows.
- Add Fake quantization module
- Module Fusion
- Prepare Quantization
- Training Loop
- Convert & Save
To add Fake quantization module at run_nerf_helpers.py
class NeRF_qat
:
# Model
class NeRF_qat(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
...
################
self.quant = torch.ao.quantization.QuantStub() # QuantStub converts tensors from floating point to quantized
self.dequant = torch.ao.quantization.DeQuantStub() # DeQuantStub converts tensors from quantized to floating point
################
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
################
h = self.quant(h)
################
...
################
outputs = self.dequant(outputs)
################
return outputs
Key mapping is required before model storage because the structure changes during quantization. The key mapping code is as follows.
# Currently modifying
NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
Ben Mildenhall*1,
Pratul P. Srinivasan*1,
Matthew Tancik*1,
Jonathan T. Barron2,
Ravi Ramamoorthi3,
Ren Ng1
1UC Berkeley, 2Google Research, 3UC San Diego
*denotes equal contribution
A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
Quantization Aware Training for Static Quantization
Diagram:
# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
/
linear_weight_fp32
# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
/
linear_weight_fp32 -- fq
# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
/
linear_weight_int8
Quantization Aware Training (QAT) models the effects of quantization during training allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization. During training, all calculations are done in floating point, with fake_quant modules modeling the effects of quantization by clamping and rounding to simulate the effects of INT8. After model conversion, weights and activations are quantized, and activations are fused into the preceding layer where possible.
@misc{lin2020nerfpytorch,
title={NeRF-pytorch},
author={Yen-Chen, Lin},
publisher = {GitHub},
journal = {GitHub repository},
howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
year={2020}
}