Skip to content

Commit

Permalink
Update Bert-VITS2 japanese Extra
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Mar 13, 2024
1 parent d5afaa8 commit 6c619cc
Show file tree
Hide file tree
Showing 10 changed files with 1,870 additions and 30 deletions.
29 changes: 27 additions & 2 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bert_vits2.get_emo import get_emo
from bert_vits2.models import SynthesizerTrn
from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230
from bert_vits2.models_ja_extra import SynthesizerTrn as SynthesizerTrn_ja_extra
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
Expand All @@ -29,6 +30,7 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):

self.bert_model_names = {}
self.zh_bert_extra = False
self.ja_bert_extra = False
self.ja_bert_dim = 1024
self.num_tones = num_tones
self.pinyinPlus = None
Expand Down Expand Up @@ -164,6 +166,20 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.bert_extra_str_map.update({"zh": "_extra"})
self.text_extra_str_map.update({"zh": "_v240"})

elif self.version is not None and self.version in ["ja_extra"]:
"""
deberta-v2-large-japanese-char-wwm
"""
self.version = "ja_extra"
self.hps_ms.model.emotion_embedding = 2
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["ja"]
self.num_tones = num_tones
self.ja_bert_extra = True
self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"})
self.bert_extra_str_map.update({"ja": "_extra"})
self.text_extra_str_map.update({"ja": "_extra"})

else:
logging.debug("Version information not found. Loaded as the newest version: v2.3.")
self.version = "2.3"
Expand All @@ -185,6 +201,8 @@ def load_model(self, model_handler):

if self.version in ["2.3", "extra", "2.4"]:
Synthesizer = SynthesizerTrn_v230
elif self.version == "ja_extra":
Synthesizer = SynthesizerTrn_ja_extra
else:
Synthesizer = SynthesizerTrn

Expand Down Expand Up @@ -235,6 +253,9 @@ def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7):
if self.zh_bert_extra:
zh_bert = bert
ja_bert, en_bert = None, None
elif self.ja_bert_extra:
ja_bert = bert
zh_bert, en_bert = None, None
elif language_str == "zh":
zh_bert = bert
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
Expand Down Expand Up @@ -287,8 +308,12 @@ def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_rat
x_tst = phones.to(self.device).unsqueeze(0)
tones = tones.to(self.device).unsqueeze(0)
lang_ids = lang_ids.to(self.device).unsqueeze(0)
zh_bert = zh_bert.to(self.device).unsqueeze(0)
if not self.zh_bert_extra:
if self.zh_bert_extra:
zh_bert = zh_bert.to(self.device).unsqueeze(0)
elif self.ja_bert_extra:
ja_bert = ja_bert.to(self.device).unsqueeze(0)
else:
zh_bert = zh_bert.to(self.device).unsqueeze(0)
ja_bert = ja_bert.to(self.device).unsqueeze(0)
en_bert = en_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
Expand Down
43 changes: 20 additions & 23 deletions bert_vits2/commons.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


Expand All @@ -16,8 +14,8 @@ def get_padding(kernel_size, dilation=1):


def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape


Expand All @@ -30,7 +28,9 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q)
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl


Expand All @@ -46,33 +46,31 @@ def rand_gumbel_like(x):


def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
1, x.size(1), 1
) + torch.arange(segment_size, device=x.device)
return torch.gather(x, 2, gather_indices)


def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str


def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
Expand Down Expand Up @@ -108,8 +106,8 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):


def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape


Expand All @@ -130,7 +128,6 @@ def generate_path(duration, mask):
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device

b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
Expand All @@ -157,5 +154,5 @@ def clip_grad_value_(parameters, clip_value, norm_type=2):
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
Loading

0 comments on commit 6c619cc

Please sign in to comment.