Skip to content

Implementation of a compression model of NeRF (Neural Radiance Fields) via PyTorch QAT.

License

Notifications You must be signed in to change notification settings

haujinnn/nerf-qat

 
 

Repository files navigation

NeRF-QAT

NeRF (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes.

This project is a NeRF implementation that uses QAT (Quantization Aware Training) to reduce the size by four times while reducing the loss rate by less than 4%. This code is based on the Pytorch implementation here, with the addition of quantization module and model storage. Here are some videos generated by this repository:

Installation

git clone https://github.com/haujinnn/nerf-qat.git
cd nerf-qat
pip install -r requirements.txt

How to Run

Download data for two example datasets: lego and fern

bash download_example_data.sh

To train lego QAT NeRF:

python run_nerf_qat.py --config configs/lego_qat.txt

After training, you can find the following video at logs/lego_qat/lego_test_spiral_100000_rgb.mp4.


To train fern QAT NeRF:

python run_nerf_qat.py --config configs/fern_qat.txt

After training for 50k iterations (~3 hours on Quadro RTX 6000/8000), you can find the following video at logs/fern_qat/fern_qat_spiral_200000_rgb.mp4 and logs/fern_qat/fern_qat_spiral_50000_disp.mp4


To test QAT NeRF:

python run_nerf_qat.py --config configs/{DATASET}_qat.txt

replace {DATASET} with lego | fern | etc.

Code

The process of QAT is as follows.

  1. Add Fake quantization module
  2. Module Fusion
  3. Prepare Quantization
  4. Training Loop
  5. Convert & Save

1. Adding Quantization Modules

To add Fake quantization module at run_nerf_helpers.py class NeRF_qat:

# Model
class NeRF_qat(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
        
        ...
        
        ################
        self.quant = torch.ao.quantization.QuantStub()         # QuantStub converts tensors from floating point to quantized
        self.dequant = torch.ao.quantization.DeQuantStub()     # DeQuantStub converts tensors from quantized to floating point
        ################

    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        
        ################
        h = self.quant(h)
        ################

        ...

        ################
        outputs = self.dequant(outputs)
        ################

        return outputs  

2. Saving Quantized Model

Key mapping is required before model storage because the structure changes during quantization. The key mapping code is as follows.

# Currently modifying

Method

1. NeRF (Neural Radiance Fields)

NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
Ben Mildenhall*1, Pratul P. Srinivasan*1, Matthew Tancik*1, Jonathan T. Barron2, Ravi Ramamoorthi3, Ren Ng1
1UC Berkeley, 2Google Research, 3UC San Diego
*denotes equal contribution

A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views

2. QAT (Quantization Aware Training)

Quantization Aware Training for Static Quantization

Diagram:

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
                      /
    linear_weight_fp32

# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
                           /
   linear_weight_fp32 -- fq

# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
                     /
   linear_weight_int8

Quantization Aware Training (QAT) models the effects of quantization during training allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization. During training, all calculations are done in floating point, with fake_quant modules modeling the effects of quantization by clamping and rounding to simulate the effects of INT8. After model conversion, weights and activations are quantized, and activations are fused into the preceding layer where possible.

Citation

@misc{lin2020nerfpytorch,
  title={NeRF-pytorch},
  author={Yen-Chen, Lin},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
  year={2020}
}

About

Implementation of a compression model of NeRF (Neural Radiance Fields) via PyTorch QAT.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Shell 0.3%