Haoru Zhao*, Zonghui Guo*, Shishi Qiao, Zhaorui Gu, Bing Zheng, Junyu Dong, Haiyong Zheng
Here we provide the PyTorch implementation of our latest version, if you require the code of our previous ACM MM version ("TransCNN-HAE: Transformer-CNN Hybrid AutoEncoder for Blind Image Inpainting"), please click the released version.
- Linux
- Python 3.7
- NVIDIA GPU + CUDA CuDNN
- Clone this repo:
git clone https://github.com/zhenglab/TransCNN-HAE.git
cd TransCNN-HAE
- Install PyTorch and 1.7 and other dependencies (e.g., torchvision).
- For Conda users, you can create a new Conda environment using
conda create --name <env> --file requirements.txt
.
- For Conda users, you can create a new Conda environment using
- Train our TransCNN-HAE+:
python train.py --path=./checkpoints/config.yml
- Train our TransCNN-HAE+wCIA:
python train.py --path=./checkpoints/configwithCIA.yml
The model is automatically saved every 10,000 iterations, please rename the file g.pth_$iter_number$
to g.pth
and then run testing command.
- Test our TransCNN-HAE+:
python test.py --path=./checkpoints/config.yml
- Test our TransCNN-HAE+wCIA:
python test.py --path=./checkpoints/configwithCIA.yml
- Download pre-trained models from BaiduCloud (access code: 3w72), and put
g.pth
in the directorycheckpoints
.