From 074e35c80194d2a42156ae9b11af8aad0de3f568 Mon Sep 17 00:00:00 2001 From: zzhang Date: Tue, 4 Aug 2020 15:36:47 +0800 Subject: [PATCH] first commit --- .gitignore | 1 + READM.md | 120 +++++++++ Result/Detail/CoCA_GICD_retest.pth | Bin 0 -> 8940 bytes Result/result.txt | 1 + dataloader.py | 33 +++ eval.sh | 1 + evaluator.py | 382 +++++++++++++++++++++++++++++ main.py | 46 ++++ plot_curve.sh | 1 + 9 files changed, 585 insertions(+) create mode 100644 .gitignore create mode 100644 READM.md create mode 100644 Result/Detail/CoCA_GICD_retest.pth create mode 100644 Result/result.txt create mode 100644 dataloader.py create mode 100755 eval.sh create mode 100644 evaluator.py create mode 100644 main.py create mode 100755 plot_curve.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2812bcc --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*__pycache__ diff --git a/READM.md b/READM.md new file mode 100644 index 0000000..e4693e5 --- /dev/null +++ b/READM.md @@ -0,0 +1,120 @@ + +
+

+ + Logo + + +

PyTorch-Based Evaluation Tool for Co-Saliency Detection

+ +

+ Automatically evaluate 8 metrics and draw 3 types of curves +
+ ⭐ Project Home » +
+

+

+ + +*** +**Eval Co-SOD** is an extended version of [Evaluate-SOD](https://github.com/Hanqer/Evaluate-SOD) for **co-saliency detection task**. +It provides eight metrics and three curves: +* Metrics: + * Mean Absolute Error (MAE) + * Maximum F-measure (max-Fm) + * Mean F-measure (mean-Fm) + * Maximum E-measure (max-Em) + * Mean E-measure (mean-Em) + * S-measure (Sm) + * Average Precision (AP) + * Area Under Curve (AUC) +* Curves: + * Precision-Recall (PR) curve + * Receiver Operating Characteristic (ROC) curve + * F-measure curve + + +evaluation tool for co-saliency detection methods. + +## Prerequisites +* PyTorch >= 1.0 + + +## Usage + +### 1. Prepare your data +The structure of `root_dir` should be organized as follows: +``` +. +├── gt +│   ├── dataset1 +│   │   ├── accordion +│   │   │   ├── 51499.png +│   │   │   └── 186605.png +│   │   └── alarm clock +│   │      ├── 51499.png +│   │      └── 186605.png +│   ├── dataset2 ... +│   └── dataset3 ... +└── pred + └── method1 + │   ├── dataset1 + │   │   ├── accordion + │   │   │   ├── 51499.png + │   │   │   └── 186605.png + │   │   └── alarm clock + │   │      ├── 51499.png + │   │      └── 186605.png + │   ├── dataset2 .. + │   └── dataset3 ... + └──method2 ... +``` + +### 2. Evaluate on the 8 metrices +1. Configure `eval.sh` +```shell +--methods method1+method2+method3 (Multiple items are connected with '+') +--datasets dataset1+dataset2+dataset3 +--save_dir ../Result (Path to save results) +--root_dir ../SalMaps +``` + +2. Run by +``` +sh eval.sh +``` + +### 2. Draw the 3 types of curves +1. Configure `plot_curve.sh` +```shell +--methods method1+method2+method3 (Multiple items are connected with '+') +--datasets dataset1+dataset2+dataset3 +--out_dir ../Result/Curves (Path to save results) +--res_dir ../Result +``` + +2. Run by +``` +sh plot_curve.sh +``` + +## Citation +If you find this tool is useful for your research, please cite the following papers. +``` +@inproceedings{zhang2020gicd, + title={Gradient-Induced Co-Saliency Detection}, + author={Zhang, Zhao and Jin, Wenda and Xu, Jun and Cheng, Ming-Ming}, + booktitle={European Conference on Computer Vision (ECCV)}, + year={2020} +} + +@inproceedings{fan2020taking, + title={Taking a Deeper Look at the Co-salient Object Detection}, + author={Fan, Deng-Ping and Lin, Zheng and Ji, Ge-Peng and Zhang, Dingwen and Fu, Huazhu and Cheng, Ming-Ming}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2020} +} +``` + +## Contact +If you have any questions, feel free to contact me via `zzhang🥳mail😲nankai😲edu😲cn` \ No newline at end of file diff --git a/Result/Detail/CoCA_GICD_retest.pth b/Result/Detail/CoCA_GICD_retest.pth new file mode 100644 index 0000000000000000000000000000000000000000..dd9a22257ae78644b7be2fdf03731264690c5392 GIT binary patch literal 8940 zcmZ{q*;7>6lgF#BJFcxLF74WlySP*pDsU@V1QZmta6wx|L|RZlE?~FX+Jdt0`zoNI zpaP-@DhN1{`A^K_?`0<97cuiT^Dq$;@tf~CmyUUu7$f+cJbCi&_n!RbDGQ7Wzs&!B z9`j#YZofIRE~vmb_4hg8W6QBKX@$e#$jQB%e>e4RX7X=oxp^6Pvn?kAl9pm5Gb2Ag zGc7qS`({Q;wq=wh`N5h0ASW$3FXLfao;ByOHFsyyJUH{x@8;%P^Wu%9Ip}3%=UemR z=O)cn`?#BZ%UTdW_ujp~j;SI}MMPW9So*$edpLS}7pGG0$7ESXEdA}`zia1w6_=Kh z4euOyH<9)l@J5`m=B8%f%gT8WlzKNeEhy_=W`0IWZf?p0zr56x%#>U|Yo0yscT=6N zpEW;8ZTIUDoQWgeS!ylxS&MeY&q;E^8)o?>BqYTB1OGJQ9r@Pc8>z{ucWm%(`T>CZEdu#Ym?8~jO$AOzgM*7MmhGDgNL%gD* z6L?J;K}*I!dAdR-yz34C5plqdUjZq2tT@aqXdYu5bUsQYnoH6Tnlb|>_{*eM2KIvd zDFRGR0+wc~eg!{MO(FIG7Y~64rBJzsr4`Lqd9ON+4*BN&-5-!IyFe}oQT9v%n4Jt9 z%2)jfE<(dS+^U*(ev2lEzkv&+7Thi`z{NO61UDHn<&9YhNOK^_r7mUn$0&B6Q=Wle z!L;9jbzBOrk;mz1+46G=2hN4*`23 z0jW#`eJKEse^N>ry24+`Nk@?7;k%K(}f83$K69mH=IIB5AY2T9>72PLnEXLt+H8t9-+ z`y7;X&_R($9JJPSP<)hwPM>hl_R|PxI1hZhj132pVfCgtXw7XHFLE4o_8#0Ve>mtc z7s6`i5*5LfuwQYVgVuA4gYI!B4#6)Sbb$Ny$RPNdKRU?CpJ5L2HwWF80+?M=41N*p z0oPmaAiNz83Y6#Q7YcTfCfG%YU=O%Z>>*3Ahb1zFtp*%UvgSKU#uhs%{|6^==^CJF zJzSr*z^dF16zp-*iGwh{dEoyR0&6rJs5|DQ@Hi*^m;iTWA~1OoX#d4Y8P_rXCJnvn z+fGWl3*&trQ1Af$q9?#G7dhz=m%(h{Y78}TJ?vg?and?|4zrtI!rjUJ;AegZe&7+X zoyWj3op2a@SSucqCs?D33c3)WR>|=;eOcO*~@IGX4Tb1Ahm{rPQDUf?zyb#HGoYLEG}@lC@P$ z-rIcEPx;o*H+0JIr6EmaT7l{Q>3X8yo{ouqhtJwMJF)Ld->hb9hf1vuE})0G$0+{F zG2*GXV`PTH7}k!YvrX*PC-F z5KeQXJDlc8Gn|D|8BTx7({Q|fD(D>}=qE-H-pS5z^5>>-GPoQ=ubza{!+YU$JQJPz z8^FlLaC)48(3TkV+C#$W@}Y1_-3d?qW?Y|9~fm`1ej64B|?L@T0*{tQ9L z=OB!B1p?*kh@30n8lOva9ioG{MBaNzzDl2$_DLHYeNqEQvpn@u08e-+hKIehm3uKX zz%AfAt^n(~z)QLLUJAeMCI4$)O7UT{!5E-CM8)n0O1Gig6yT*lmUszS%_QXaPhgGo zn|KQH95$}g#JpZ;(gq$w{~Kd$D|eXqI;u>X`=?1~E%=*LO*)ihQtokV`o;u`_nUNe z3ykTN82>Z}XqV4nv|HYUVTzUZFbbFIFp?Rb3Zn%60JEGsG1kbnVRYwDSX0?ybUzu! zSc2-AK;PalO5Yepk;}qpgUo~?T;7M0DP5s>nWQ$9q8Q#Y3@2u-?ois!Rfu}~FqDp` z!+3WA1Fezh?T0b337A;~e~nCp;1LgmU?z~}5H){4g+0O}A#|R*LdeIZ=nq*T6n8Cz zmdC0PcL-%}455rg=#0wfQCcXSM+rKn=qMfKp`*xy+z1vufOqT~Fdc<%$qu00?o@ac)@!K?4N`=~RG^dfK}|%c&)JmEr1z%9-{Cu2FgiDixdw z2cs_ncQSy)#{e{HHJS`)GjygPc4xfXq!{$aR5s)FVFJUZ6Uh5B-tCK&i0KJX9N7Uz zdjM%v`Y13F*m^_t+j2BOfwy%b(5a+R>uLvX^#-V)tqyv!LwWW(D8~h3;4p?tybcPQ6-SWJ4mx!X zL&XqBd#^(cK^pN4w9%2f4)RyhsQo_B`6q-IKXc%_E_2WYu0~V^H-K&2hVdbWx>?V? zYHR@BVOHv9Q0b$KJno>!Jn6v4I^&>ig1tacmO7vfsvP+4v6p$$q7)G{7L;d?GG969 zzPv?boeY5`N++#UItiI=!l6D_#5Opc#6Tt<)m#~3&$&* zG?_|M`7wmhVCALCv z^xImEmC0KiQgHD50*QB7MzODQ!6~4|!a0P(Q&;V9;76;ifjB*RloZhnQROGmP3m_$PK4 zbZ)mn!FwUUl)S6l4}3gekna$TVV6NWgCW6=800x>&`~A&zK5yZcnw+|j$Tg$dfkyw zhfy%T9Rn(%2)Q-}n;bXD3uP$NaoFKWgBG5`NN)l>ou>`DdKTKyhfZ}O)ZjUT)}1%V zxB$m1;O0exA|NGqUot5FGTL7-@bn6_>Q%LqYrxB7gDzZ$2Y>1KyOLti6a2|DC{KBs zYS5mW2K}4{`H4U0Eo|~0$bu%_l!1s!7~G<)*TAFOxTH+w26!e*jVX-YHRu-$##FvR zx0MXFTxz;>`z%ks6?{Xl#zdTtvJS`6-(URmKS51u*SRx3>&RbI=utI=hWM<|*(o&a zf5oSsf$h0i&B}UyRui?JvF*v*p2yXUuIF|=%j@}G&-`{0(D}el4LVEM2}7k0l|pnb zQAtH57oBBvzR{V-PC~N@Nhc>ePwAAUvzN|eI-#k|rgNN5b}H-X{AZ^^ofUO{v{R+d znsx%!xl|=pm0Wd}wezjYygCJ|T&$C_PS19#wzIZQ;5wJvY2D87I@#;IuN8on0!kBT zb)aN|k_%clC;_2_gjN(Yb~x-xvkHY zQrD7QX?I)2D?QIj*(-^!q`s2;O8V;>KotYJD$w>@+g8@lq) zHHck~*rkcCQgpqdD;HhEsA5J}HLAQ(WsWL$blsy0AXNzIib$14szg$ylCGIl@uaIL zyP%pau5_WLYc5@V*@c*0lj-uzuGDnVrt3E~#m*LWvjv`A?CHW!*MO=RR7Ih#4t1TV z3PxQy>Kam4ld7y#<)yAOb-}4BPhErRY7}=2h3RaOstZ=TcGYF9UDE36R&=3j7rk}? ztZQM22b9ILC9L(jJBuF|3KY-5sGfY#A?ZFlo0pU$WxQ6(!f>R4?@;-sHSR}JTv+8QP6u^ znKhV`(2VdT8l3b?G>_&=k$G$K_ApAASO*Q@3^R!h+ILoF4VJD?7W%#F*X*o@f+ zAiUU=igYlGi$OOuZ?Tz!1q%3^aSFxeMSSJxd{`+ZX4nCMt75@*w}4m()Dnzg$|*6g z@Hn^&rF@B*hV{-8^Wq#SHRa7lklT(ZzB&v3mId74ry%Y@H=489wA9SNN)KWt@R66A zK4@=@e_AVLCYK%rSH}TTc@y;fp}f6VLxZOd!Vt|Q#BiCJgXO_8^Ewtl;d-%H%FUE* zKx8N&qvsU+?|@v3wbF9a4`*I(uEzQ%Tr;@oQJAl==PiE>l zfD1FhC~RM0hM;h(FypcOT45f><`w29+_VZ)nvlZKe4HSj6Aa{0t#hX}a*JD$Rg-Qf11Qjo_L?0Dp}KgRU!Y?*s58*D1S;2hi|) z?4Zg_!qKbD6%cV%=Ds;nZT`Likam}1M=XdH=4vxPPxXhmN;MOB^Q-YtCc*jApjryA15h1Ed=%0L>dL-lFNp-K;SeA^X>u@yN_Ic!eg_ zKIL)GHePJ&2jHL4lKY1EqM@j4)Ee6`e>dxL;S z(SVd)1-aojFai1j_IKO~47|oZ>df6J!0Jq73UKFKsW%^N1bEN|W+njEEkGKcfP3*O zU~CA_5=}LxiF$J@v;fAxBKg#tj+N2?dpEc*3dp?-aQi*QD%8jg<{$WU8_YW>#Tv{9 zc!~&D??Hp^=`Z`W@gIZbP#hC)94P#+|_E>jlzbF4_;QISX8V04(4x z5ECJ0Ruj|zpz2kukxp|xB)w_g0@h%1?KI^XQfVh1z%a<)Hej_!nq$E050u%6r~F(^ z5Z$UNT_s)kj8B5;*lU;Rhc^eF!d233p1ld|#>CWZ+UX)jrHfdfbv)lXapSLFHKVhv zxLEqve~aCslvn*a3D3#D63Z!-SQ32J>Dk0`=6{zpoLCU{LMApJ zqAcz^M7wz85Y|X150O;KvqOZ0Q*{V+8^+H_7y3i;T6>4U78yg_n5gJtrwamVp^MU% zyXf>97hU_&MIqbZY~1Oh3qdfRxq)v-fldMjkGW8W#kpwp85bp9aM7wOF4~iVk(o50 z=Z=ducZhG*O zo38J0i1))$mOPEhust&>Lwq-TM-3}$GB-rJdB0}H{JEQDd&QlE?;ufnXA~Q z@4A}~-GuS!mYYuBfzfo=P0l0XGHl5O&tX@4+e_g?p68aRxjAbCRdrlq3Zh z8^hD3#Zu~~OH$#+LU@fE>n!y+2WiHkr48Zz(uqhcJ-I1EdbQ)V8$W2hh51o%oMrOC zjfH|wZu&vKAhc7|QO0H3O=%7fR%GUQXzv0T1%5!)QV*^D!GoV|{L$+VfYq}B7~2F) zZS~NvfiQ-5cqna;ht3C~^W~t2_INzB3q*b6J`Wbd`eF9* zTet^#2yEsN4;JS~J(R-Z7^~#3@Q&e^JltoV0pCa={0&m1Jc7ew!l?v@q{f5Q85}lO zaMUEhQF8=G#pJ0DTO&9u5;+e0lMHxhfxN?5p$y}0$OjBH$|rDIK6@xqCOvq<(+H?> z1k>8N!4$I~7*E767*Av=@O4En?O6?DG9Z{9Z@^g3reJ!s4W8;iMDGfw`TK%t^8q+M zx`HX=FpP$fV7f|(?T8GfCo#cvGY*|M2|$|<;}sWzDfAZ@U#`JbmKsbu($VY445qkT z7`?v*)5AxI;itir&E>(gf*XRdQq~#UkteLgEO-#TZn0w-o!=VR$#jF-mS N=c4{|$yLh{^gm1446y(J literal 0 HcmV?d00001 diff --git a/Result/result.txt b/Result/result.txt new file mode 100644 index 0000000..4f8843a --- /dev/null +++ b/Result/result.txt @@ -0,0 +1 @@ +CoCA (GICD_retest): 0.1217 mae || 0.5107 max-fm || 0.5037 mean-fm || 0.7163 max-Emeasure || 0.7045 mean-Emeasure || 0.6566 S-measure || 0.4237 AP || 0.8009 AUC. diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..72548c2 --- /dev/null +++ b/dataloader.py @@ -0,0 +1,33 @@ +from torch.utils import data +import os +from PIL import Image + + +class EvalDataset(data.Dataset): + def __init__(self, pred_root, label_root): + pred_dirs = os.listdir(pred_root) + label_dirs = os.listdir(label_root) + + dir_name_list = [] + for idir in pred_dirs: + if idir in label_dirs: + pred_names = os.listdir(os.path.join(pred_root, idir)) + label_names = os.listdir(os.path.join(label_root, idir)) + for iname in pred_names: + if iname in label_names: + dir_name_list.append(os.path.join(idir, iname)) + + self.image_path = list( + map(lambda x: os.path.join(pred_root, x), dir_name_list)) + self.label_path = list( + map(lambda x: os.path.join(label_root, x), dir_name_list)) + + def __getitem__(self, item): + pred = Image.open(self.image_path[item]).convert('L') + gt = Image.open(self.label_path[item]).convert('L') + if pred.size != gt.size: + pred = pred.resize(gt.size, Image.BILINEAR) + return pred, gt + + def __len__(self): + return len(self.image_path) diff --git a/eval.sh b/eval.sh new file mode 100755 index 0000000..7176bd6 --- /dev/null +++ b/eval.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python main.py --methods GICD_retest --datasets CoCA --save_dir ./Result --root_dir ../SalMaps diff --git a/evaluator.py b/evaluator.py new file mode 100644 index 0000000..7b63e9f --- /dev/null +++ b/evaluator.py @@ -0,0 +1,382 @@ +import os +import time + +import numpy as np +import torch +from torchvision import transforms + + +class Eval_thread(): + def __init__(self, loader, method, dataset, output_dir, cuda): + self.loader = loader + self.method = method + self.dataset = dataset + self.cuda = cuda + self.output_dir = output_dir + self.logfile = os.path.join(output_dir, 'result.txt') + + def run(self): + Res = {} + start_time = time.time() + mae = self.Eval_mae() + Res['MAE'] = mae + + Fm, prec, recall = self.Eval_fmeasure() + max_f = Fm.max().item() + mean_f = Fm.mean().item() + prec = prec.cpu().numpy() + recall = recall.cpu().numpy() + avg_p = self.Eval_AP(prec, recall) # AP + Fm = Fm.cpu().numpy() + Res['MaxFm'] = max_f + Res['MeanFm'] = mean_f + Res['AP'] = avg_p + Res['Prec'] = prec + Res['Recall'] = recall + Res['Fm'] = Fm + + auc, TPR, FPR = self.Eval_auc() + TPR = TPR.cpu().numpy() + FPR = FPR.cpu().numpy() + + Res['AUC'] = auc + Res['TPR'] = TPR + Res['FPR'] = FPR + + Em = self.Eval_Emeasure() + max_e = Em.max().item() + mean_e = Em.mean().item() + Em = Em.cpu().numpy() + Res['MaxEm'] = max_e + Res['MeanEm'] = mean_e + Res['Em'] = Em + + s = self.Eval_Smeasure() + Res['Sm'] = s + os.makedirs(os.path.join(self.output_dir, 'Detail'), exist_ok=True) + torch.save( + Res, + os.path.join(self.output_dir, 'Detail', + self.dataset + '_' + self.method + '.pth')) + + self.LOG( + '{} ({}): {:.4f} mae || {:.4f} max-fm || {:.4f} mean-fm || {:.4f} max-Emeasure || {:.4f} mean-Emeasure || {:.4f} S-measure || {:.4f} AP || {:.4f} AUC.\n' + .format(self.dataset, self.method, mae, max_f, mean_f, max_e, + mean_e, s, avg_p, auc)) + return '[cost:{:.4f}s] {} ({}): {:.4f} mae || {:.4f} max-fm || {:.4f} mean-fm || {:.4f} max-Emeasure || {:.4f} mean-Emeasure || {:.4f} S-measure || {:.4f} AP || {:.4f} AUC.'.format( + time.time() - start_time, self.dataset, self.method, mae, max_f, + mean_f, max_e, mean_e, s, avg_p, auc) + + def Eval_mae(self): + print('eval[MAE]:{} dataset with {} method.'.format( + self.dataset, self.method)) + avg_mae, img_num = 0.0, 0.0 + with torch.no_grad(): + trans = transforms.Compose([transforms.ToTensor()]) + for pred, gt in self.loader: + if self.cuda: + pred = trans(pred).cuda() + gt = trans(gt).cuda() + else: + pred = trans(pred) + gt = trans(gt) + mea = torch.abs(pred - gt).mean() + if mea == mea: # for Nan + avg_mae += mea + img_num += 1.0 + avg_mae /= img_num + return avg_mae.item() + + def Eval_fmeasure(self): + print('eval[FMeasure]:{} dataset with {} method.'.format( + self.dataset, self.method)) + beta2 = 0.3 + avg_f, avg_p, avg_r, img_num = 0.0, 0.0, 0.0, 0.0 + + with torch.no_grad(): + trans = transforms.Compose([transforms.ToTensor()]) + for pred, gt in self.loader: + if self.cuda: + pred = trans(pred).cuda() + gt = trans(gt).cuda() + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + else: + pred = trans(pred) + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt) + prec, recall = self._eval_pr(pred, gt, 255) + f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) + f_score[f_score != f_score] = 0 # for Nan + avg_f += f_score + avg_p += prec + avg_r += recall + img_num += 1.0 + Fm = avg_f / img_num + avg_p = avg_p / img_num + avg_r = avg_r / img_num + return Fm, avg_p, avg_r + + def Eval_auc(self): + print('eval[AUC]:{} dataset with {} method.'.format( + self.dataset, self.method)) + + avg_tpr, avg_fpr, avg_auc, img_num = 0.0, 0.0, 0.0, 0.0 + + with torch.no_grad(): + trans = transforms.Compose([transforms.ToTensor()]) + for pred, gt in self.loader: + if self.cuda: + pred = trans(pred).cuda() + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt).cuda() + else: + pred = trans(pred) + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt) + TPR, FPR = self._eval_roc(pred, gt, 255) + avg_tpr += TPR + avg_fpr += FPR + img_num += 1.0 + avg_tpr = avg_tpr / img_num + avg_fpr = avg_fpr / img_num + + sorted_idxes = torch.argsort(avg_fpr) + avg_tpr = avg_tpr[sorted_idxes] + avg_fpr = avg_fpr[sorted_idxes] + avg_auc = torch.trapz(avg_tpr, avg_fpr) + + return avg_auc.item(), avg_tpr, avg_fpr + + def Eval_Emeasure(self): + print('eval[EMeasure]:{} dataset with {} method.'.format( + self.dataset, self.method)) + avg_e, img_num = 0.0, 0.0 + with torch.no_grad(): + trans = transforms.Compose([transforms.ToTensor()]) + Em = torch.zeros(255) + if self.cuda: + Em = Em.cuda() + for pred, gt in self.loader: + if self.cuda: + pred = trans(pred).cuda() + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt).cuda() + else: + pred = trans(pred) + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt) + Em += self._eval_e(pred, gt, 255) + img_num += 1.0 + + Em /= img_num + return Em + + def Eval_Smeasure(self): + print('eval[SMeasure]:{} dataset with {} method.'.format( + self.dataset, self.method)) + alpha, avg_q, img_num = 0.5, 0.0, 0.0 + with torch.no_grad(): + trans = transforms.Compose([transforms.ToTensor()]) + for pred, gt in self.loader: + if self.cuda: + pred = trans(pred).cuda() + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt).cuda() + else: + pred = trans(pred) + pred = (pred - torch.min(pred)) / (torch.max(pred) - + torch.min(pred) + 1e-20) + gt = trans(gt) + y = gt.mean() + if y == 0: + x = pred.mean() + Q = 1.0 - x + elif y == 1: + x = pred.mean() + Q = x + else: + gt[gt >= 0.5] = 1 + gt[gt < 0.5] = 0 + Q = alpha * self._S_object( + pred, gt) + (1 - alpha) * self._S_region(pred, gt) + if Q.item() < 0: + Q = torch.FloatTensor([0.0]) + img_num += 1.0 + avg_q += Q.item() + avg_q /= img_num + return avg_q + + def LOG(self, output): + with open(self.logfile, 'a') as f: + f.write(output) + + def _eval_e(self, y_pred, y, num): + if self.cuda: + score = torch.zeros(num).cuda() + thlist = torch.linspace(0, 1 - 1e-10, num).cuda() + else: + score = torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_pred_th = (y_pred >= thlist[i]).float() + fm = y_pred_th - y_pred_th.mean() + gt = y - y.mean() + align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) + enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 + score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) + return score + + def _eval_pr(self, y_pred, y, num): + if self.cuda: + prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() + thlist = torch.linspace(0, 1 - 1e-10, num).cuda() + else: + prec, recall = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + + 1e-20) + return prec, recall + + def _eval_roc(self, y_pred, y, num): + if self.cuda: + TPR, FPR = torch.zeros(num).cuda(), torch.zeros(num).cuda() + thlist = torch.linspace(0, 1 - 1e-10, num).cuda() + else: + TPR, FPR = torch.zeros(num), torch.zeros(num) + thlist = torch.linspace(0, 1 - 1e-10, num) + for i in range(num): + y_temp = (y_pred >= thlist[i]).float() + tp = (y_temp * y).sum() + fp = (y_temp * (1 - y)).sum() + tn = ((1 - y_temp) * (1 - y)).sum() + fn = ((1 - y_temp) * y).sum() + + TPR[i] = tp / (tp + fn + 1e-20) + FPR[i] = fp / (fp + tn + 1e-20) + + return TPR, FPR + + def _S_object(self, pred, gt): + fg = torch.where(gt == 0, torch.zeros_like(pred), pred) + bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred) + o_fg = self._object(fg, gt) + o_bg = self._object(bg, 1 - gt) + u = gt.mean() + Q = u * o_fg + (1 - u) * o_bg + return Q + + def _object(self, pred, gt): + temp = pred[gt == 1] + x = temp.mean() + sigma_x = temp.std() + score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) + + return score + + def _S_region(self, pred, gt): + X, Y = self._centroid(gt) + gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y) + p1, p2, p3, p4 = self._dividePrediction(pred, X, Y) + Q1 = self._ssim(p1, gt1) + Q2 = self._ssim(p2, gt2) + Q3 = self._ssim(p3, gt3) + Q4 = self._ssim(p4, gt4) + Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4 + return Q + + def _centroid(self, gt): + rows, cols = gt.size()[-2:] + gt = gt.view(rows, cols) + if gt.sum() == 0: + if self.cuda: + X = torch.eye(1).cuda() * round(cols / 2) + Y = torch.eye(1).cuda() * round(rows / 2) + else: + X = torch.eye(1) * round(cols / 2) + Y = torch.eye(1) * round(rows / 2) + else: + total = gt.sum() + if self.cuda: + i = torch.from_numpy(np.arange(0, cols)).cuda().float() + j = torch.from_numpy(np.arange(0, rows)).cuda().float() + else: + i = torch.from_numpy(np.arange(0, cols)).float() + j = torch.from_numpy(np.arange(0, rows)).float() + X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20) + Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20) + return X.long(), Y.long() + + def _divideGT(self, gt, X, Y): + h, w = gt.size()[-2:] + area = h * w + gt = gt.view(h, w) + LT = gt[:Y, :X] + RT = gt[:Y, X:w] + LB = gt[Y:h, :X] + RB = gt[Y:h, X:w] + X = X.float() + Y = Y.float() + w1 = X * Y / area + w2 = (w - X) * Y / area + w3 = X * (h - Y) / area + w4 = 1 - w1 - w2 - w3 + return LT, RT, LB, RB, w1, w2, w3, w4 + + def _dividePrediction(self, pred, X, Y): + h, w = pred.size()[-2:] + pred = pred.view(h, w) + LT = pred[:Y, :X] + RT = pred[:Y, X:w] + LB = pred[Y:h, :X] + RB = pred[Y:h, X:w] + return LT, RT, LB, RB + + def _ssim(self, pred, gt): + gt = gt.float() + h, w = pred.size()[-2:] + N = h * w + x = pred.mean() + y = gt.mean() + sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20) + sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20) + sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20) + + aplha = 4 * x * y * sigma_xy + beta = (x * x + y * y) * (sigma_x2 + sigma_y2) + + if aplha != 0: + Q = aplha / (beta + 1e-20) + elif aplha == 0 and beta == 0: + Q = 1.0 + else: + Q = 0 + return Q + + def Eval_AP(self, prec, recall): + # Ref: + # https://github.com/facebookresearch/Detectron/blob/05d04d3a024f0991339de45872d02f2f50669b3d/lib/datasets/voc_eval.py#L54 + print('eval[AP]:{} dataset with {} method.'.format( + self.dataset, self.method)) + ap_r = np.concatenate(([0.], recall, [1.])) + ap_p = np.concatenate(([0.], prec, [0.])) + sorted_idxes = np.argsort(ap_r) + ap_r = ap_r[sorted_idxes] + ap_p = ap_p[sorted_idxes] + count = ap_r.shape[0] + + for i in range(count - 1, 0, -1): + ap_p[i - 1] = max(ap_p[i], ap_p[i - 1]) + + i = np.where(ap_r[1:] != ap_r[:-1])[0] + ap = np.sum((ap_r[i + 1] - ap_r[i]) * ap_p[i + 1]) + return ap diff --git a/main.py b/main.py new file mode 100644 index 0000000..f786405 --- /dev/null +++ b/main.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import argparse +import os.path as osp +import os +from evaluator import Eval_thread +from dataloader import EvalDataset + + +def main(cfg): + root_dir = cfg.root_dir + if cfg.save_dir is not None: + output_dir = cfg.save_dir + else: + output_dir = root_dir + gt_dir = osp.join(root_dir, 'gt') + pred_dir = osp.join(root_dir, 'pred') + if cfg.methods is None: + method_names = os.listdir(pred_dir) + else: + method_names = cfg.methods.split('+') + if cfg.datasets is None: + dataset_names = os.listdir(gt_dir) + else: + dataset_names = cfg.datasets.split('+') + + threads = [] + for dataset in dataset_names: + for method in method_names: + loader = EvalDataset(osp.join(pred_dir, method, dataset), + osp.join(gt_dir, dataset)) + thread = Eval_thread(loader, method, dataset, output_dir, cfg.cuda) + threads.append(thread) + for thread in threads: + print(thread.run()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--methods', type=str, default=None) + parser.add_argument('--datasets', type=str, default=None) + parser.add_argument('--root_dir', type=str, default='./') + parser.add_argument('--save_dir', type=str, default=None) + parser.add_argument('--cuda', type=bool, default=True) + config = parser.parse_args() + main(config) diff --git a/plot_curve.sh b/plot_curve.sh new file mode 100755 index 0000000..534bfb9 --- /dev/null +++ b/plot_curve.sh @@ -0,0 +1 @@ +python plot_curve.py --methods GICD_retest --datasets CoCA --out_dir ./Result/Curves --res_dir ./Result