Skip to content

Commit

Permalink
sample rate
Browse files Browse the repository at this point in the history
  • Loading branch information
实一 committed Dec 12, 2022
1 parent 50e1ac2 commit 77b447e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
40 changes: 9 additions & 31 deletions data/s2t_data/unify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from pathlib import Path
import soundfile as sf
import librosa
import torchaudio
from typing import List

Expand Down Expand Up @@ -187,7 +188,8 @@ def __init__(
lang="zh",
text2phone_path=None,
train_stage=2,
n_frames_per_step=1
n_frames_per_step=1,
sample_rate=16000,
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.phone_dict = phone_dict
Expand All @@ -213,6 +215,7 @@ def __init__(
self.data_cfg.get_feature_transforms(split, split.startswith("train"))
)
self.n_frames_per_step = n_frames_per_step
self.sample_rate = sample_rate
self.blank_id = self.phone_dict.index("<blank>")
self.phone_mask_idx = self.phone_dict.index("<mask>")
self.text2phone_tokenizer = None
Expand Down Expand Up @@ -353,11 +356,8 @@ def process_speech_text_pair(self, index, dataset=None):
speed = random.choice([0.9, 1.0, 1.1])
else:
speed = 1.0
# wav, sr = sf.read(BytesIO(base64.urlsafe_b64decode(wav_data)))
wav, sr = sf.read(wav_data)
# if speech_id == "BAC009S0002W0122":
# print(speech_id, wav, sr)
# speed = 0.9
# wav, sr = sf.read(wav_data)
wav, sr = librosa.load(wav_data, self.sample_rate)
# spec_augmentation
fbank = self.prepare_fbank(torch.tensor([wav], dtype=torch.float32), sr, speed, speech_id)

Expand Down Expand Up @@ -467,28 +467,6 @@ def encode_phone(self, phone_item):
line=phone_item, add_if_not_exist=False, append_eos=False).long()
return tokens

def _phone_seq_augmentation(self, phone_seq, max_repeat=3):

new_phone_seq = []
old_phone = self.blank_id
for phone in phone_seq:
sil_repeat = random.randint(1, 10)
if sil_repeat < 6:
new_phone_seq.extend([self.blank_id] * max(sil_repeat, max_repeat))
# 如果连续两个phone_item重复, 必须加sil
elif phone == old_phone:
new_phone_seq.append(self.blank_id)

# 重复
repeat = random.randint(1, 10)
if repeat < 6:
new_phone_seq.extend([phone] * max(repeat, max_repeat))
else:
new_phone_seq.append(phone)

old_phone = phone
return new_phone_seq

def add_noise_to_phone(self, phone, p, random_p=0.1):
num_to_mask = int(math.ceil(phone.size(0) * p))
indices = torch.randperm(phone.size(0))[:num_to_mask]
Expand Down Expand Up @@ -553,9 +531,9 @@ def collater(self, samples, pad_to_length=None):

mask = False
mask_prob = None
# if self.split == "train" and self.train_stage != 1:
# mask = True
# mask_prob = 0.3
if self.split == "train" and self.train_stage != 1:
mask = True
mask_prob = 0.3

res_v1 = collate(
samples_v1,
Expand Down
7 changes: 6 additions & 1 deletion tasks/speech_tasks/unify_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class UnifySpeechTextConfig(OFAConfig):
default=1,
metadata={"help": "n_frames_per_step of fbank"}
)
sample_rate: int = field(
default=16000,
metadata={"help": "sample rate of wav"}
)
phone_dict_path: Optional[str] = field(
default=None,
metadata={"help": "phone_dict_path"}
Expand Down Expand Up @@ -227,7 +231,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
lang=self.cfg.lang,
text2phone_path=self.text2phone_path,
train_stage=self.train_stage,
n_frames_per_step=self.cfg.n_frames_per_step
n_frames_per_step=self.cfg.n_frames_per_step,
sample_rate=self.cfg.sample_rate,
)

def get_batch_iterator(
Expand Down

0 comments on commit 77b447e

Please sign in to comment.