forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_e2e_asr_maskctc.py
129 lines (107 loc) · 3.62 KB
/
test_e2e_asr_maskctc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import pytest
import torch
from espnet.nets.pytorch_backend.e2e_asr_maskctc import E2E
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.transformer import plot
def make_arg(**kwargs):
defaults = dict(
adim=2,
aheads=2,
dropout_rate=0.0,
transformer_attn_dropout_rate=None,
elayers=1,
eunits=2,
dlayers=1,
dunits=2,
sym_space="<space>",
sym_blank="<blank>",
transformer_decoder_selfattn_layer_type="selfattn",
transformer_encoder_selfattn_layer_type="selfattn",
transformer_init="pytorch",
transformer_input_layer="conv2d",
transformer_length_normalized_loss=False,
report_cer=False,
report_wer=False,
mtlalpha=0.3,
lsm_weight=0.001,
wshare=4,
char_list=["<blank>", "a", "e", "<eos>"],
ctc_type="builtin",
)
defaults.update(kwargs)
return argparse.Namespace(**defaults)
def prepare(args):
idim = 10
odim = len(args.char_list)
model = E2E(idim, odim, args)
batchsize = 2
x = torch.randn(batchsize, 15, idim)
ilens = [15, 10]
n_token = model.odim - 2 # w/o <eos>/<sos>, <mask>
y = (torch.rand(batchsize, 10) * n_token % n_token).long()
olens = [7, 6]
for i in range(batchsize):
x[i, ilens[i] :] = -1
y[i, olens[i] :] = model.ignore_id
data = {}
uttid_list = []
for i in range(batchsize):
data["utt%d" % i] = {
"input": [{"shape": [ilens[i], idim]}],
"output": [{"shape": [olens[i]]}],
}
uttid_list.append("utt%d" % i)
return model, x, torch.tensor(ilens), y, data, uttid_list
def test_mask():
args = make_arg()
model, x, ilens, y, data, uttid_list = prepare(args)
# check <sos>/<eos>, <mask> position
n_char = len(args.char_list) + 1
assert model.sos == n_char - 2
assert model.eos == n_char - 2
assert model.mask_token == n_char - 1
yi, yo = mask_uniform(y, model.mask_token, model.eos, model.ignore_id)
assert (
(yi == model.mask_token).detach().numpy()
== (yo != model.ignore_id).detach().numpy()
).all()
def _savefn(*args, **kwargs):
return
maskctc_interctc = {
"maskctc_n_iterations": 0,
"maskctc_probability_threshold": 0.5,
"elayers": 2,
"intermediate_ctc_weight": 0.3,
"intermediate_ctc_layer": "1",
}
@pytest.mark.parametrize(
"model_dict",
[
({"maskctc_n_iterations": 1, "maskctc_probability_threshold": 0.0}),
({"maskctc_n_iterations": 1, "maskctc_probability_threshold": 0.5}),
({"maskctc_n_iterations": 2, "maskctc_probability_threshold": 0.5}),
({"maskctc_n_iterations": 0, "maskctc_probability_threshold": 0.5}),
maskctc_interctc,
],
)
def test_transformer_trainable_and_decodable(model_dict):
args = make_arg(**model_dict)
model, x, ilens, y, data, uttid_list = prepare(args)
# decoding params
recog_args = argparse.Namespace(
maskctc_n_iterations=args.maskctc_n_iterations,
maskctc_probability_threshold=args.maskctc_probability_threshold,
)
# test training
optim = torch.optim.Adam(model.parameters(), 0.01)
loss = model(x, ilens, y)
optim.zero_grad()
loss.backward()
optim.step()
# test attention plot
attn_dict = model.calculate_all_attentions(x[0:1], ilens[0:1], y[0:1])
plot.plot_multi_head_attention(data, uttid_list, attn_dict, "", savefn=_savefn)
# test decoding
with torch.no_grad():
model.recognize(x[0, : ilens[0]].numpy(), recog_args, args.char_list)