-
-
Notifications
You must be signed in to change notification settings - Fork 150
/
Copy pathtest_errors.py
169 lines (126 loc) · 3.08 KB
/
test_errors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import pytest
def _f(x):
x = eqx.error_if(x, x < 0, "x must be non-negative")
return jax.nn.relu(x)
# Strangely, JAX raises different errors depending on context.
_error = pytest.raises((ValueError, RuntimeError))
def test_basic():
jf = jax.jit(_f)
_f(1.0)
jf(1.0)
with _error:
_f(-1.0)
with _error:
jf(-1.0)
def test_vmap():
vf = jax.vmap(_f)
jvf = jax.jit(vf)
good = jnp.array([1.0, 1.0])
bad1 = jnp.array([1.0, -1.0])
bad2 = jnp.array([-1.0, -1.0])
vf(good)
jvf(good)
with _error:
vf(bad1)
with _error:
vf(bad2)
with _error:
jvf(bad1)
with _error:
jvf(bad2)
def test_jvp():
def g(p, t):
return jax.jvp(_f, (p,), (t,))
jg = jax.jit(g)
for h in (g, jg):
h(1.0, 1.0)
h(1.0, -1.0)
with _error:
h(-1.0, 1.0)
with _error:
h(-1.0, -1.0)
def test_grad():
g = jax.grad(_f)
jg = jax.jit(g)
for h in (g, jg):
h(1.0)
with _error:
h(-1.0)
def test_grad2():
@jax.jit
@jax.grad
def f(x, y, z):
x = eqxi.nondifferentiable_backward(x)
x, y = eqx.error_if((x, y), z, "oops")
return y
f(1.0, 1.0, True)
def test_tracetime():
@jax.jit
def f(x):
return eqx.error_if(x, True, "hi")
with pytest.raises(Exception):
f(1.0)
def test_nan_tracetime():
@jax.jit
def f(x):
return eqx.error_if(x, True, "hi", on_error="nan")
with pytest.warns(UserWarning):
y = f(1.0)
assert jnp.isnan(y)
def test_nan():
@jax.jit
def f(x, pred):
return eqx.error_if(x, pred, "hi", on_error="nan")
y = f(1.0, True)
assert jnp.isnan(y)
def test_assert_dce():
@jax.jit
def f(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x
f(1.0)
@jax.jit
def g(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x
with jax.disable_jit():
g(1.0)
def test_traceback_runtime_eqx(caplog):
@eqx.filter_jit
def f(x):
return g(x)
@eqx.filter_jit
def g(x):
return eqx.error_if(x, x > 0, "egads")
try:
f(jnp.array(1.0))
except Exception as e:
assert caplog.text == ""
assert e.__cause__ is None
msg = str(e).strip()
assert msg.startswith("Above is the stack outside of JIT")
assert "egads" in msg
assert "EQX_ON_ERROR" in msg
def test_traceback_runtime_custom():
class FooException(Exception):
pass
@eqx.filter_jit
def f(x):
return g(x)
@eqx.filter_jit
def g(x):
def _raises():
raise FooException("egads")
return jax.pure_callback(_raises, x) # pyright: ignore
try:
f(jnp.array(1.0))
except Exception as e:
# assert e.__cause__ is None # varies by Python version and JAX version.
assert "egads" in str(e)
assert "EQX_ON_ERROR" not in str(e)