forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_train_dtype.py
95 lines (90 loc) · 2.73 KB
/
test_train_dtype.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import pytest
import torch
from espnet.nets.asr_interface import dynamic_import_asr
@pytest.mark.parametrize(
"dtype, device, model, conf",
[
(dtype, device, nn, conf)
for nn, conf in [
(
"transformer",
dict(adim=4, eunits=3, dunits=3, elayers=2, dlayers=2, mtlalpha=0.0),
),
(
"transformer",
dict(
adim=4,
eunits=3,
dunits=3,
elayers=2,
dlayers=2,
mtlalpha=0.5,
ctc_type="builtin",
),
),
(
"transformer",
dict(
adim=4,
eunits=3,
dunits=3,
elayers=2,
dlayers=2,
mtlalpha=0.5,
ctc_type="builtin",
),
),
(
"rnn",
dict(adim=4, eunits=3, dunits=3, elayers=2, dlayers=2, mtlalpha=0.0),
),
(
"rnn",
dict(
adim=4,
eunits=3,
dunits=3,
elayers=2,
dlayers=2,
mtlalpha=0.5,
ctc_type="builtin",
),
),
(
"rnn",
dict(
adim=4,
eunits=3,
dunits=3,
elayers=2,
dlayers=2,
mtlalpha=0.5,
ctc_type="builtin",
),
),
]
for dtype in ("float16", "float32", "float64")
for device in ("cpu", "cuda")
],
)
def test_train_pytorch_dtype(dtype, device, model, conf):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("no cuda device is available")
if device == "cpu" and dtype == "float16":
pytest.skip("cpu float16 implementation is not available in pytorch yet")
idim = 10
odim = 10
model = dynamic_import_asr(model, "pytorch").build(idim, odim, **conf)
dtype = getattr(torch, dtype)
device = torch.device(device)
model.to(dtype=dtype, device=device)
x = torch.rand(2, 10, idim, dtype=dtype, device=device)
ilens = torch.tensor([10, 7], device=device)
y = torch.randint(1, odim, (2, 3), device=device)
opt = torch.optim.Adam(model.parameters())
loss = model(x, ilens, y)
assert loss.dtype == dtype
model.zero_grad()
loss.backward()
assert any(p.grad is not None for p in model.parameters())
opt.step()