-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathtest_numba_type_checking.py
170 lines (134 loc) · 5.9 KB
/
test_numba_type_checking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import typing
import unittest
from datetime import datetime
import csp
from csp import ts
class TestNumbaTypeChecking(unittest.TestCase):
@unittest.skip("numba not yet used, tests fail on newer numba we get in our 3.8 build")
def test_graph_build_type_checking(self):
@csp.numba_node
def typed_ts(x: ts[int]):
if csp.ticked(x):
pass
@csp.numba_node
def typed_scalar(x: ts[int], y: str):
if csp.ticked(x):
pass
@csp.graph
def graph():
i = csp.const(5)
typed_ts(i)
typed_scalar(i, "xyz")
with self.assertRaisesRegex(TypeError, "Expected ts\\[int\\] for argument 'x', got ts\\[str\\]"):
s = csp.const("xyz")
## THIS SHOULD RAISE, passing ts[str] but typed takes ts[int]
typed_ts(s)
with self.assertRaisesRegex(TypeError, "Expected str for argument 'y', got 123 \\(int\\)"):
## THIS SHOULD RAISE, passing int instead of str
typed_scalar(i, 123)
csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))
@unittest.skip("numba not yet used, tests fail on newer numba we get in our 3.8 build")
def test_runtime_type_check(self):
## native output type
@csp.numba_node
def typed_int(x: ts["T"]) -> ts[int]:
if csp.ticked(x):
return x
# TODO: Uncomment
# @csp.numba_node
# def typed_alarm(v: '~T', alarm_type: 'V') -> outputs(ts['V']):
# with csp.alarms():
# alarm = csp.alarm( 'V' )
# with csp.start():
# csp.schedule_alarm(alarm, timedelta(), v)
#
# if csp.ticked(alarm):
# return alarm
# Valid
csp.run(typed_int, csp.const(5), starttime=datetime(2020, 2, 7))
# Invalid
with self.assertRaisesRegex(RuntimeError, "Unable to resolve getter function for type.*"):
csp.run(typed_int, csp.const("5"), starttime=datetime(2020, 2, 7))
# TODO: uncomment
# # valid
# csp.run(typed_alarm, 5, int, starttime=datetime(2020, 2, 7))
# csp.run(typed_alarm, 5, object, starttime=datetime(2020, 2, 7))
# csp.run(typed_alarm, [1, 2, 3], [int], starttime=datetime(2020, 2, 7))
#
# # Invalid
# with self.assertRaisesRegex(TypeError,
# '"typed_alarm" node expected output type on output #0 to be of type "str" got type "int"'):
# csp.run(typed_alarm, 5, str, starttime=datetime(2020, 2, 7))
#
# with self.assertRaisesRegex(TypeError,
# '"typed_alarm" node expected output type on output #0 to be of type "bool" got type "int"'):
# csp.run(typed_alarm, 5, bool, starttime=datetime(2020, 2, 7))
#
# with self.assertRaisesRegex(TypeError,
# '"typed_alarm" node expected output type on output #0 to be of type "str" got type "list"'):
# csp.run(typed_alarm, [1, 2, 3], str, starttime=datetime(2020, 2, 7))
@unittest.skip("numba not yet used, tests fail on newer numba we get in our 3.8 build")
def test_dict_type_resolutions(self):
@csp.numba_node
def typed_dict_int_int(x: {int: int}):
pass
@csp.numba_node
def typed_dict_int_int2(x: typing.Dict[int, int]):
pass
@csp.numba_node
def typed_dict_int_float(x: {int: int}):
pass
@csp.numba_node
def typed_dict_float_float(x: {float: float}):
pass
@csp.numba_node
def typed_dict(x: {"T": "V"}):
pass
@csp.numba_node
def deep_nested_generic_resolution(x: "T1", y: "T2", z: {"T1": {"T2": [{"T1"}]}}):
pass
@csp.graph
def graph():
d_i_i = csp.const({1: 2, 3: 4})
csp.add_graph_output("o1", d_i_i)
# Ok int dict expected
typed_dict_int_int({1: 2, 3: 4})
# Ok int dict expected
typed_dict_int_int2({1: 2, 3: 4})
typed_dict_float_float({1: 2})
typed_dict_float_float({1.0: 2})
typed_dict_float_float({})
with self.assertRaisesRegex(TypeError, r"Expected typing.Dict\[int, int\] for argument 'x', got .*"):
# Passing a float value instead of expected ints
typed_dict_int_int2({1: 2, 3: 4.0})
l_good = csp.const.using(T={int: float})({})
csp.add_graph_output("o2", l_good)
l_good = csp.const.using(T={int: float})({2: 1})
csp.add_graph_output("o3", l_good)
l_good = csp.const.using(T={int: float})({2: 1.0})
csp.add_graph_output("o4", l_good)
csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))
@unittest.skip("numba not yet used, tests fail on newer numba we get in our 3.8 build")
def test_list_type_resolutions(self):
@csp.numba_node
def typed_list_int(x: [int]):
pass
@csp.numba_node
def typed_list_int2(x: typing.List[int]):
pass
@csp.numba_node
def typed_list_float(x: [float]):
pass
def graph():
l_i = csp.const([1, 2, 3, 4])
typed_list_int([])
typed_list_int([1, 2, 3])
typed_list_int2([1, 2, 3])
typed_list_float([1, 2, 3])
typed_list_float([1, 2, 3.0])
with self.assertRaisesRegex(TypeError, r"Expected typing.List\[int\] for argument 'x', got .*"):
# Passing a float value instead of expected ints
typed_list_int([1, 2, 3.0])
csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1))
if __name__ == "__main__":
unittest.main()