Skip to content

Commit

Permalink
Ensure MetricGAN recipe works with small dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
pplantinga committed Jul 21, 2023
1 parent dc1e16c commit f053563
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
7 changes: 6 additions & 1 deletion recipes/Voicebank/dereverb/MetricGAN-U/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def make_dataloader(
output_keys=["id", "enh_sig", "score"],
)
samples = round(len(dataset) * self.hparams.history_portion)
samples = max(samples, 1) # Ensure there's at least 1 sample
else:
samples = self.hparams.number_of_samples

Expand All @@ -619,8 +620,12 @@ def make_dataloader(
# Equal weights for all samples, we use "Weighted" so we can do
# both "replacement=False" and a set number of samples, reproducibly
weights = torch.ones(len(dataset))
replacement = samples > len(dataset)
sampler = ReproducibleWeightedRandomSampler(
weights, epoch=epoch, replacement=False, num_samples=samples
weights,
epoch=epoch,
replacement=replacement,
num_samples=samples,
)
loader_kwargs["sampler"] = sampler

Expand Down
7 changes: 6 additions & 1 deletion recipes/Voicebank/enhance/MetricGAN-U/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def make_dataloader(
output_keys=["id", "enh_sig", "score"],
)
samples = round(len(dataset) * self.hparams.history_portion)
samples = max(samples, 1) # Ensure there's at least one sample
else:
samples = self.hparams.number_of_samples

Expand All @@ -609,8 +610,12 @@ def make_dataloader(
# Equal weights for all samples, we use "Weighted" so we can do
# both "replacement=False" and a set number of samples, reproducibly
weights = torch.ones(len(dataset))
replacement = samples > len(dataset)
sampler = ReproducibleWeightedRandomSampler(
weights, epoch=epoch, replacement=False, num_samples=samples
weights,
epoch=epoch,
replacement=replacement,
num_samples=samples,
)
loader_kwargs["sampler"] = sampler

Expand Down
30 changes: 20 additions & 10 deletions recipes/Voicebank/enhance/MetricGAN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def compute_objectives(self, predictions, batch, stage, optim_name=""):

clean_wav, lens = batch.clean_sig
clean_spec = self.compute_feats(clean_wav)
clean_paths = batch.clean_wav

ids = self.compute_ids(batch.id, optim_name)

Expand All @@ -109,7 +110,7 @@ def compute_objectives(self, predictions, batch, stage, optim_name=""):

# Write enhanced wavs during discriminator training, because we
# compute the actual score here and we can save it
self.write_wavs(batch.id, ids, predict_wav, target_score, lens)
self.write_wavs(ids, predict_wav, clean_paths, target_score, lens)

# D Relearns to estimate the scores of previous epochs
elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL:
Expand Down Expand Up @@ -245,7 +246,7 @@ def est_score(self, deg_spec, ref_spec):
)
return self.modules.discriminator(combined_spec)

def write_wavs(self, clean_id, batch_id, wavs, scores, lens):
def write_wavs(self, batch_id, wavs, clean_paths, scores, lens):
"""Write wavs to files, for historical discriminator training
Arguments
Expand All @@ -254,25 +255,24 @@ def write_wavs(self, clean_id, batch_id, wavs, scores, lens):
A list of the utterance ids for the batch
wavs : torch.Tensor
The wavs to write to files
clean_paths : list of str
The paths to the clean wavs
scores : torch.Tensor
The actual scores for the corresponding utterances
lens : torch.Tensor
The relative lengths of each utterance
"""
lens = lens * wavs.shape[1]
record = {}
for i, (cleanid, name, pred_wav, length) in enumerate(
zip(clean_id, batch_id, wavs, lens)
for i, (name, pred_wav, clean_path, length) in enumerate(
zip(batch_id, wavs, clean_paths, lens)
):
path = os.path.join(self.hparams.MetricGAN_folder, name + ".wav")
data = torch.unsqueeze(pred_wav[: int(length)].cpu(), 0)
torchaudio.save(path, data, self.hparams.Sample_rate)

# Make record of path and score for historical training
score = float(scores[i][0])
clean_path = os.path.join(
self.hparams.train_clean_folder, cleanid + ".wav"
)
record[name] = {
"enh_wav": path,
"score": score,
Expand Down Expand Up @@ -446,7 +446,13 @@ def make_dataloader(
dataset = sb.dataio.dataset.DynamicItemDataset(
data=dataset,
dynamic_items=[enh_pipeline],
output_keys=["id", "enh_sig", "clean_sig", "score"],
output_keys=[
"id",
"enh_sig",
"clean_sig",
"score",
"clean_wav",
],
)
samples = round(len(dataset) * self.hparams.history_portion)
else:
Expand All @@ -458,8 +464,12 @@ def make_dataloader(
# Equal weights for all samples, we use "Weighted" so we can do
# both "replacement=False" and a set number of samples, reproducibly
weights = torch.ones(len(dataset))
replacement = samples > len(dataset)
sampler = ReproducibleWeightedRandomSampler(
weights, epoch=epoch, replacement=False, num_samples=samples
weights,
epoch=epoch,
replacement=replacement,
num_samples=samples,
)
loader_kwargs["sampler"] = sampler

Expand Down Expand Up @@ -525,7 +535,7 @@ def dataio_prep(hparams):
json_path=data_info[dataset],
replacements={"data_root": hparams["data_folder"]},
dynamic_items=[audio_pipeline],
output_keys=["id", "noisy_sig", "clean_sig"],
output_keys=["id", "noisy_sig", "clean_sig", "clean_wav"],
)

return datasets
Expand Down
4 changes: 2 additions & 2 deletions tests/recipes/Voicebank.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voi
ASR+enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/robust_asr.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder_rirs=tests/tmp --data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,
ASR,Voicebank,recipes/Voicebank/ASR/CTC/train.py,recipes/Voicebank/ASR/CTC/hparams/train.yaml,recipes/Voicebank/ASR/CTC/voicebank_prepare.py,recipes/Voicebank/ASR/CTC/README.md,https://www.dropbox.com/sh/w4j0auezgmmo005/AAAjKcoJMdLDp0Pqe3m7CLVaa?dl=0,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --output_neurons=18 --number_of_epochs=2,
Enhancement,Voicebank,recipes/Voicebank/MTL/ASR_enhance/train.py,recipes/Voicebank/MTL/ASR_enhance/hparams/enhance_mimic.yaml,recipes/Voicebank/MTL/ASR_enhance/voicebank_prepare.py,recipes/Voicebank/MTL/ASR_enhance/README.md,https://www.dropbox.com/sh/azvcbvu8g5hpgm1/AACDc6QxtNMGZ3IoZLrDiU0Va?dl=0,https://huggingface.co/speechbrain/mtl-mimic-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,
Enhancement,Voicebank,recipes/Voicebank/dereverb/MetricGAN-U/train.py,recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml,recipes/Voicebank/dereverb/MetricGAN-U/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/MetricGAN-U/README.md,https://www.dropbox.com/sh/r94qn1f5lq9r3p7/AAAZfisBhhkS8cwpzy1O5ADUa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,
Enhancement,Voicebank,recipes/Voicebank/dereverb/MetricGAN-U/train.py,recipes/Voicebank/dereverb/MetricGAN-U/hparams/train_dereverb.yaml,recipes/Voicebank/dereverb/MetricGAN-U/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/MetricGAN-U/README.md,https://www.dropbox.com/sh/r94qn1f5lq9r3p7/AAAZfisBhhkS8cwpzy1O5ADUa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,
Enhancement,Voicebank,recipes/Voicebank/dereverb/spectral_mask/train.py,recipes/Voicebank/dereverb/spectral_mask/hparams/train.yaml,recipes/Voicebank/dereverb/spectral_mask/voicebank_revb_prepare.py,recipes/Voicebank/dereverb/spectral_mask/README.md,https://www.dropbox.com/sh/pw8aer8gcsrdbx7/AADknh7plHF5GBeTRK9VkIKga?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,
Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN-U/train.py,recipes/Voicebank/enhance/MetricGAN-U/hparams/train_dnsmos.yaml,recipes/Voicebank/enhance/MetricGAN-U/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN-U/README.md,https://www.dropbox.com/sh/h9akxmyel17sc8y/AAAP3Oz5MbXDfMlEXVjOBWV0a?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False,
Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN-U/train.py,recipes/Voicebank/enhance/MetricGAN-U/hparams/train_dnsmos.yaml,recipes/Voicebank/enhance/MetricGAN-U/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN-U/README.md,https://www.dropbox.com/sh/h9akxmyel17sc8y/AAAP3Oz5MbXDfMlEXVjOBWV0a?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False --target_metric=srmr --calculate_dnsmos_on_validation_set=False,
Enhancement,Voicebank,recipes/Voicebank/enhance/MetricGAN/train.py,recipes/Voicebank/enhance/MetricGAN/hparams/train.yaml,recipes/Voicebank/enhance/MetricGAN/voicebank_prepare.py,recipes/Voicebank/enhance/MetricGAN/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,https://huggingface.co/speechbrain/metricgan-plus-voicebank,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --tensorboard_train_logger=None --use_tensorboard=False,
Enhancement,Voicebank,recipes/Voicebank/enhance/SEGAN/train.py,recipes/Voicebank/enhance/SEGAN/hparams/train.yaml,recipes/Voicebank/enhance/SEGAN/voicebank_prepare.py,recipes/Voicebank/enhance/SEGAN/README.md,https://www.dropbox.com/sh/ez0folswdbqiad4/AADDasepeoCkneyiczjCcvaOa?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2,
Enhancement,Voicebank,recipes/Voicebank/enhance/spectral_mask/train.py,recipes/Voicebank/enhance/spectral_mask/hparams/train.yaml,recipes/Voicebank/enhance/spectral_mask/voicebank_prepare.py,recipes/Voicebank/enhance/spectral_mask/README.md,https://www.dropbox.com/sh/n5q9vjn0yn1qvk6/AAB-S7i2-XzVm6ux0MrXCvqya?dl=0 ,,--data_folder=tests/samples/separation --train_annotation=tests/samples/annotation/enhance_train.json --valid_annotation=tests/samples/annotation/enhance_dev.json --test_annotation=tests/samples/annotation/enhance_dev.json --skip_prep=True --number_of_epochs=2 --use_tensorboard=False,
Expand Down

0 comments on commit f053563

Please sign in to comment.