forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_scheduler.py
45 lines (36 loc) · 1.22 KB
/
test_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import chainer
import numpy
import pytest
import torch
from espnet.scheduler import scheduler
from espnet.scheduler.chainer import ChainerScheduler
from espnet.scheduler.pytorch import PyTorchScheduler
@pytest.mark.parametrize("name", scheduler.SCHEDULER_DICT.keys())
def test_scheduler(name):
s = scheduler.dynamic_import_scheduler(name).build("lr")
assert s.key == "lr"
assert isinstance(s.scale(0), float)
assert isinstance(s.scale(1000), float)
def test_pytorch_scheduler():
warmup = 30000
s = scheduler.NoamScheduler.build("lr", warmup=warmup)
net = torch.nn.Linear(2, 1)
o = torch.optim.SGD(net.parameters(), lr=1.0)
so = PyTorchScheduler([s], o)
so.step(0)
for g in o.param_groups:
assert g["lr"] == s.scale(0)
so.step(warmup)
for g in o.param_groups:
numpy.testing.assert_allclose(g["lr"], 1.0, rtol=1e-4)
def test_chainer_scheduler():
warmup = 30000
s = scheduler.NoamScheduler.build("lr", warmup=warmup)
net = chainer.links.Linear(2, 1)
o = chainer.optimizers.SGD(lr=1.0)
o.setup(net)
so = ChainerScheduler([s], o)
so.step(0)
assert o.lr == s.scale(0)
so.step(warmup)
numpy.testing.assert_allclose(o.lr, 1.0, rtol=1e-4)