From 6c619cc806cea6a63dc310e52e0c16da2b025d66 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Wed, 13 Mar 2024 16:09:08 +0800 Subject: [PATCH] Update Bert-VITS2 japanese Extra --- bert_vits2/bert_vits2.py | 29 +- bert_vits2/commons.py | 43 +- bert_vits2/models_ja_extra.py | 1016 ++++++++++++++++++++++++ bert_vits2/text/cleaner.py | 3 +- bert_vits2/text/japanese.py | 6 +- bert_vits2/text/japanese_bert_extra.py | 42 + bert_vits2/text/japanese_extra.py | 524 ++++++++++++ bert_vits2/text/japanese_mora_list.py | 230 ++++++ manager/TTSManager.py | 3 + manager/model_handler.py | 4 +- 10 files changed, 1870 insertions(+), 30 deletions(-) create mode 100644 bert_vits2/models_ja_extra.py create mode 100644 bert_vits2/text/japanese_bert_extra.py create mode 100644 bert_vits2/text/japanese_extra.py create mode 100644 bert_vits2/text/japanese_mora_list.py diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index 3afd750..72ef626 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -9,6 +9,7 @@ from bert_vits2.get_emo import get_emo from bert_vits2.models import SynthesizerTrn from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230 +from bert_vits2.models_ja_extra import SynthesizerTrn as SynthesizerTrn_ja_extra from bert_vits2.text import * from bert_vits2.text.cleaner import clean_text from bert_vits2.utils import process_legacy_versions @@ -29,6 +30,7 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): self.bert_model_names = {} self.zh_bert_extra = False + self.ja_bert_extra = False self.ja_bert_dim = 1024 self.num_tones = num_tones self.pinyinPlus = None @@ -164,6 +166,20 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): self.bert_extra_str_map.update({"zh": "_extra"}) self.text_extra_str_map.update({"zh": "_v240"}) + elif self.version is not None and self.version in ["ja_extra"]: + """ + deberta-v2-large-japanese-char-wwm + """ + self.version = "ja_extra" + self.hps_ms.model.emotion_embedding = 2 + self.hps_ms.model.n_layers_trans_flow = 6 + self.lang = ["ja"] + self.num_tones = num_tones + self.ja_bert_extra = True + self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) + self.bert_extra_str_map.update({"ja": "_extra"}) + self.text_extra_str_map.update({"ja": "_extra"}) + else: logging.debug("Version information not found. Loaded as the newest version: v2.3.") self.version = "2.3" @@ -185,6 +201,8 @@ def load_model(self, model_handler): if self.version in ["2.3", "extra", "2.4"]: Synthesizer = SynthesizerTrn_v230 + elif self.version == "ja_extra": + Synthesizer = SynthesizerTrn_ja_extra else: Synthesizer = SynthesizerTrn @@ -235,6 +253,9 @@ def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7): if self.zh_bert_extra: zh_bert = bert ja_bert, en_bert = None, None + elif self.ja_bert_extra: + ja_bert = bert + zh_bert, en_bert = None, None elif language_str == "zh": zh_bert = bert ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) @@ -287,8 +308,12 @@ def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_rat x_tst = phones.to(self.device).unsqueeze(0) tones = tones.to(self.device).unsqueeze(0) lang_ids = lang_ids.to(self.device).unsqueeze(0) - zh_bert = zh_bert.to(self.device).unsqueeze(0) - if not self.zh_bert_extra: + if self.zh_bert_extra: + zh_bert = zh_bert.to(self.device).unsqueeze(0) + elif self.ja_bert_extra: + ja_bert = ja_bert.to(self.device).unsqueeze(0) + else: + zh_bert = zh_bert.to(self.device).unsqueeze(0) ja_bert = ja_bert.to(self.device).unsqueeze(0) en_bert = en_bert.to(self.device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device) diff --git a/bert_vits2/commons.py b/bert_vits2/commons.py index 9704898..1bc7d9a 100644 --- a/bert_vits2/commons.py +++ b/bert_vits2/commons.py @@ -1,7 +1,5 @@ import math -import numpy as np import torch -from torch import nn from torch.nn import functional as F @@ -16,8 +14,8 @@ def get_padding(kernel_size, dilation=1): def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] return pad_shape @@ -30,7 +28,9 @@ def intersperse(lst, item): def kl_divergence(m_p, logs_p, m_q, logs_q): """KL(P||Q)""" kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q) + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) return kl @@ -46,33 +46,31 @@ def rand_gumbel_like(x): def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + gather_indices = ids_str.view(x.size(0), 1, 1).repeat( + 1, x.size(1), 1 + ) + torch.arange(segment_size, device=x.device) + return torch.gather(x, 2, gather_indices) def rand_slice_segments(x, x_lengths=None, segment_size=4): b, d, t = x.size() if x_lengths is None: x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) ret = slice_segments(x, ids_str, segment_size) return ret, ids_str -def get_timing_signal_1d( - length, channels, min_timescale=1.0, max_timescale=1.0e4): +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): position = torch.arange(length, dtype=torch.float) num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)) + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = F.pad(signal, [0, 0, 0, channels % 2]) @@ -108,8 +106,8 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] return pad_shape @@ -130,7 +128,6 @@ def generate_path(duration, mask): duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1) @@ -157,5 +154,5 @@ def clip_grad_value_(parameters, clip_value, norm_type=2): total_norm += param_norm.item() ** norm_type if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1. / norm_type) + total_norm = total_norm ** (1.0 / norm_type) return total_norm diff --git a/bert_vits2/models_ja_extra.py b/bert_vits2/models_ja_extra.py new file mode 100644 index 0000000..6f9e30f --- /dev/null +++ b/bert_vits2/models_ja_extra.py @@ -0,0 +1,1016 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from bert_vits2 import commons +from bert_vits2 import modules +from bert_vits2 import attentions + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from bert_vits2.commons import init_weights, get_padding +from bert_vits2.text import symbols, num_tones, num_languages + +from vector_quantize_pytorch import VectorQuantize + + +class DurationDiscriminator(nn.Module): # vits2 + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.dur_proj = nn.Conv1d(1, filter_channels, 1) + + self.LSTM = nn.LSTM( + 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True + ) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + self.output_layer = nn.Sequential( + nn.Linear(2 * filter_channels, 1), nn.Sigmoid() + ) + + def forward_probability(self, x, dur): + dur = self.dur_proj(dur) + x = torch.cat([x, dur], dim=1) + x = x.transpose(1, 2) + x, _ = self.LSTM(x) + output_prob = self.output_layer(x) + return output_prob + + def forward(self, x, x_mask, dur_r, dur_hat, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + + output_probs = [] + for dur in [dur_r, dur_hat]: + output_prob = self.forward_probability(x, dur) + output_probs.append(output_prob) + + return output_probs + + +class TransformerCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + share_parameter=False, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + + self.wn = ( + attentions.FFT( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=self.gin_channels, + ) + if share_parameter + else None + ) + + for i in range(n_flows): + self.flows.append( + modules.TransformerCouplingLayer( + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout, + filter_channels, + mean_only=True, + wn_sharing_parameter=self.wn, + gin_channels=self.gin_channels, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class Bottleneck(nn.Sequential): + def __init__(self, in_dim, hidden_dim): + c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) + c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) + super().__init__(*[c_fc1, c_fc2]) + + +class Block(nn.Module): + def __init__(self, in_dim, hidden_dim) -> None: + super().__init__() + self.norm = nn.LayerNorm(in_dim) + self.mlp = MLP(in_dim, hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.mlp(self.norm(x)) + return x + + +class MLP(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) + self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False) + + def forward(self, x: torch.Tensor): + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=0, + ): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + self.emb = nn.Embedding(len(symbols), hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + self.tone_emb = nn.Embedding(num_tones, hidden_channels) + nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels ** -0.5) + self.language_emb = nn.Embedding(num_languages, hidden_channels) + nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5) + self.bert_proj = nn.Conv1d(1024, hidden_channels, 1) + # self.bert_pre_proj = nn.Conv1d(2048, 1024, 1) + # self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1) + self.in_feature_net = nn.Sequential( + # input is assumed to an already normalized embedding + nn.Linear(512, 1028, bias=False), + nn.GELU(), + nn.LayerNorm(1028), + *[Block(1028, 512) for _ in range(1)], + nn.Linear(1028, 512, bias=False), + # normalize before passing to VQ? + # nn.GELU(), + # nn.LayerNorm(512), + ) + self.emo_vq = VectorQuantize( + dim=512, + # codebook_size=128, + codebook_size=256, + codebook_dim=16, + # codebook_dim=32, + commitment_weight=0.1, + decay=0.99, + heads=32, + kmeans_iters=20, + separate_codebook_per_head=True, + stochastic_sample_codes=True, + threshold_ema_dead_code=2, + use_cosine_sim=True, + ) + self.out_feature_net = nn.Linear(512, hidden_channels) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=self.gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, tone, language, bert, emo, g=None): + bert_emb = self.bert_proj(bert).transpose(1, 2) + # en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2) + emo_emb = self.in_feature_net(emo) + emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1)) + loss_commit = loss_commit.mean() + emo_emb = self.out_feature_net(emo_emb) + x = ( + self.emb(x) + + self.tone_emb(tone) + + self.language_emb(language) + + bert_emb + # + en_bert_emb + + emo_emb + ) * math.sqrt( + self.hidden_channels + ) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + + x = self.encoder(x * x_mask, x_mask, g=g) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask, loss_commit + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class WavLMDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__( + self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False + ): + super(WavLMDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.pre = norm_f( + Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0) + ) + + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv1d( + initial_channel, initial_channel * 2, kernel_size=5, padding=2 + ) + ), + norm_f( + nn.Conv1d( + initial_channel * 2, + initial_channel * 4, + kernel_size=5, + padding=2, + ) + ), + norm_f( + nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2) + ), + ] + ) + + self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) + + def forward(self, x): + x = self.pre(x) + + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501 + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + + def forward(self, inputs, mask=None): + N = inputs.size(0) + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=256, + gin_channels=256, + use_sdp=True, + n_flow_layer=4, + n_layers_trans_flow=6, + flow_share_parameter=False, + use_transformer_flow=True, + **kwargs + ): + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + self.n_layers_trans_flow = n_layers_trans_flow + self.use_spk_conditioned_encoder = kwargs.get( + "use_spk_conditioned_encoder", True + ) + self.use_sdp = use_sdp + self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False) + self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01) + self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6) + self.current_mas_noise_scale = self.mas_noise_scale_initial + if self.use_spk_conditioned_encoder and gin_channels > 0: + self.enc_gin_channels = gin_channels + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + gin_channels=self.enc_gin_channels, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + if use_transformer_flow: + self.flow = TransformerCouplingBlock( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers_trans_flow, + 5, + p_dropout, + n_flow_layer, + gin_channels=gin_channels, + share_parameter=flow_share_parameter, + ) + else: + self.flow = ResidualCouplingBlock( + inter_channels, + hidden_channels, + 5, + 1, + n_flow_layer, + gin_channels=gin_channels, + ) + self.sdp = StochasticDurationPredictor( + hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels + ) + self.dp = DurationPredictor( + hidden_channels, 256, 3, 0.5, gin_channels=gin_channels + ) + + if n_speakers >= 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + else: + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) + + def infer( + self, + x, + x_lengths, + sid, + tone, + language, + ja_bert, + emo, + noise_scale=0.667, + length_scale=1, + noise_scale_w=0.8, + max_len=None, + sdp_ratio=0, + y=None, + **kwargs + ): + # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, ja_bert) + # g = self.gst(y) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1) + x, m_p, logs_p, x_mask, _ = self.enc_p( + x, x_lengths, tone, language, ja_bert, emo, g=g + ) + logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * ( + sdp_ratio + ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to( + x_mask.dtype + ) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) diff --git a/bert_vits2/text/cleaner.py b/bert_vits2/text/cleaner.py index 8baa52e..52ef399 100644 --- a/bert_vits2/text/cleaner.py +++ b/bert_vits2/text/cleaner.py @@ -1,5 +1,5 @@ from bert_vits2.text import chinese, japanese, english, cleaned_text_to_sequence, japanese_v111, chinese_v100, \ - japanese_v200, english_v200, english_v230, chinese_v240 + japanese_v200, english_v200, english_v230, chinese_v240, japanese_extra language_module_map = { 'zh': chinese, @@ -11,6 +11,7 @@ 'en_v200': english_v200, 'en_v230': english_v230, 'zh_v240': chinese_v240, + 'ja_extra': japanese_extra, } diff --git a/bert_vits2/text/japanese.py b/bert_vits2/text/japanese.py index 2350cb4..284088d 100644 --- a/bert_vits2/text/japanese.py +++ b/bert_vits2/text/japanese.py @@ -413,16 +413,16 @@ def g2p(norm_text, tokenizer, **kwargs): if __name__ == "__main__": - from transformers import AutoTokenizer + from manager import model_handler - tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese") + tokenizer, _ = model_handler.get_bert_model("DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM") text = "hello,こんにちは、世界ー!……" from bert_vits2.text.japanese_bert import get_bert_feature text = text_normalize(text) print(text) - phones, tones, word2ph = g2p(text) + phones, tones, word2ph = g2p(text, tokenizer) bert = get_bert_feature(text, word2ph) print(phones, tones, word2ph, bert.shape) diff --git a/bert_vits2/text/japanese_bert_extra.py b/bert_vits2/text/japanese_bert_extra.py new file mode 100644 index 0000000..1455342 --- /dev/null +++ b/bert_vits2/text/japanese_bert_extra.py @@ -0,0 +1,42 @@ +import torch + +from contants import config +from bert_vits2.text.japanese import text2sep_kata + + +def get_bert_feature(text, word2ph, tokenizer, model, device=config.system.device, style_text=None, style_weight=0.7, + **kwargs): + text = "".join(text2sep_kata(text)[0]) + if style_text: + style_text = "".join(text2sep_kata(style_text)[0]) + + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) + res = model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu() + if style_text: + style_inputs = tokenizer(style_text, return_tensors="pt") + for i in style_inputs: + style_inputs[i] = style_inputs[i].to(device) + style_res = model(**style_inputs, output_hidden_states=True) + style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu() + style_res_mean = style_res.mean(0) + + assert len(word2ph) == len(text) + 2 + word2phone = word2ph + phone_level_feature = [] + for i in range(len(word2phone)): + if style_text: + repeat_feature = ( + res[i].repeat(word2phone[i], 1) * (1 - style_weight) + + style_res_mean.repeat(word2phone[i], 1) * style_weight + ) + else: + repeat_feature = res[i].repeat(word2phone[i], 1) + phone_level_feature.append(repeat_feature) + + phone_level_feature = torch.cat(phone_level_feature, dim=0) + + return phone_level_feature.T diff --git a/bert_vits2/text/japanese_extra.py b/bert_vits2/text/japanese_extra.py new file mode 100644 index 0000000..61b7184 --- /dev/null +++ b/bert_vits2/text/japanese_extra.py @@ -0,0 +1,524 @@ +# Convert Japanese text to phonemes which is +# compatible with Julius https://github.com/julius-speech/segmentation-kit +import re +import unicodedata + +import pyopenjtalk +from num2words import num2words + +from bert_vits2.text import punctuation +from bert_vits2.text.japanese_mora_list import ( + mora_kata_to_mora_phonemes, +) + +# 子音の集合 +COSONANTS = set( + [ + cosonant + for cosonant, _ in mora_kata_to_mora_phonemes.values() + if cosonant is not None + ] +) + +# 母音の集合 +VOWELS = {"a", "i", "u", "e", "o"} + +# 正規化で記号を変換するための辞書 +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + ".": ".", + "…": "...", + "···": "...", + "・・・": "...", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "−": "-", + # "~": "-", # これは長音記号「ー」として扱うよう変更 + # "~": "-", # これは長音記号「ー」として扱うよう変更 + "「": "'", + "」": "'", +} + + +def text_normalize(text): + """ + 日本語のテキストを正規化する。 + 結果は、ちょうど次の文字のみからなる: + - ひらがな + - カタカナ(全角長音記号「ー」が入る!) + - 漢字 + - 半角アルファベット(大文字と小文字) + - ギリシャ文字 + - `.` (句点`。`や`…`の一部や改行等) + - `,` (読点`、`や`:`等) + - `?` (疑問符`?`) + - `!` (感嘆符`!`) + - `'` (`「`や`」`等) + - `-` (`―`(ダッシュ、長音記号ではない)や`-`等) + + 注意点: + - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`) + - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`) + - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`) + """ + # print(f"Before normalization: {text}") + # ここでアルファベットは半角になり、三点リーダは`...`になる + res = unicodedata.normalize("NFKC", text) + + res = japanese_convert_numbers_to_words(res) # 「100円」→「百円」等 + + # 「~」と「~」も長音記号として扱う + res = res.replace("~", "ー") + res = res.replace("~", "ー") + + res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除 + + # 結合文字の濁点・半濁点を削除 + # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる + res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る + res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な + return res + + +def replace_punctuation(text: str) -> str: + """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す: + 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字 + """ + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + # 句読点を辞書で置換 + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub( + # ↓ ひらがな、カタカナ、漢字 + r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" + # ↓ 半角アルファベット(大文字と小文字) + + r"\u0041-\u005A\u0061-\u007A" + # ↓ 全角アルファベット(大文字と小文字) + + r"\uFF21-\uFF3A\uFF41-\uFF5A" + # ↓ ギリシャ文字 + + r"\u0370-\u03FF\u1F00-\u1FFF" + # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている + + "".join(punctuation) + r"]+", + # 上述以外の文字を削除 + "", + replaced_text, + ) + + return replaced_text + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def g2p(norm_text: str, tokenizer, **kwargs) -> tuple[list[str], list[int], list[int]]: + """ + 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、 + - phones: 音素のリスト(ただし`!`や`,`や`.`等punctuationが含まれうる) + - tones: アクセントのリスト、0(低)と1(高)からなり、phonesと同じ長さ + - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + のタプルを返す。 + ただし`phones`と`tones`の最初と終わりに`_`が入り、応じて`word2ph`の最初と最後に1が追加される。 + """ + # pyopenjtalkのフルコンテキストラベルを使ってアクセントを取り出すと、punctuationの位置が消えてしまい情報が失われてしまう: + # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。 + # よって、まずpunctuation無しの音素とアクセントのリストを作り、 + # それとは別にpyopenjtalk.run_frontend()で得られる音素リスト(こちらはpunctuationが保持される)を使い、 + # アクセント割当をしなおすことによってpunctuationを含めた音素とアクセントのリストを作る。 + + # punctuationがすべて消えた、音素とアクセントのタプルのリスト + phone_tone_list_wo_punct = g2phone_tone_wo_punct(norm_text) + + # sep_text: 単語単位の単語のリスト + # sep_kata: 単語単位の単語のカタカナ読みのリスト + sep_text, sep_kata = text2sep_kata(norm_text) + + # sep_phonemes: 各単語ごとの音素のリストのリスト + sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata]) + + # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列 + phone_w_punct: list[str] = [] + for i in sep_phonemes: + phone_w_punct += i + + # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る + phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct) + # word2phは厳密な解答は不可能なので(「今日」「眼鏡」等の熟字訓が存在)、 + # Bert-VITS2では、単語単位の分割を使って、単語の文字ごとにだいたい均等に音素を分配する + + # sep_textから、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る + sep_tokenized: list[list[str]] = [] + for i in sep_text: + if i not in punctuation: + sep_tokenized.append(tokenizer.tokenize(i)) # ここでおそらく`i`が文字単位に分割される + else: + sep_tokenized.append([i]) + + # 各単語について、音素の数と文字の数を比較して、均等っぽく分配する + word2ph = [] + for token, phoneme in zip(sep_tokenized, sep_phonemes): + phone_len = len(phoneme) + word_len = len(token) + word2ph += distribute_phone(phone_len, word_len) + + # 最初と最後に`_`記号を追加、アクセントは0(低)、word2phもそれに合わせて追加 + phone_tone_list = [("_", 0)] + phone_tone_list + [("_", 0)] + word2ph = [1] + word2ph + [1] + + phones = [phone for phone, _ in phone_tone_list] + tones = [tone for _, tone in phone_tone_list] + + assert len(phones) == sum(word2ph), f"{len(phones)} != {sum(word2ph)}" + + return phones, tones, word2ph + + +def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: + """ + テキストに対して、音素とアクセント(0か1)のペアのリストを返す。 + ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。 + 非音素記号を含める処理は`align_tones()`で行われる。 + また「っ」は「cl」でなく「q」に変換される(「ん」は「N」のまま)。 + 例: "こんにちは、世界ー。。元気?!" → + [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)] + """ + prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True) + result: list[tuple[str, int]] = [] + current_phrase: list[tuple[str, int]] = [] + current_tone = 0 + for i, letter in enumerate(prosodies): + # 特殊記号の処理 + + # 文頭記号、無視する + if letter == "^": + assert i == 0, "Unexpected ^" + # アクセント句の終わりに来る記号 + elif letter in ("$", "?", "_", "#"): + # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加 + result.extend(fix_phone_tone(current_phrase)) + # 末尾に来る終了記号、無視(文中の疑問文は`_`になる) + if letter in ("$", "?"): + assert i == len(prosodies) - 1, f"Unexpected {letter}" + # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ + # これらは残さず、次のアクセント句に備える。 + current_phrase = [] + # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る) + current_tone = 0 + # アクセント上昇記号 + elif letter == "[": + current_tone = current_tone + 1 + # アクセント下降記号 + elif letter == "]": + current_tone = current_tone - 1 + # それ以外は通常の音素 + else: + if letter == "cl": # 「っ」の処理 + letter = "q" + current_phrase.append((letter, current_tone)) + return result + + +def text2sep_kata(norm_text: str) -> tuple[list[str], list[str]]: + """ + `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、 + 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。 + 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。 + 例: + `私はそう思う!って感じ?` → + ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"] + """ + # parsed: OpenJTalkの解析結果 + parsed = pyopenjtalk.run_frontend(norm_text) + sep_text: list[str] = [] + sep_kata: list[str] = [] + for parts in parsed: + # word: 実際の単語の文字列 + # yomi: その読み、但し無声化サインの`’`は除去 + word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace( + "’", "" + ) + """ + ここで`yomi`の取りうる値は以下の通りのはず。 + - `word`が通常単語 → 通常の読み(カタカナ) + (カタカナからなり、長音記号も含みうる、`アー` 等) + - `word`が`ー` から始まる → `ーラー` や `ーーー` など + - `word`が句読点や空白等 → `、` + - `word`が`?` → `?`(全角になる) + 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。 + また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。 + 処理すべきは`yomi`が`、`の場合のみのはず。 + """ + assert yomi != "", f"Empty yomi: {word}" + if yomi == "、": + # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`のいずれか + if word not in ( + ".", + ",", + "!", + "'", + "-", + ): + # ここはpyopenjtalkが読めない文字等のときに起こる + raise ValueError(f"Cannot read: {word} in:\n{norm_text}") + # yomiは元の記号のままに変更 + yomi = word + elif yomi == "?": + assert word == "?", f"yomi `?` comes from: {word}" + yomi = "?" + sep_text.append(word) + sep_kata.append(yomi) + return sep_text, sep_kata + + +# ESPnetの実装から引用、変更点無し +# https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py +def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: + """Extract phoneme + prosoody symbol sequence from input full-context labels. + + The algorithm is based on `Prosodic features control by symbols as input of + sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. + + Args: + text (str): Input text. + drop_unvoiced_vowels (bool): whether to drop unvoiced vowels. + + Returns: + List[str]: List of phoneme + prosody symbols. + + Examples: + #>>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody + #>>> pyopenjtalk_g2p_prosody("こんにちは。") + ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] + + .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic + modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 + + """ + labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) + N = len(labels) + + phones = [] + for n in range(N): + lab_curr = labels[n] + + # current phoneme + p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) + # deal unvoiced vowels as normal vowels + if drop_unvoiced_vowels and p3 in "AEIOU": + p3 = p3.lower() + + # deal with sil at the beginning and the end of text + if p3 == "sil": + assert n == 0 or n == N - 1 + if n == 0: + phones.append("^") + elif n == N - 1: + # check question form or not + e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) + if e3 == 0: + phones.append("$") + elif e3 == 1: + phones.append("?") + continue + elif p3 == "pau": + phones.append("_") + continue + else: + phones.append(p3) + + # accent type and position info (forward or backward) + a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) + a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) + a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) + + # number of mora in accent phrase + f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) + + a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) + # accent phrase border + if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": + phones.append("#") + # pitch falling + elif a1 == 0 and a2_next == a2 + 1 and a2 != f1: + phones.append("]") + # pitch rising + elif a2 == 1 and a2_next == 2: + phones.append("[") + + return phones + + +def _numeric_feature_by_regex(regex, s): + match = re.search(regex, s) + if match is None: + return -50 + return int(match.group(1)) + + +def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]: + """ + `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。 + 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)] + """ + tone_values = set(tone for _, tone in phone_tone_list) + if len(tone_values) == 1: + assert tone_values == {0}, tone_values + return phone_tone_list + elif len(tone_values) == 2: + if tone_values == {0, 1}: + return phone_tone_list + elif tone_values == {-1, 0}: + return [ + (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list + ] + else: + raise ValueError(f"Unexpected tone values: {tone_values}") + else: + raise ValueError(f"Unexpected tone values: {tone_values}") + + +def distribute_phone(n_phone: int, n_word: int) -> list[int]: + """ + 左から右に1ずつ振り分け、次にまた左から右に1ずつ増やし、というふうに、 + 音素の数`n_phone`を単語の数`n_word`に分配する。 + """ + phones_per_word = [0] * n_word + for _ in range(n_phone): + min_tasks = min(phones_per_word) + min_index = phones_per_word.index(min_tasks) + phones_per_word[min_index] += 1 + return phones_per_word + + +def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: + for i in range(len(sep_phonemes)): + if sep_phonemes[i][0] == "ー": + sep_phonemes[i][0] = sep_phonemes[i - 1][-1] + if "ー" in sep_phonemes[i]: + for j in range(len(sep_phonemes[i])): + if sep_phonemes[i][j] == "ー": + sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1] + return sep_phonemes + + +def align_tones( + phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]] +) -> list[tuple[str, int]]: + """ + 例: + …私は、、そう思う。 + phones_with_punct: + [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."] + phone_tone_list: + [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))] + Return: + [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)] + """ + result: list[tuple[str, int]] = [] + tone_index = 0 + for phone in phones_with_punct: + if tone_index >= len(phone_tone_list): + # 余ったpunctuationがある場合 → (punctuation, 0)を追加 + result.append((phone, 0)) + elif phone == phone_tone_list[tone_index][0]: + # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加 + result.append((phone, phone_tone_list[tone_index][1])) + # 探すindexを1つ進める + tone_index += 1 + elif phone in punctuation: + # phoneがpunctuationの場合 → (phone, 0)を追加 + result.append((phone, 0)) + else: + print(f"phones: {phones_with_punct}") + print(f"phone_tone_list: {phone_tone_list}") + print(f"result: {result}") + print(f"tone_index: {tone_index}") + print(f"phone: {phone}") + raise ValueError(f"Unexpected phone: {phone}") + return result + + +def kata2phoneme_list(text: str) -> list[str]: + """ + 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。 + 注意点: + - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す + - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される) + - 文中の「ー」は前の音素記号の最後の音素記号に変換される。 + 例: + `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"] + `?` → ["?"] + """ + if text in punctuation: + return [text] + # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック + if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: + raise ValueError(f"Input must be katakana only: {text}") + sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True) + pattern = "|".join(map(re.escape, sorted_keys)) + + def mora2phonemes(mora: str) -> str: + cosonant, vowel = mora_kata_to_mora_phonemes[mora] + if cosonant is None: + return f" {vowel}" + return f" {cosonant} {vowel}" + + spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text) + + # 長音記号「ー」の処理 + long_pattern = r"(\w)(ー*)" + long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) + spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes) + return spaced_phonemes.strip().split(" ") + + +if __name__ == "__main__": + from manager import model_handler + + tokenizer, _ = model_handler.get_bert_model("DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM") + text = "hello,こんにちは、世界ー~!……" + + from bert_vits2.text.japanese_bert import get_bert_feature + + text = text_normalize(text) + print(text) + + phones, tones, word2ph = g2p(text) + print(phones, tones, word2ph) + bert = get_bert_feature(text, word2ph) + + print(phones, tones, word2ph, bert.shape) diff --git a/bert_vits2/text/japanese_mora_list.py b/bert_vits2/text/japanese_mora_list.py new file mode 100644 index 0000000..d3a1dc4 --- /dev/null +++ b/bert_vits2/text/japanese_mora_list.py @@ -0,0 +1,230 @@ +""" +VOICEVOXのソースコードからお借りして最低限に改造したコード。 +https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/tts_pipeline/mora_list.py +""" +""" +以下のモーラ対応表はOpenJTalkのソースコードから取得し、 +カタカナ表記とモーラが一対一対応するように改造した。 +ライセンス表記: +----------------------------------------------------------------- + The Japanese TTS System "Open JTalk" + developed by HTS Working Group + http://open-jtalk.sourceforge.net/ +----------------------------------------------------------------- + + Copyright (c) 2008-2014 Nagoya Institute of Technology + Department of Computer Science + +All rights reserved. + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +- Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. +- Neither the name of the HTS working group nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" +from typing import Optional + +# (カタカナ, 子音, 母音)の順。子音がない場合はNoneを入れる。 +# 但し「ン」と「ッ」は母音のみという扱いで、それぞれ「N」「q (clから変更)」 +# また「デェ = dy e」はpyopenjtalkの出力(de e)と合わないため削除 +_mora_list_minimum: list[tuple[str, Optional[str], str]] = [ + ("ヴォ", "v", "o"), + ("ヴェ", "v", "e"), + ("ヴィ", "v", "i"), + ("ヴァ", "v", "a"), + ("ヴ", "v", "u"), + ("ン", None, "N"), + ("ワ", "w", "a"), + ("ロ", "r", "o"), + ("レ", "r", "e"), + ("ル", "r", "u"), + ("リョ", "ry", "o"), + ("リュ", "ry", "u"), + ("リャ", "ry", "a"), + ("リェ", "ry", "e"), + ("リ", "r", "i"), + ("ラ", "r", "a"), + ("ヨ", "y", "o"), + ("ユ", "y", "u"), + ("ヤ", "y", "a"), + ("モ", "m", "o"), + ("メ", "m", "e"), + ("ム", "m", "u"), + ("ミョ", "my", "o"), + ("ミュ", "my", "u"), + ("ミャ", "my", "a"), + ("ミェ", "my", "e"), + ("ミ", "m", "i"), + ("マ", "m", "a"), + ("ポ", "p", "o"), + ("ボ", "b", "o"), + ("ホ", "h", "o"), + ("ペ", "p", "e"), + ("ベ", "b", "e"), + ("ヘ", "h", "e"), + ("プ", "p", "u"), + ("ブ", "b", "u"), + ("フォ", "f", "o"), + ("フェ", "f", "e"), + ("フィ", "f", "i"), + ("ファ", "f", "a"), + ("フ", "f", "u"), + ("ピョ", "py", "o"), + ("ピュ", "py", "u"), + ("ピャ", "py", "a"), + ("ピェ", "py", "e"), + ("ピ", "p", "i"), + ("ビョ", "by", "o"), + ("ビュ", "by", "u"), + ("ビャ", "by", "a"), + ("ビェ", "by", "e"), + ("ビ", "b", "i"), + ("ヒョ", "hy", "o"), + ("ヒュ", "hy", "u"), + ("ヒャ", "hy", "a"), + ("ヒェ", "hy", "e"), + ("ヒ", "h", "i"), + ("パ", "p", "a"), + ("バ", "b", "a"), + ("ハ", "h", "a"), + ("ノ", "n", "o"), + ("ネ", "n", "e"), + ("ヌ", "n", "u"), + ("ニョ", "ny", "o"), + ("ニュ", "ny", "u"), + ("ニャ", "ny", "a"), + ("ニェ", "ny", "e"), + ("ニ", "n", "i"), + ("ナ", "n", "a"), + ("ドゥ", "d", "u"), + ("ド", "d", "o"), + ("トゥ", "t", "u"), + ("ト", "t", "o"), + ("デョ", "dy", "o"), + ("デュ", "dy", "u"), + ("デャ", "dy", "a"), + # ("デェ", "dy", "e"), + ("ディ", "d", "i"), + ("デ", "d", "e"), + ("テョ", "ty", "o"), + ("テュ", "ty", "u"), + ("テャ", "ty", "a"), + ("ティ", "t", "i"), + ("テ", "t", "e"), + ("ツォ", "ts", "o"), + ("ツェ", "ts", "e"), + ("ツィ", "ts", "i"), + ("ツァ", "ts", "a"), + ("ツ", "ts", "u"), + ("ッ", None, "q"), # 「cl」から「q」に変更 + ("チョ", "ch", "o"), + ("チュ", "ch", "u"), + ("チャ", "ch", "a"), + ("チェ", "ch", "e"), + ("チ", "ch", "i"), + ("ダ", "d", "a"), + ("タ", "t", "a"), + ("ゾ", "z", "o"), + ("ソ", "s", "o"), + ("ゼ", "z", "e"), + ("セ", "s", "e"), + ("ズィ", "z", "i"), + ("ズ", "z", "u"), + ("スィ", "s", "i"), + ("ス", "s", "u"), + ("ジョ", "j", "o"), + ("ジュ", "j", "u"), + ("ジャ", "j", "a"), + ("ジェ", "j", "e"), + ("ジ", "j", "i"), + ("ショ", "sh", "o"), + ("シュ", "sh", "u"), + ("シャ", "sh", "a"), + ("シェ", "sh", "e"), + ("シ", "sh", "i"), + ("ザ", "z", "a"), + ("サ", "s", "a"), + ("ゴ", "g", "o"), + ("コ", "k", "o"), + ("ゲ", "g", "e"), + ("ケ", "k", "e"), + ("グヮ", "gw", "a"), + ("グ", "g", "u"), + ("クヮ", "kw", "a"), + ("ク", "k", "u"), + ("ギョ", "gy", "o"), + ("ギュ", "gy", "u"), + ("ギャ", "gy", "a"), + ("ギェ", "gy", "e"), + ("ギ", "g", "i"), + ("キョ", "ky", "o"), + ("キュ", "ky", "u"), + ("キャ", "ky", "a"), + ("キェ", "ky", "e"), + ("キ", "k", "i"), + ("ガ", "g", "a"), + ("カ", "k", "a"), + ("オ", None, "o"), + ("エ", None, "e"), + ("ウォ", "w", "o"), + ("ウェ", "w", "e"), + ("ウィ", "w", "i"), + ("ウ", None, "u"), + ("イェ", "y", "e"), + ("イ", None, "i"), + ("ア", None, "a"), +] +_mora_list_additional: list[tuple[str, Optional[str], str]] = [ + ("ヴョ", "by", "o"), + ("ヴュ", "by", "u"), + ("ヴャ", "by", "a"), + ("ヲ", None, "o"), + ("ヱ", None, "e"), + ("ヰ", None, "i"), + ("ヮ", "w", "a"), + ("ョ", "y", "o"), + ("ュ", "y", "u"), + ("ヅ", "z", "u"), + ("ヂ", "j", "i"), + ("ヶ", "k", "e"), + ("ャ", "y", "a"), + ("ォ", None, "o"), + ("ェ", None, "e"), + ("ゥ", None, "u"), + ("ィ", None, "i"), + ("ァ", None, "a"), +] + +# 例: "vo" -> "ヴォ", "a" -> "ア" +mora_phonemes_to_mora_kata: dict[str, str] = { + (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum +} + +# 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a") +mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = { + kana: (consonant, vowel) + for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional +} diff --git a/manager/TTSManager.py b/manager/TTSManager.py index ad3c0f9..2710004 100644 --- a/manager/TTSManager.py +++ b/manager/TTSManager.py @@ -481,6 +481,9 @@ def bert_vits2_infer_v2(self, state, encode=True): if model.zh_bert_extra: infer_func = model.infer state["lang"] = "zh" + elif model.ja_bert_extra: + infer_func = model.infer + state["lang"] = "ja" elif state["lang"].lower() == "auto": infer_func = model.infer_multilang else: diff --git a/manager/model_handler.py b/manager/model_handler.py index 5ef2ea2..07f34b8 100644 --- a/manager/model_handler.py +++ b/manager/model_handler.py @@ -18,6 +18,7 @@ from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200 from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200 from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra +from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra class ModelHandler: @@ -114,7 +115,8 @@ def __init__(self, device=config.system.device): } self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111, - "ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra} + "ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra, + "ja_extra": ja_bert_extra} self.bert_models = {} # Value: (tokenizer, model, reference_count) self.emotion = None