-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathcheck_fitters.py
executable file
·121 lines (106 loc) · 3.87 KB
/
check_fitters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
"""
Run each fitter on the 3 dimensional Rosenbrock function to make sure they
all converge.
"""
from __future__ import print_function
import sys
import os
from os.path import join as joinpath, realpath, dirname
import tempfile
import shutil
import glob
import subprocess
import numpy as np
sys.dont_write_bytecode = True
try:
bytes
def decode(b):
return b.decode('utf-8')
except Exception:
def decode(b):
return b
# Ask bumps for a list of available fitters
ROOT = realpath(dirname(__file__))
sys.path.insert(0, ROOT)
from bumps.fitters import FIT_AVAILABLE_IDS
RUNPY = joinpath(ROOT, 'run.py')
EXAMPLEDIR = joinpath(ROOT, 'doc', 'examples')
def clear_directory(path, recursive=False):
"""
Remove all regular files in a directory.
If *recursive* is True, removes subdirectories as well.
This does not remove the directory itself. Use *shutil.rmtree* if
you want to delete the entire tree.
"""
for f in os.listdir(path):
target = joinpath(path, f)
if not os.path.isdir(target):
os.unlink(target)
elif recursive:
clear_directory(target, recursive)
os.rmdir(target)
def run_fit(fit_args, model_args, store, seed=1):
command_parts = ([sys.executable, RUNPY] + fit_args + model_args
+ ['--store='+store, '--seed=%d'%seed, '--batch'])
try:
output = subprocess.check_output(command_parts, stderr=subprocess.STDOUT)
output = decode(output.strip())
if output: print(output)
except subprocess.CalledProcessError as exc:
output = decode(exc.output.strip())
if output: print(output)
if "KeyboardInterrupt" in output:
raise KeyboardInterrupt()
else:
raise RuntimeError("fit failed:\n" + " ".join(command_parts))
def check_fit(fitter, store, targets):
errfiles = glob.glob(joinpath(store, "*.err"))
if not errfiles:
raise ValueError("error in %s: no err file created"%fitter)
elif len(errfiles) > 1:
raise ValueError("error in %s: too many err files created"%fitter)
model_index = 0
with open(errfiles[0]) as fid:
for line in fid:
if line.startswith("[overall chisq="):
if line[15:10].lower() == 'inf':
value = np.inf
else:
value = float(line[15:].split("(")[0])
assert abs(value-targets[model_index]) < 1e-2, \
"error in %s: expected %.3f but got %.3f" \
% (fitter, targets[model_index], value)
model_index += 1
assert model_index == len(targets), \
"error in %s: not enough models found"%fitter
def run_fits(model_args, store, fitters=FIT_AVAILABLE_IDS, seed=1, target=0):
failed = []
for f in fitters:
print("====== fitter: %s"%f)
try:
run_fit(["--fit="+f], model_args, store, seed=seed)
check_fit(f, store, [target])
except Exception as exc:
#import traceback; traceback.print_exc()
print(str(exc))
failed.append(f)
clear_directory(store)
return failed
def main():
fitters = sys.argv[1:] if len(sys.argv) > 1 else FIT_AVAILABLE_IDS
store = tempfile.mkdtemp(prefix="bumps-test-")
# TODO: use a test function that defines residuals
test_functions = joinpath(EXAMPLEDIR, "test_functions", "model.py")
#model_args = [test_functions, '"fk(rosenbrock, 3)"']
model_args, target = [test_functions, 'gauss', '3'], 0
model_args, target = [joinpath(EXAMPLEDIR, "curvefit", "curve.py")], 1.760
seed = 1
failed = run_fits(model_args, store, fitters=fitters, seed=seed, target=target)
shutil.rmtree(store)
if failed:
print("======")
print("Fits failed for: %s"%(", ".join(failed),))
sys.exit(1)
if __name__ == "__main__":
main()