forked from facebookresearch/c3dpo_nrsfm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
65 lines (46 loc) · 1.95 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
from dataset.dataset_zoo import dataset_zoo
from dataset.eval_zoo import eval_zoo
from experiment import init_model_from_dir
from tools.model_io import download_model
from tools.cache_preds import cache_preds
from tabulate import tabulate
def eval_model(dataset_name):
model_dir = download_model(dataset_name, force_download=False)
model, _ = init_model_from_dir(model_dir)
model.eval()
_, _, dataset_test = dataset_zoo(
dataset_name=dataset_name, sets_to_load=('val',),
force_download=False)
loader_test = torch.utils.data.DataLoader(dataset_test,
num_workers=8, pin_memory=True,
batch_size=1024, shuffle=False)
eval_script, cache_vars, eval_vars = eval_zoo(dataset_name)
cached_preds = cache_preds(model, loader_test, cache_vars=cache_vars)
results, _ = eval_script(cached_preds, eval_vars=eval_vars)
return results
if __name__ == '__main__':
results = {}
for dataset in ('h36m', 'h36m_hourglass', 'pascal3d_hrnet',
'pascal3d', 'up3d_79kp'):
results[dataset] = eval_model(dataset)
print('\n\nRESULTS:')
tab_rows = []
for dataset, result in results.items():
tab_row = [dataset]
tab_row.extend([result[m] for m in ('EVAL_MPJPE_best', 'EVAL_stress')])
tab_rows.append(tab_row)
print(tabulate(tab_rows, headers=['dataset', 'MPJPE', 'Stress']))
# RESULTS:
# dataset MPJPE Stress
# -------------- ----------- ----------
# h36m 95.6338 41.5864
# h36m_hourglass 145.021 84.693
# pascal3d_hrnet 56.8909 40.1775
# pascal3d 36.6413 31.0768
# up3d_79kp 0.0672771 0.0406902