Skip to content

Commit

Permalink
Merge pull request espnet#5579 from ftshijt/spk_inference
Browse files Browse the repository at this point in the history
Speaker embedding extractor (with ESPnet pre-trained speaker model)
  • Loading branch information
sw005320 authored Jan 10, 2024
2 parents 31d2c57 + d0740d1 commit 3b2e0d3
Show file tree
Hide file tree
Showing 26 changed files with 516 additions and 724 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,33 @@
def get_parser():
parser = argparse.ArgumentParser(
description="""
Replaces xvectors in a specified xvector directory with the average xvector
Replaces spk_embeds in a specified spk_embed directory with the average spk_embed
for a given speaker.
The xvectors generally reside in dump/xvector/<data_subset>/xvector.scp, whereas
speaker-averaged xvectors reside in dump/xvector/<data_subset>/spk_xvector.scp.
The spk_embeds generally reside in
dump/${spk_embed_tag}/<data_subset>/${spk_embed_tag}.scp, whereas
speaker-averaged spk_embeds reside in
dump/${spk_embed_tag}/<data_subset>/spk_${spk_embed_tag}.scp.
The old xvector.scp file will be renamed to xvector.scp.bak and
The old spk_embed.scp file will be renamed to spk_embed.scp.bak and
the corresponding .ark files are left unchanged.
If no speaker id is provided, the average xvector for the speaker who
If no speaker id is provided, the average spk_embed for the speaker who
the utterance belongs to will be used in each case.
At inference time in a TTS task, you are unlikely to have the xvector
for that sentence in particular. Thus, using average xvectors
At inference time in a TTS task, you are unlikely to have the spk_embed
for that sentence in particular. Thus, using average spk_embeds
during training may yield better performance at inference time.
This is also useful for conditioning inference on a particular speaker.
To transform the training data, this script should be run after
xvectors are extracted (stage 2), but before training commences (stage 6).
spk_embeds are extracted (stage 3), but before training commences (stage 7).
"""
)
parser.add_argument(
"--xvector-path",
"--utt-embed-path",
type=str,
required=True,
help="Path to the xvector file to be modified.",
help="Path to the spk_embed file to be modified.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
Expand All @@ -47,30 +49,30 @@ def get_parser():
help="Path to the relevant utt2spk file, if the source speakers are used",
)
parser.add_argument(
"--spk-xvector-path",
"--spk-embed-path",
type=str,
required=True,
help="The path to the spk_xvector.scp file for the speakers being used.",
help="The path to the spk_{spk_embed_tag}.scp for the speakers being used.",
)
return parser


def check_args(args):
xvector_path = args.xvector_path
spk_xvector_path = args.spk_xvector_path
utt_embed_path = args.utt_embed_path
spk_embed_path = args.spk_embed_path
utt2spk = args.utt2spk

if not os.path.exists(xvector_path):
if not os.path.exists(utt_embed_path):
sys.stderr.write(
f"Error: provided --xvector-path ({xvector_path}) does not exist. "
f"Error: provided --utt-embed-path ({utt_embed_path}) does not exist. "
)
sys.stderr.write("Exiting...\n")
sys.stderr.flush()
exit(1)

if not os.path.exists(spk_xvector_path):
if not os.path.exists(spk_embed_path):
sys.stderr.write(
f"Error: provided --spk-xvector-path ({spk_xvector_path}) does not exist. "
f"Error: provided --spk-embed-path ({spk_embed_path}) does not exist. "
)
sys.stderr.write("Exiting...\n")
sys.stderr.flush()
Expand All @@ -88,34 +90,37 @@ def check_args(args):
check_args(args)
spk_id = args.spk_id
utt2spk = args.utt2spk
xvector_path = args.xvector_path
spk_xvector_path = args.spk_xvector_path
utt_embed_path = args.utt_embed_path
spk_embed_path = args.spk_embed_path

print(f"Loading {spk_xvector_path}...")
spk_xvector_paths = {}
with open(spk_xvector_path) as spembfile:
print(f"Loading {spk_embed_path}...")
spk_embed_paths = {}
with open(spk_embed_path) as spembfile:
for line in spembfile.readlines():
spkid, spembpath = line.split()
spk_xvector_paths[spkid] = spembpath
spk_embed_paths[spkid] = spembpath

if spk_id and (spk_id not in spk_xvector_paths):
if spk_id and (spk_id not in spk_embed_paths):
sys.stderr.write(
f"Error: provided --spk-id: {spk_id} not present in --spk-xvector-path."
f"Error: provided --spk-id: {spk_id} not present in --spk-embed-path."
)
sys.stderr.write("Exiting...\n")
sys.stderr.flush()
exit(1)

print("Backing up xvector file...")
print(os.path.dirname(xvector_path))
shutil.copy(xvector_path, f"{os.path.dirname(xvector_path)}/xvector.scp.bak")
print("Backing up utt_embed file...")
print(os.path.dirname(utt_embed_path))
shutil.copy(
utt_embed_path,
f"{os.path.dirname(utt_embed_path)}/os.path.filename(utt_embed_path).bak",
)

utt2xvector = []
with open(args.xvector_path) as f:
utt2spk_embed = []
with open(args.utt_embed_path) as f:
for line in f.readlines():
utt, xvector = line.split()
utt2xvector.append((utt, spk_xvector_paths[spk_id]))
utt, spk_embed = line.split()
utt2spk_embed.append((utt, spk_embed_paths[spk_id]))

with open(args.xvector_path, "w") as f:
for utt, xvector in utt2xvector:
f.write(f"{utt} {xvector}\n")
with open(args.utt_embed_path, "w") as f:
for utt, spk_embed in utt2spk_embed:
f.write(f"{utt} {spk_embed}\n")
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@ def get_parser():
parser.add_argument(
"--toolkit",
type=str,
help="Toolkit for Extracting X-vectors.",
help="Toolkit for Extracting speaker speaker embeddingss.",
choices=["espnet", "speechbrain", "rawnet"],
)
parser.add_argument(
"--spk_embed_tag",
type=str,
help="the target data name (e.g., xvector for xvector.scp)",
default="spk_embed",
)
parser.add_argument("--verbose", type=int, default=1, help="Verbosity level.")
parser.add_argument("--device", type=str, default="cuda:0", help="Inference device")
parser.add_argument(
Expand All @@ -38,15 +44,16 @@ def get_parser():
parser.add_argument(
"out_folder",
type=Path,
help="Output folder to save the xvectors.",
help="Output folder to save the speaker embeddings.",
)
return parser


class XVExtractor:
class SpkEmbedExtractor:
def __init__(self, args, device):
self.toolkit = args.toolkit
self.device = device
self.tgt_sr = 16000 # NOTE(jiatong): following 16khz convetion

if self.toolkit == "speechbrain":
from speechbrain.dataio.preprocess import AudioNormalizer
Expand Down Expand Up @@ -81,8 +88,37 @@ def __init__(self, args, device):
)["model"]
)
self.model.to(device).eval()
elif self.toolkit == "espnet":
from espnet2.bin.spk_inference import Speech2Embedding

# NOTE(jiatong): set default config file as None
# assume config is the same path as the model file
speech2embedding_kwargs = dict(
batch_size=1,
dtype="float32",
train_config=None,
model_file=args.pretrained_model,
)

if args.pretrained_model.endswith("pth"):
logging.info(
"the provided model path is end with pth,"
"assume it not a huggingface model"
)
model_tag = None
else:
logging.info(
"the provided model path is not end with pth,"
"assume use huggingface model"
)
model_tag = args.pretrained_model

self.speech2embedding = Speech2Embedding.from_pretrained(
model_tag=model_tag,
**speech2embedding_kwargs,
)

def rawnet_extract_embd(self, audio, n_samples=48000, n_segments=10):
def _rawnet_extract_embd(self, audio, n_samples=48000, n_segments=10):
if len(audio.shape) > 1:
raise ValueError(
"RawNet3 supports mono input only."
Expand All @@ -102,13 +138,31 @@ def rawnet_extract_embd(self, audio, n_samples=48000, n_segments=10):
output = self.model(audios)
return output.mean(0).detach().cpu().numpy()

def _espnet_extract_embd(self, audio):
if len(audio.shape) == 2:
logging.info(
"Not support multi-channel input for ESPnet pre-trained model"
f"Input data has shape {audio.shape}, default set avg across channel"
)
audio = np.mean(audio, axis=0)
elif len(audio.shape) > 1:
raise ValueError(f"Input data has shape {audio.shape} thatis not support")
audio = torch.from_numpy(audio.astype(np.float32)).to(self.device)
output = self.speech2embedding(audio)
return output.cpu().numpy()

def __call__(self, wav, in_sr):
if self.toolkit == "speechbrain":
wav = self.audio_norm(torch.from_numpy(wav), in_sr).to(self.device)
embeds = self.model.encode_batch(wav).detach().cpu().numpy()[0]
elif self.toolkit == "rawnet":
wav = librosa.resample(wav, orig_sr=in_sr, target_sr=16000)
embeds = self.rawnet_extract_embd(wav)
if in_sr != self.tgt_sr:
wav = librosa.resample(wav, orig_sr=in_sr, target_sr=self.tgt_sr)
embeds = self._rawnet_extract_embd(wav)
elif self.toolkit == "espnet":
if in_sr != self.tgt_sr:
wav = librosa.resample(wav, orig_sr=in_sr, target_sr=self.tgt_sr)
embeds = self._espnet_extract_embd(wav)
return embeds


Expand All @@ -134,7 +188,7 @@ def main(argv):
else:
device = "cpu"

if args.toolkit in ("speechbrain", "rawnet"):
if args.toolkit in ("speechbrain", "rawnet", "espnet"):
# Prepare spk2utt for mean x-vector
spk2utt = dict()
with open(os.path.join(args.in_folder, "spk2utt"), "r") as reader:
Expand All @@ -145,33 +199,33 @@ def main(argv):
wav_scp = SoundScpReader(os.path.join(args.in_folder, "wav.scp"), np.float32)
os.makedirs(args.out_folder, exist_ok=True)
writer_utt = kaldiio.WriteHelper(
"ark,scp:{0}/xvector.ark,{0}/xvector.scp".format(args.out_folder)
"ark,scp:{0}/{1}.ark,{0}/{1}.scp".format(
args.out_folder, args.spk_embed_tag
)
)
writer_spk = kaldiio.WriteHelper(
"ark,scp:{0}/spk_xvector.ark,{0}/spk_xvector.scp".format(args.out_folder)
"ark,scp:{0}/spk_{1}.ark,{0}/spk_{1}.scp".format(
args.out_folder, args.spk_embed_tag
)
)

xv_extractor = XVExtractor(args, device)
spk_embed_extractor = SpkEmbedExtractor(args, device)

for speaker in tqdm(spk2utt):
xvectors = list()
spk_embeddings = list()
for utt in spk2utt[speaker]:
in_sr, wav = wav_scp[utt]
# X-vector Embedding
embeds = xv_extractor(wav, in_sr)
# Speaker Embedding
embeds = spk_embed_extractor(wav, in_sr)
writer_utt[utt] = np.squeeze(embeds)
xvectors.append(embeds)
spk_embeddings.append(embeds)

# Speaker Normalization
embeds = np.mean(np.stack(xvectors, 0), 0)
embeds = np.mean(np.stack(spk_embeddings, 0), 0)
writer_spk[speaker] = embeds
writer_utt.close()
writer_spk.close()

elif args.toolkit == "espnet":
raise NotImplementedError(
"Follow details at: https://github.com/espnet/espnet/issues/3040"
)
else:
raise ValueError(
"Unkown type of toolkit. Only supported: speechbrain, rawnet, espnet, kaldi"
Expand Down
Loading

0 comments on commit 3b2e0d3

Please sign in to comment.