Skip to content

Commit

Permalink
Update GPT-SoVITS t2s_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Mar 15, 2024
1 parent 21ca541 commit 61f6c1f
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions gpt_sovits/AR/models/t2s_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,15 @@ def __init__(self, config, norm_first=False, top_k=3, flash_attn_enabled: bool =
ignore_index=self.EOS,
)

if not flash_attn_enabled:
self.enable_flash_attn(flash_attn_enabled)

def enable_flash_attn(self, enable: bool = True):

if not enable:
logging.info("Not Using Flash Attention")
self.infer_panel = self.infer_panel_batch_only
else:
self.infer_panel = self.infer_panel_batch_infer_with_flash_attn
logging.info("Using Flash Attention")
blocks = []

Expand Down Expand Up @@ -499,7 +504,7 @@ def pad_y_eos(self, y, y_mask_int, eos_id):
# 错位
return targets[:, :-1], targets[:, 1:]

def infer_panel(
def infer_panel_batch_infer_with_flash_attn(
self,
x, #####全部文本token
x_lens,
Expand All @@ -510,8 +515,10 @@ def infer_panel(
early_stop_num: int = -1,
temperature: float = 1.0,
):

bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = x + bert_feature
x = self.ar_text_position(x)

# AR Decoder
Expand Down Expand Up @@ -548,28 +555,27 @@ def infer_panel(
y_mask = make_pad_mask(y_lens)
x_mask = make_pad_mask(x_lens)

# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
_xy_padding_mask = (
xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1)
)

x_attn_mask_pad = F.pad(
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
x.device
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)

xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).expand(bsz, -1, -1).to(x.device)
# xy_mask = torch.triu(torch.ones(src_len, src_len, dtype=torch.bool, device=x.device), diagonal=1)
xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).expand(-1, src_len, src_len)
xy_attn_mask = xy_mask.logical_or(xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
xy_attn_mask = new_attn_mask.masked_fill(xy_attn_mask, float("-inf"))

###### decode #####
y_list = [None] * y.shape[0]
Expand Down

0 comments on commit 61f6c1f

Please sign in to comment.