Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli/paraformer] ali-paraformer inference #2067

Merged
merged 21 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rm positionwise_feed_forward.py/lfr.py
  • Loading branch information
Mddct committed Oct 23, 2023
commit 3d25e2e8473133d54bb0b39e04bee273bf634d77
72 changes: 0 additions & 72 deletions wenet/paraformer/ali_paraformer/lfr.py

This file was deleted.

104 changes: 101 additions & 3 deletions wenet/paraformer/ali_paraformer/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
""" NOTE(Mddct): This file is experimental and is used to export paraformer
"""

import math
from typing import Dict, List, Optional, Tuple
import torch
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM
)
from wenet.paraformer.ali_paraformer.lfr import LFR
from wenet.paraformer.ali_paraformer.positionwise_feed_forward import \
PositionwiseFeedForwardDecoderSANM
from wenet.transformer.search import DecodeResult
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -20,6 +18,106 @@
from wenet.utils.mask import make_non_pad_mask


class LFR(torch.nn.Module):

def __init__(self, m: int = 7, n: int = 6) -> None:
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.

"""
super().__init__()

self.m = m
self.n = n

self.left_padding_nums = math.ceil((self.m - 1) // 2)

def forward(self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, _, D = input.size()
n_lfr = torch.ceil(input_lens / self.n)
# print(n_lfr)
# right_padding_nums >= 0
prepad_nums = input_lens + self.left_padding_nums

right_padding_nums = torch.where(
self.m >= (prepad_nums - self.n * (n_lfr - 1)),
self.m - (prepad_nums - self.n * (n_lfr - 1)),
0,
)
T_all = self.left_padding_nums + input_lens + right_padding_nums

new_len = T_all // self.n

T_all_max = T_all.max().int()

tail_frames_index = (input_lens - 1).view(B, 1, 1).repeat(1, 1,
D) # [B,1,D]

tail_frames = torch.gather(input, 1, tail_frames_index)
tail_frames = tail_frames.repeat(1, right_padding_nums.max().int(), 1)
head_frames = input[:, 0:1, :].repeat(1, self.left_padding_nums, 1)

# stack
input = torch.cat([head_frames, input, tail_frames], dim=1)

index = torch.arange(T_all_max,
device=input.device,
dtype=input_lens.dtype).unsqueeze(0).repeat(
B, 1) # [B, T_all_max]
# [B, T_all_max]
index_mask = index < (self.left_padding_nums + input_lens).unsqueeze(1)

tail_index_mask = torch.logical_not(
index >= (T_all.unsqueeze(1))) & index_mask
tail = torch.ones(T_all_max,
dtype=input_lens.dtype,
device=input.device).unsqueeze(0).repeat(B, 1) * (
T_all_max - 1) # [B, T_all_max]
indices = torch.where(torch.logical_or(index_mask, tail_index_mask),
index, tail)
input = torch.gather(input, 1, indices.unsqueeze(2).repeat(1, 1, D))

input = input.unfold(1, self.m, step=self.n).transpose(2, 3)
# new len
return input.reshape(B, -1, D * self.m), new_len


class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
"""Positionwise feed forward layer.

Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.

"""

def __init__(self,
idim,
hidden_units,
dropout_rate,
adim=None,
activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForwardDecoderSANM, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units,
idim if adim is None else adim,
bias=False)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
self.norm = torch.nn.LayerNorm(hidden_units)

def forward(self, x):
"""Forward function."""
return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))


class SinusoidalPositionEncoder(torch.nn.Module):
"""https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/modules/embedding.py#L387
"""
Expand Down
32 changes: 0 additions & 32 deletions wenet/paraformer/ali_paraformer/positionwise_feed_forward.py

This file was deleted.

Loading