Skip to content

Commit

Permalink
add plot
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhanghub committed Aug 4, 2020
1 parent 074e35c commit 54dc4da
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
*__pycache__

/Result
13 changes: 7 additions & 6 deletions READM.md → README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<h3 align="center">PyTorch-Based Evaluation Tool for Co-Saliency Detection</h3>

<p align="center">
Automatically evaluate 8 metrics and draw 3 types of curves
Automatically evaluate 8 metrics and draw 4 types of curves
<br />
<a href="https://zhaozhang.net/coca.html"><strong>⭐ Project Home »</strong></a>
<br />
Expand All @@ -18,7 +18,7 @@

***
**Eval Co-SOD** is an extended version of [Evaluate-SOD](https://github.com/Hanqer/Evaluate-SOD) for **co-saliency detection task**.
It provides eight metrics and three curves:
It provides eight metrics and four curves:
* Metrics:
* Mean Absolute Error (MAE)
* Maximum F-measure (max-Fm)
Expand All @@ -32,6 +32,7 @@ It provides eight metrics and three curves:
* Precision-Recall (PR) curve
* Receiver Operating Characteristic (ROC) curve
* F-measure curve
* E-measure curve


evaluation tool for co-saliency detection methods.
Expand Down Expand Up @@ -75,7 +76,7 @@ The structure of `root_dir` should be organized as follows:
```shell
--methods method1+method2+method3 (Multiple items are connected with '+')
--datasets dataset1+dataset2+dataset3
--save_dir ../Result (Path to save results)
--save_dir ./Result (Path to save results)
--root_dir ../SalMaps
```

Expand All @@ -84,13 +85,13 @@ The structure of `root_dir` should be organized as follows:
sh eval.sh
```

### 2. Draw the 3 types of curves
### 2. Draw the 4 types of curves
1. Configure `plot_curve.sh`
```shell
--methods method1+method2+method3 (Multiple items are connected with '+')
--datasets dataset1+dataset2+dataset3
--out_dir ../Result/Curves (Path to save results)
--res_dir ../Result
--out_dir ./Result/Curves (Path to save results)
--res_dir ./Result/Detail
```

2. Run by
Expand Down
Binary file removed Result/Detail/CoCA_GICD_retest.pth
Binary file not shown.
1 change: 0 additions & 1 deletion Result/result.txt

This file was deleted.

2 changes: 1 addition & 1 deletion eval.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
CUDA_VISIBLE_DEVICES=0 python main.py --methods GICD_retest --datasets CoCA --save_dir ./Result --root_dir ../SalMaps
CUDA_VISIBLE_DEVICES=0 python main.py --methods GICD_retest --datasets CoCA --save_dir ./Result/Detail --root_dir ../SalMaps
173 changes: 173 additions & 0 deletions plot_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import torch

styles = ['.-r', '.--b', '.--g', '.--c', '.-m', '.-y', '.-k', '.-c']
lines = ['-', '--', '--', '--', '-', '-', '-', '-']
points = ['*', '.', '.', '.', '.', '.', '.', '.']
colors = ['r', 'b', 'g', 'c', 'm', 'orange', 'k', 'navy']


def main(cfg):
method_names = cfg.methods.split('+')
dataset_names = cfg.datasets.split('+')
os.makedirs(cfg.out_dir, exist_ok=True)
# plt.style.use('seaborn-white')

# Plot PR Cureve
for dataset in dataset_names:
plt.figure()
idx_style = 0
for method in method_names:
iRes = torch.load(
os.path.join(cfg.res_dir, dataset + '_' + method + '.pth'))
imax = np.argmax(iRes['Fm'])
plt.plot(
iRes['Prec'],
iRes['Recall'],
# styles[idx_style],
color=colors[idx_style],
linestyle=lines[idx_style],
marker=points[idx_style],
markevery=[imax, imax],
label=method)
idx_style += 1

plt.grid(True, zorder=-1)
# plt.xlim(0, 1)
# plt.ylim(0, 1.02)
plt.ylabel('Precision', fontsize=25)
plt.xlabel('Recall', fontsize=25)

plt.legend(loc='lower left', prop={'size': 15})
plt.savefig(os.path.join(cfg.out_dir, 'PR_' + dataset + '.png'),
dpi=600,
bbox_inches='tight')
plt.close()

# Plot Fm Cureve
for dataset in dataset_names:
plt.figure()
idx_style = 0
for method in method_names:
iRes = torch.load(
os.path.join(cfg.res_dir, dataset + '_' + method + '.pth'))
imax = np.argmax(iRes['Fm'])
plt.plot(
np.arange(0, 255),
iRes['Fm'],
# styles[idx_style],
color=colors[idx_style],
linestyle=lines[idx_style],
marker=points[idx_style],
label=method,
markevery=[imax, imax])
idx_style += 1
plt.grid(True, zorder=-1)
# plt.ylim(0, 1)
plt.ylabel('F-measure', fontsize=25)
plt.xlabel('Threshold', fontsize=25)

plt.legend(loc='lower left', prop={'size': 15})
plt.savefig(os.path.join(cfg.out_dir, 'Fm_' + dataset + '.png'),
dpi=600,
bbox_inches='tight')
plt.close()

# Plot Em Cureve
for dataset in dataset_names:
plt.figure()
idx_style = 0
for method in method_names:
iRes = torch.load(
os.path.join(cfg.res_dir, dataset + '_' + method + '.pth'))
imax = np.argmax(iRes['Em'])
plt.plot(
np.arange(0, 255),
iRes['Em'],
# styles[idx_style],
color=colors[idx_style],
linestyle=lines[idx_style],
marker=points[idx_style],
label=method,
markevery=[imax, imax])
idx_style += 1
plt.grid(True, zorder=-1)
plt.ylim(0, 1)
plt.ylabel('E-measure', fontsize=16)
plt.xlabel('Threshold', fontsize=16)

plt.legend(loc='lower left', prop={'size': 15})
plt.savefig(os.path.join(cfg.out_dir, 'Em_' + dataset + '.png'),
dpi=600,
bbox_inches='tight')
plt.close()

# Plot ROC Cureve
for dataset in dataset_names:
plt.figure()
idx_style = 0
for method in method_names:
iRes = torch.load(
os.path.join(cfg.res_dir, dataset + '_' + method + '.pth'))
imax = np.argmax(iRes['Fm'])
plt.plot(
iRes['FPR'],
iRes['TPR'],
# styles[idx_style][1:],
color=colors[idx_style],
linestyle=lines[idx_style],
label=method)
idx_style += 1

plt.grid(True, zorder=-1)
plt.xlim(0, 1)
plt.ylim(0, 1.02)
plt.ylabel('TPR', fontsize=16)
plt.xlabel('FPR', fontsize=16)

plt.legend(loc='lower right')
plt.savefig(os.path.join(cfg.out_dir, 'ROC_' + dataset + '.png'),
dpi=600,
bbox_inches='tight')
plt.close()

# Plot Sm-MAE
for dataset in dataset_names:
plt.figure()
plt.gca().invert_xaxis()
idx_style = 0
for method in method_names:
iRes = torch.load(
os.path.join(cfg.res_dir, dataset + '_' + method + '.pth'))
plt.scatter(iRes['MAE'],
iRes['Sm'],
marker=points[idx_style],
c=colors[idx_style],
s=120)
plt.annotate(method,
xy=(iRes['MAE'], iRes['Sm']),
xytext=(iRes['MAE'] - 0.001, iRes['Sm'] - 0.001),
fontsize=14)
idx_style += 1

plt.grid(True, zorder=-1)
# plt.xlim(0, 1)
plt.ylim(0, 1)
plt.ylabel('S-measure', fontsize=16)
plt.xlabel('MAE', fontsize=16)
plt.savefig(os.path.join(cfg.out_dir, 'Sm-MAE_' + dataset + '.png'),
bbox_inches='tight')
plt.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--methods', type=str, default=None)
parser.add_argument('--datasets', type=str, default=None)
parser.add_argument('--res_dir', type=str, default='./')
parser.add_argument('--out_dir', type=str, default=None)
config = parser.parse_args()
main(config)
2 changes: 1 addition & 1 deletion plot_curve.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python plot_curve.py --methods GICD_retest --datasets CoCA --out_dir ./Result/Curves --res_dir ./Result
python plot_curve.py --methods GICD_retest --datasets CoCA --out_dir ./Result/Curves --res_dir ./Result/Detail

0 comments on commit 54dc4da

Please sign in to comment.