Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli/paraformer] ali-paraformer inference #2067

Merged
merged 21 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[cli/paraformer] ali-paraformer load and infer work
  • Loading branch information
Mddct committed Oct 20, 2023
commit 000c7af5c36afafd3fa3ae4599af237e902c67bb
214 changes: 214 additions & 0 deletions wenet/paraformer/experiment/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Optional, Tuple
from wenet.transformer.attention import MultiHeadedAttention
from torch import nn
import math
import torch


class MultiHeadedAttentionSANM(MultiHeadedAttention):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0):
"""Construct an MultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# We assume d_v always equals d_k
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
del self.linear_q, self.linear_k, self.linear_v
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)

self.fsmn_block = nn.Conv1d(n_feat,
n_feat,
kernel_size,
stride=1,
padding=0,
groups=n_feat,
bias=False)
# padding
self.left_padding = (kernel_size - 1) // 2
if sanm_shfit > 0:
self.left_padding = self.left_padding + sanm_shfit
self.right_padding = kernel_size - 1 - self.left_padding
self.pad_fn = nn.ConstantPad1d((self.left_padding, self.right_padding),
0.0)

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

x = query
b, t, _ = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time1, d_k)
k = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
v = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)

return q, k, v

def forward_fsmn(self,
inputs: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk=None):
b, _, t, _ = inputs.size()
inputs = inputs.transpose(1, 2).view(b, t, -1)
if mask.size(2) > 0: # time2 > 0
# TODO(Mddct): make sure mask is right
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
mask = mask.transpose(1, 2) # [B,T,1]
inputs = inputs * mask
x = inputs.transpose(1, 2)
# x = torch.nn.functional.pad(x, (self.left_padding, self.right_padding),
# value=0.0,
# mode='constant')
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
return x * mask

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(Mddct): we need know fsmn_memory's cache, but paraformer is nonstreamming
# refactor later if streaming model is available
new_cache = torch.cat((k, v), dim=-1)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
att = self.forward_attention(v, scores, mask)
return att + fsmn_memory, new_cache


class DummyMultiHeadSANM(MultiHeadedAttentionSANM):
"""A dummy multihead attention for Paraformer befroe cross attention
"""

def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0):
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size,
sanm_shfit)
del self.linear_q_k_v
del self.linear_out

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

inputs = query

x = inputs.transpose(1, 2)
x = self.pad_fn(x)
# TODO(Mddct): cache here for future streaming
cache: torch.Tensor
x = self.fsmn_block(x)
x = x.transpose(1, 2)
if x.size(1) != inputs.size(1):
inputs = inputs[:, -1, :]

x = x + inputs
x = self.dropout(x)
if mask is not None:
x = x * mask.transpose(1, 2)
return x, cache


class MultiHeadAttentionCross(MultiHeadedAttentionSANM):

def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0,
target_size: Optional[int] = None):
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size,
sanm_shfit)
del self.linear_q_k_v
del self.fsmn_block
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(
n_feat if target_size is None else target_size, n_feat * 2)

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NOTE(Mddct): here value == key
_ = value

x = query
b = x.size(0)
q = self.linear_q(x)
q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time1, d_k)

k_v = self.linear_k_v(key)
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)

return q_h, k_h, v_h

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = self.forward_qkv(query, key, key)
q = q * self.d_k**(-0.5)
scores = torch.matmul(q, k.transpose(-2, -1))

# TODO(Mddct): support future streaming paraformer
cache: torch.Tensor
return self.forward_attention(v, scores, mask), cache
42 changes: 42 additions & 0 deletions wenet/paraformer/experiment/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# network architecture
# encoder related
encoder: SanEncoder
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 50 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: 'conv2d' # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
kernel_size: 11
sanm_shfit: 0

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 16
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
att_layer_num: 16
kernel_size: 11
sanm_shfit: 0

predictor_conf:
idim: 512
threshold: 1.0
l_order: 1
r_order: 1
tail_threshold: 0.45
cnn_groups: 1
# smooth_factor2: 0.25
# noise_threshold2: 0.01
# upsample_times: 3
# use_cif1_cnn: false # TODO: support in the future, has no effect on model wer (timestamp related)
# upsample_type: cnn_blstm # TODO: support in the future, has no effect on model wer (timestamp related)
Loading
Loading