Skip to content

Commit

Permalink
Remove .tar suffices for checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou13 committed Feb 14, 2020
1 parent c044515 commit 49f1b45
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ The following table reports the performance metrics of several wireframe and lin

| | ShanghaiTech (sAP<sup>10</sup>) | ShanghaiTech (AP<sup>H</sup>) | ShanghaiTech (F<sup>H</sup>) | ShanghaiTech (mAP<sup>J</sup>) |
| :--------------------------------------------------: | :--------------------------------: | :-----------------------------: | :----------------------------: | :------------------------------: |
| [LSD](https://ieeexplore.ieee.org/document/4731268/) | / | 52.0 | 61.0 | / |
| [AFM](https://github.com/cherubicXN/afm_cvpr2019) | 24.4 | 69.5 | 77.2 | 23.3 |
| [Wireframe](https://github.com/huangkuns/wireframe) | 5.1 | 67.8 | 72.6 | 40.9 |
| **L-CNN** | **62.9** | **82.8** | **81.2** | **59.3** |
| [LSD](https://ieeexplore.ieee.org/document/4731268/) | / | 52.0 | 61.0 | / |
| [AFM](https://github.com/cherubicXN/afm_cvpr2019) | 24.4 | 69.5 | 77.2 | 23.3 |
| [Wireframe](https://github.com/huangkuns/wireframe) | 5.1 | 67.8 | 72.6 | 40.9 |
| **L-CNN** | **62.9** | **82.8** | **81.2** | **59.3** |

### Precision-Recall Curves
<p align="center">
Expand Down Expand Up @@ -82,8 +82,8 @@ git clone https://github.com/zhou13/lcnn
cd lcnn
conda create -y -n lcnn
source activate lcnn
# Replace cudatoolkit=10.0 with your CUDA version: https://pytorch.org/
conda install -y pytorch cudatoolkit=10.0 -c pytorch
# Replace cudatoolkit=10.1 with your CUDA version: https://pytorch.org/
conda install -y pytorch cudatoolkit=10.1 -c pytorch
conda install -y tensorboardx -c conda-forge
conda install -y pyyaml docopt matplotlib scikit-image opencv
mkdir data logs post
Expand All @@ -94,7 +94,7 @@ mkdir data logs post
You can download our reference pre-trained models from [Google
Drive](https://drive.google.com/file/d/1NvZkEqWNUBAfuhFPNGiCItjy4iU0UOy2). Those models were
trained with `config/wireframe.yaml` for 312k iterations. Use `demo.py`, `process.py`, and
`eval-*.py` to evaluate the pre-trained models. **Do not try to unzip them!**
`eval-*.py` to evaluate the pre-trained models.

### Detect Wireframes for Your Own Images
To test LCNN on your own images, you need download the pre-trained models and execute
Expand Down Expand Up @@ -144,7 +144,7 @@ python ./train.py -d 0 --identifier baseline config/wireframe.yaml
To generate wireframes on the validation dataset with the pretrained model, execute

```bash
./process.py config/wireframe.yaml <path-to-checkpoint.pth.tar> data/wireframe logs/pretrained-model/npz/000312000
./process.py config/wireframe.yaml <path-to-checkpoint.pth> data/wireframe logs/pretrained-model/npz/000312000
```

### Post Processing
Expand Down
10 changes: 5 additions & 5 deletions lcnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,17 @@ def validate(self):
"model_state_dict": self.model.state_dict(),
"best_mean_loss": self.best_mean_loss,
},
osp.join(self.out, "checkpoint_latest.pth.tar"),
osp.join(self.out, "checkpoint_latest.pth"),
)
shutil.copy(
osp.join(self.out, "checkpoint_latest.pth.tar"),
osp.join(npz, "checkpoint.pth.tar"),
osp.join(self.out, "checkpoint_latest.pth"),
osp.join(npz, "checkpoint.pth"),
)
if self.mean_loss < self.best_mean_loss:
self.best_mean_loss = self.mean_loss
shutil.copy(
osp.join(self.out, "checkpoint_latest.pth.tar"),
osp.join(self.out, "checkpoint_best.pth.tar"),
osp.join(self.out, "checkpoint_latest.pth"),
osp.join(self.out, "checkpoint_best.pth"),
)

if training:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def main():
# print("epoch_size (valid):", len(val_loader))

if resume_from:
checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth.tar"))
checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth"))

# 2. model
if M.backbone == "stacked_hourglass":
Expand Down

0 comments on commit 49f1b45

Please sign in to comment.