From eb3f3109c255698fa6eb999a39f1eba3fea34eae Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 5 Sep 2024 00:22:07 +0800 Subject: [PATCH] add low latency vocos --- .../codec/vocos_low_latency/discriminators.py | 254 ++++++++++++ wenet/codec/vocos_low_latency/train.py | 303 ++++++++++++++ wenet/codec/vocos_low_latency/vocos_my.py | 377 ++++++++++++++++++ 3 files changed, 934 insertions(+) create mode 100644 wenet/codec/vocos_low_latency/discriminators.py create mode 100644 wenet/codec/vocos_low_latency/train.py create mode 100644 wenet/codec/vocos_low_latency/vocos_my.py diff --git a/wenet/codec/vocos_low_latency/discriminators.py b/wenet/codec/vocos_low_latency/discriminators.py new file mode 100644 index 000000000..89b498cd2 --- /dev/null +++ b/wenet/codec/vocos_low_latency/discriminators.py @@ -0,0 +1,254 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm +from torchaudio.transforms import Spectrogram + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__(self, + periods: Tuple[int, ...] = (2, 3, 5, 7, 11), + num_embeddings: Optional[int] = None): + super().__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(period=p, num_embeddings=num_embeddings) + for p in periods + ]) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + bandwidth_id: Optional[torch.Tensor] = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], + List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList([ + weight_norm( + Conv2d(in_channels, + 32, (kernel_size, 1), (stride, 1), + padding=(kernel_size // 2, 0))), + weight_norm( + Conv2d(32, + 128, (kernel_size, 1), (stride, 1), + padding=(kernel_size // 2, 0))), + weight_norm( + Conv2d(128, + 512, (kernel_size, 1), (stride, 1), + padding=(kernel_size // 2, 0))), + weight_norm( + Conv2d(512, + 1024, (kernel_size, 1), (stride, 1), + padding=(kernel_size // 2, 0))), + weight_norm( + Conv2d(1024, + 1024, (kernel_size, 1), (1, 1), + padding=(kernel_size // 2, 0))), + ]) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, + embedding_dim=1024) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, + padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, + x: torch.Tensor, + cond_embedding_id: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorR(window_length=w, num_embeddings=num_embeddings) + for w in fft_sizes + ]) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], + List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], + ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), + (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram(n_fft=window_length, + hop_length=int(window_length * hop_factor), + win_length=window_length, + power=None) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList([ + weight_norm(nn.Conv2d(2, channels, (3, 9), + (1, 1), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ]) + self.band_convs = nn.ModuleList( + [convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, + embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm( + nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + # x = rearrange(x, "b f t c -> b c t f") + x = x.transpose(1, 3) + # Split into bands + x_bands = [x[..., b[0]:b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/wenet/codec/vocos_low_latency/train.py b/wenet/codec/vocos_low_latency/train.py new file mode 100644 index 000000000..78938eccd --- /dev/null +++ b/wenet/codec/vocos_low_latency/train.py @@ -0,0 +1,303 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import torch +import torchaudio +from wenet.codec.vocos_low_latency.discriminators import ( + MultiPeriodDiscriminator, MultiResolutionDiscriminator) +from wenet.codec.vocos_low_latency.vocos_my import (Vocosv1, vocos_config) +from wenet.utils.scheduler import WarmupLR +from wenet.utils.train_utils import init_distributed, init_summarywriter + + +def compute_discriminator_loss(disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor]): + loss = torch.zeros(1, + device=disc_real_outputs[0].device, + dtype=disc_real_outputs[0].dtype) + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss) + g_losses.append(g_loss) + + return loss, r_losses, g_losses + + +def compute_generator_loss( + disc_outputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + loss = torch.zeros(1, + device=disc_outputs[0].device, + dtype=disc_outputs[0].dtype) + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def compute_feature_matching_loss( + fmap_r: List[List[torch.Tensor]], + fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss + + +class MelSpecReconstructionLoss(): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, + sample_rate: int = 24000, + n_fft: int = 1024, + hop_length: int = 256, + n_mels: int = 100, + ): + super().__init__() + + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + ) + + def __call__(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = torch.log(torch.clip(self.mel_spec(y_hat, min=1e-7))) + mel = torch.log(torch.clip(self.mel_spec(y), min=1e-7)) + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +@dataclass +class TrainConfig: + config: vocos_config + + # pretrain_mel_steps: + pretrain_mel_steps = 0 + mrd_loss_coeff = 0.1 + mel_loss_coeff = 45 + + # optimizer config + opt_disc_config = {"lr": 0.001, 'betas': (0.8, 0.9)} + opt_gen_config = {"lr": 0.001, 'betas': (0.8, 0.9)} + + # scheduler conf + disc_scheduler_config = {'warmup_steps': 25000} + gen_scheduler_config = {'warmup_steps': 25000} + + +@dataclass +class TrainState: + model: Vocosv1 + multiperioddisc: MultiPeriodDiscriminator + multiresddisc: MultiResolutionDiscriminator + + scheduler_d: torch.optim.lr_scheduler._LRScheduler + scheduler_g: torch.optim.lr_scheduler._LRScheduler + optimizer_d: torch.optim.Optimizer + optimizer_g: torch.optim.Optimizer + + def __call__(self, input, input_lens): + return self.model(input, input_lens) + + +def create_state(model, multiperioddisc, multiresddisc, opt_disc, opt_gen, + opt_d_scheduler, opt_g_scheduler): + return TrainState(model=model, + multiperioddisc=multiperioddisc, + multiresddisc=multiresddisc, + optimizer_d=opt_disc, + optimizer_g=opt_gen, + scheduler_d=opt_d_scheduler, + scheduler_g=opt_g_scheduler) + # disc_params = [ + # { + # "params": state.multiperioddisc.parameters() + # }, + # { + # "params": state.multiresddisc.parameters() + # }, + # ] + # gen_params = [ + # { + # "params": state.model.parameters() + # }, + # ] + # opt_disc = torch.optim.AdamW(disc_params, + # lr=self.hparams.initial_learning_rate, + # betas=(0.8, 0.9)) + # opt_gen = torch.optim.AdamW(gen_params, + # lr=self.hparams.initial_learning_rate, + # betas=(0.8, 0.9)) + + +def train_step(batch, + state: TrainState, + train_config: TrainConfig, + mel_loss_fn, + global_step: int = 0, + **kwargs): + + mels, mels_lens = batch['mels'], batch['mels_lens'] + audio, _ = batch['wavs'], batch['wavs_lens'] + metrics = {} + for idx in [0, 1]: + # 1 train discriminator + if idx == 0 and global_step >= train_config.pretrain_mel_steps: + with torch.no_grad(): + audio_hat, audio_hat_lens = state(mels, mels_lens) + + real_score_mp, gen_score_mp, _, _ = state.multiperioddisc( + y=audio, + y_hat=audio_hat, + **kwargs, + ) + real_score_mrd, gen_score_mrd, _, _ = state.multiresddisc( + y=audio, + y_hat=audio_hat, + **kwargs, + ) + loss_mp, loss_mp_real, _ = compute_discriminator_loss( + disc_real_outputs=real_score_mp, + disc_generated_outputs=gen_score_mp, + ) + loss_mrd, loss_mrd_real, _ = compute_discriminator_loss( + disc_real_outputs=real_score_mrd, + disc_generated_outputs=gen_score_mrd, + ) + loss_mp /= len(loss_mp_real) + loss_mrd /= len(loss_mrd_real) + loss = loss_mp + train_config.mrd_loss_coeff * loss_mrd + + metrics["discriminator/total"] = loss + metrics["discriminator/multi_period_loss"] = loss_mp + metrics["discriminator/multi_res_loss"] = loss_mrd + state.optimizer_d.zero_grad() + loss.backward() + state.optimizer_d.step() + + else: + audio_hat, audio_hat_lens = state(mels, mels_lens) + if global_step >= train_config.pretrain_mel_steps: + _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = state.multiperioddisc( + y=audio, + y_hat=audio_hat, + **kwargs, + ) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = state.multiresddisc( + y=audio, + y_hat=audio_hat, + **kwargs, + ) + loss_gen_mp, list_loss_gen_mp = compute_generator_loss( + disc_outputs=gen_score_mp) + loss_gen_mrd, list_loss_gen_mrd = compute_generator_loss( + disc_outputs=gen_score_mrd) + loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + loss_fm_mp = compute_feature_matching_loss( + fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) + loss_fm_mrd = compute_feature_matching_loss( + fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) + + else: + loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 + + mel_loss = mel_loss_fn(audio_hat, audio) + loss = (loss_gen_mp + train_config.mrd_loss_coeff * loss_gen_mrd + + loss_fm_mp + train_config.mrd_loss_coeff * loss_fm_mrd + + train_config.mel_loss_coeff * mel_loss) + state.optimizer_g.zero_grad() + loss.backward() + state.optimizer_g.step() + + metrics["generator/total_loss"] = loss + metrics["generator/mel_loss"] = mel_loss + metrics["generator/multi_period_loss"] = loss_gen_mp + metrics["generator/multi_res_loss"] = loss_gen_mrd + metrics["generator/feature_matching_mp"] = loss_fm_mp + metrics["generator/feature_matching_mrd"] = loss_fm_mrd + + return metrics + + +def main(): + # TODO: args + args = "" + _, _, rank = init_distributed(args) + + # init dataset + train_iter = ... + eval_iter = ... + # init model + model_config = vocos_config() + model = Vocosv1(model_config) + multiperioddisc = MultiPeriodDiscriminator() + multiresddisc = MultiResolutionDiscriminator() + + # train config + config = TrainConfig(config=model_config) + # Tensorboard + writer = init_summarywriter(args) + + disc_params = [ + { + "params": multiperioddisc.parameters() + }, + { + "params": model.multiresddisc.parameters() + }, + ] + gen_params = [ + { + "params": model.parameters() + }, + ] + opt_disc = torch.optim.AdamW(disc_params, **config.opt_disc_config) + opt_gen = torch.optim.AdamW(gen_params, **config.opt_gen_config) + scheduler_disc = WarmupLR(opt_disc, **config.disc_scheduler_config) + scheduler_gen = WarmupLR(opt_gen, **config.gen_scheduler_config) + + import torch.distributed as dist + if dist.is_initialized(): + model = torch.nn.parallel.DistributedDataParallel( + model, find_unused_parameters=False) + multiperioddisc = torch.nn.parallel.DistributedDataParallel( + multiperioddisc, find_unused_parameters=False) + multiresddisc = torch.nn.parallel.DistributedDataParallel( + multiresddisc, find_unused_parameters=False) + + train_state = create_state(model, multiperioddisc, multiresddisc, opt_disc, + opt_gen, scheduler_disc, scheduler_gen) + + global_step = 0 + mel_loss_fn = MelSpecReconstructionLoss() + for batch in enumerate(train_iter): + metric = train_step(batch, train_state, mel_loss_fn, global_step) + # TODO: + # writer to tensorboard + # interval logging diff --git a/wenet/codec/vocos_low_latency/vocos_my.py b/wenet/codec/vocos_low_latency/vocos_my.py new file mode 100644 index 000000000..679329a23 --- /dev/null +++ b/wenet/codec/vocos_low_latency/vocos_my.py @@ -0,0 +1,377 @@ +from dataclasses import dataclass +from functools import partial +from typing import Optional, Tuple +import torch +from torch.nn import attention +from torch.nn.modules import activation +from wenet.transformer.attention import T_CACHE, MultiHeadedAttention +from wenet.transformer.convolution import ConvolutionModule + +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES, WENET_MLP_CLASSES, WENET_NORM_CLASSES +from wenet.utils.mask import causal_or_lookahead_mask, make_non_pad_mask + + +@dataclass +class vocos_config: + # convoultion config + causal = True + activation = 'gelu' + conv_norm = 'batch_norm' + conv_bias = False + norm_eps = 1e-6 + linear_units = 1536 + kernel_size = 15 + + # conformer config + att_blocks = 3 + input_size = 100 + output_size = 256 + attention_heads = 4 + attention_dropout_rate = 0.1 + qkv_bias = False + use_sdpa = False + n_kv_head = 1 # MQA + + # head config + n_fft = 1024 # 2048 + hop_length = 256 # 640 + padding = 'center' + + dropout_rate = 0.1 + norm_type = 'rms_norm' + mlp_type = 'position_wise_feed_forward' + mlp_bias = False + head_dim = 512 + + +# https://github.com/gemelo-ai/vocos/blob/main/vocos/heads.py#L26 +class ISTFTHead(torch.nn.Module): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, config: vocos_config): + super().__init__() + self.dim = config.output_size + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.padding = config.padding + self.win_length = self.n_fft + self.window = torch.hann_window(self.win_length) + + out_dim = self.n_fft + 2 + self.out = torch.nn.Linear(self.dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + return torch.istft(S, + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True) + + +class ConvNoattLayer(ConformerEncoderLayer): + + def __init__(self, + size: int, + feed_forward: Optional[torch.nn.Module] = None, + feed_forward_macaron: Optional[torch.nn.Module] = None, + conv_module: Optional[torch.nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 0.00001): + super().__init__(size, None, feed_forward, feed_forward_macaron, + conv_module, dropout_rate, normalize_before, + layer_norm_type, norm_eps) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + assert self.feed_forward is not None + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_cnn_cache + + +class ConformerConvBeforeAttLayer(ConformerEncoderLayer): + + def __init__(self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[torch.nn.Module] = None, + feed_forward_macaron: Optional[torch.nn.Module] = None, + conv_module: Optional[torch.nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 0.00001): + super().__init__(size, self_attn, feed_forward, feed_forward_macaron, + conv_module, dropout_rate, normalize_before, + layer_norm_type, norm_eps) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros((0, 0, 0, 0))), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + assert self.feed_forward is not None + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache + + +class Vocosv1(torch.nn.Module): + # modify: + # two causal convolution -> three conformer encoder -> head + + def __init__(self, config: vocos_config) -> None: + super().__init__() + + activation = WENET_ACTIVATION_CLASSES[config.activation]() + attention_class = MultiHeadedAttention + attention_layer_args = ( + config.attention_heads, + config.output_size, + config.attention_dropout_rate, + config.qkv_bias, + config.qkv_bias, + config.qkv_bias, + config.use_sdpa, + config.n_kv_head, + config.head_dim, + ) + mlp_class = WENET_MLP_CLASSES[config.mlp_type] + # feed-forward module definition + positionwise_layer_args = ( + config.output_size, + config.linear_units, + config.dropout_rate, + activation, + config.mlp_bias, # mlp bias + ) + + # convolution module definition + convolution_layer_args = ( + config.output_size, + config.kernel_size, + activation, + config.conv_norm, + config.causal, + config.conv_bias, + ) + self.linaer = torch.nn.Linear(config.input_size, config.output_size) + # first two convolution + self.conv1 = ConvNoattLayer(config.output_size, + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args), + ConvolutionModule(*convolution_layer_args), + config.dropout_rate, + layer_norm_type=config.norm_type, + norm_eps=config.norm_eps) + self.conv2 = ConvNoattLayer(config.output_size, + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args), + ConvolutionModule(*convolution_layer_args), + config.dropout_rate, + layer_norm_type=config.norm_type, + norm_eps=config.norm_eps) + + self.encoders = torch.nn.ModuleList([ + ConformerConvBeforeAttLayer( + config.output_size, + attention_class(*attention_layer_args), + mlp_class(*positionwise_layer_args), + mlp_class(*positionwise_layer_args), + ConvolutionModule(*convolution_layer_args), + config.dropout_rate, + True, # normalize befor + 'rms_norm', + config.norm_eps, + ) for _ in range(config.att_blocks) + ]) + + self.head = ISTFTHead(config) + + self.config = config + + def forward(self, input: torch.Tensor, input_len: torch.Tensor): + """ forward for training + """ + x = self.linaer(input) + mask = make_non_pad_mask(input_len) # [B,T] + + # TODO: add ln here + + x, _, _ = self.conv1(x, mask.squeeze(1)) + x, _, _ = self.conv2(x, mask.squeeze(1)) + + causal_att_mask = causal_or_lookahead_mask(mask.unsqueeze(1), 0, 13) + + # TODO: use sdpa here + for i, layer in enumerate(self.encoders): + x, causal_att_mask, _, _ = layer(x, causal_att_mask, None, + mask.unsqueeze(1)) + + audio = self.head(x) + + return audio, (mask.sum(1) - 1) * self.config.hop_length + + +# input = torch.rand(1, 24000) +# input_len = torch.tensor([24000]) + +# import vocos + +# feature = vocos.feature_extractors.MelSpectrogramFeatures() + +# config = vocos_config() +# model = Vocosv1(config) +# print(model) +# mels = feature(input) +# print(mels.shape) + +# mels = mels.transpose(1, 2) +# gen, gen_lens = model(mels, torch.tensor([mels.shape[1]])) +# print(gen.shape, gen_lens, input.shape) + +# print(feature(gen).shape)