Skip to content

Commit

Permalink
Merge pull request #1799 from TParcollet/fix_wav2vec2_masking
Browse files Browse the repository at this point in the history
Fix wav2vec2 masking
  • Loading branch information
TParcollet authored Jan 15, 2023
2 parents 801b150 + 002779c commit 44d1316
Show file tree
Hide file tree
Showing 21 changed files with 253 additions and 122 deletions.
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def compute_forward(self, batch, stage):
wav_lens = torch.cat([wav_lens, wav_lens])

# Forward pass
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)

if stage == sb.Stage.TRAIN:
if hasattr(self.hparams, "SpecAugment"):
Expand Down
2 changes: 1 addition & 1 deletion recipes/AISHELL-1/ASR/transformer/train_with_wav2vect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def compute_forward(self, batch, stage):
tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)

# compute features
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)
current_epoch = self.hparams.epoch_counter.current

if stage == sb.Stage.TRAIN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ character_coverage: 1.0
dnn_neurons: 1024
wav2vec_output_dim: !ref <dnn_neurons>
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500

# Outputs
output_neurons: 32 # BPE size, index(blank/eos/bos) = 0
Expand Down Expand Up @@ -109,6 +111,7 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ character_coverage: 1.0
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500

# Outputs
output_neurons: 1000 # BPE size, index(blank/eos/bos) = 0
Expand Down Expand Up @@ -109,6 +111,7 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
Expand Down
19 changes: 10 additions & 9 deletions recipes/CommonVoice/ASR/CTC/hparams/train_fr_with_wav2vec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ character_coverage: 1.0
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps

# Outputs
output_neurons: 76 # BPE size, index(blank/eos/bos) = 0
Expand All @@ -78,36 +81,34 @@ eos_index: 2
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <sample_rate>
speeds: [95, 100, 105]

enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: 0.15
p: !ref <dropout>
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: 0.15
p: !ref <dropout>
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU

wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
output_norm: False
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ character_coverage: 1.0
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500

# Outputs
output_neurons: 1000 # BPE size, index(blank/eos/bos) = 0
Expand Down Expand Up @@ -110,6 +112,7 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
Expand Down
14 changes: 9 additions & 5 deletions recipes/CommonVoice/ASR/CTC/hparams/train_rw_with_wav2vec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ character_coverage: 1.0
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500

# Outputs
output_neurons: 1000 # BPE size, index(blank/eos/bos) = 0
Expand All @@ -86,21 +89,21 @@ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: 0.15
p: !ref <dropout>
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: 0.15
p: !ref <dropout>
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: 1024
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU
Expand All @@ -109,6 +112,7 @@ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
Expand Down
68 changes: 39 additions & 29 deletions recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def compute_forward(self, batch, stage):
wavs = self.hparams.augmentation(wavs, wav_lens)

# Forward pass
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)
x = self.modules.enc(feats)
logits = self.modules.ctc_lin(x)
p_ctc = self.hparams.log_softmax(logits)
Expand Down Expand Up @@ -88,43 +88,53 @@ def compute_objectives(self, predictions, batch, stage):

def fit_batch(self, batch):
"""Train the parameters given a single batch in input"""
should_step = self.step % self.grad_accumulation_factor == 0
# Managing automatic mixed precision
# TOFIX: CTC fine-tuning currently is unstable
# This is certainly due to CTC being done in fp16 instead of fp32
if self.auto_mix_prec:

if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.zero_grad()
self.model_optimizer.zero_grad()

with torch.cuda.amp.autocast():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
with self.no_sync():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
with self.no_sync(not should_step):
self.scaler.scale(
loss / self.grad_accumulation_factor
).backward()
if should_step:

self.scaler.scale(loss).backward()
if not self.hparams.wav2vec2.freeze:
self.scaler.unscale_(self.wav2vec_optimizer)
self.scaler.unscale_(self.model_optimizer)

if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
self.scaler.step(self.wav2vec_optimizer)
self.scaler.step(self.model_optimizer)

self.scaler.update()
self.scaler.unscale_(self.wav2vec_optimizer)
self.scaler.unscale_(self.model_optimizer)
if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
if self.optimizer_step >= self.hparams.warmup_steps:
self.scaler.step(self.wav2vec_optimizer)
self.scaler.step(self.model_optimizer)
self.scaler.update()
self.zero_grad()
self.optimizer_step += 1
else:
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
# This is mandatory because HF models have a weird behavior with DDP
# on the forward pass
with self.no_sync():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)

loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
loss.backward()

if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.step()
self.model_optimizer.step()

if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.zero_grad()
self.model_optimizer.zero_grad()

return loss.detach()
with self.no_sync(not should_step):
(loss / self.grad_accumulation_factor).backward()
if should_step:
if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
if self.optimizer_step >= self.hparams.warmup_steps:
self.wav2vec_optimizer.step()
self.model_optimizer.step()
self.zero_grad()
self.optimizer_step += 1

self.on_fit_batch_end(batch, outputs, loss, should_step)
return loss.detach().cpu()

def evaluate_batch(self, batch, stage):
"""Computations needed for validation/test batches"""
Expand Down
2 changes: 1 addition & 1 deletion recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_forward(self, batch, stage):
wavs = self.hparams.augmentation(wavs, wav_lens)

# Forward pass
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)
x = self.modules.enc(feats)

e_in = self.modules.emb(tokens_bos) # y_in bos + tokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def compute_forward(self, batch, stage):
# Forward on w2v2 and take the loss.
# It has to be on train mode even for eval. Otherwise it would deactivate
# the loss computation ...
out, mask = self.modules.wav2vec2(wavs)
out, mask = self.modules.wav2vec2(wavs, wav_lens)
loss = out.loss

if stage != sb.Stage.TRAIN:
Expand Down
68 changes: 39 additions & 29 deletions recipes/DVoice/ASR/CTC/train_with_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def compute_forward(self, batch, stage):
wavs = self.hparams.augmentation(wavs, wav_lens)

# Forward pass
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)
x = self.modules.enc(feats)
logits = self.modules.ctc_lin(x)
p_ctc = self.hparams.log_softmax(logits)
Expand Down Expand Up @@ -88,43 +88,53 @@ def compute_objectives(self, predictions, batch, stage):

def fit_batch(self, batch):
"""Train the parameters given a single batch in input"""
should_step = self.step % self.grad_accumulation_factor == 0
# Managing automatic mixed precision
# TOFIX: CTC fine-tuning currently is unstable
# This is certainly due to CTC being done in fp16 instead of fp32
if self.auto_mix_prec:

if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.zero_grad()
self.model_optimizer.zero_grad()

with torch.cuda.amp.autocast():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
with self.no_sync():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
with self.no_sync(not should_step):
self.scaler.scale(
loss / self.grad_accumulation_factor
).backward()
if should_step:

self.scaler.scale(loss).backward()
if not self.hparams.wav2vec2.freeze:
self.scaler.unscale_(self.wav2vec_optimizer)
self.scaler.unscale_(self.model_optimizer)

if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
self.scaler.step(self.wav2vec_optimizer)
self.scaler.step(self.model_optimizer)

self.scaler.update()
self.scaler.unscale_(self.wav2vec_optimizer)
self.scaler.unscale_(self.model_optimizer)
if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
if self.optimizer_step >= self.hparams.warmup_steps:
self.scaler.step(self.wav2vec_optimizer)
self.scaler.step(self.model_optimizer)
self.scaler.update()
self.zero_grad()
self.optimizer_step += 1
else:
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
# This is mandatory because HF models have a weird behavior with DDP
# on the forward pass
with self.no_sync():
outputs = self.compute_forward(batch, sb.Stage.TRAIN)

loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
loss.backward()

if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.step()
self.model_optimizer.step()

if not self.hparams.wav2vec2.freeze:
self.wav2vec_optimizer.zero_grad()
self.model_optimizer.zero_grad()

return loss.detach()
with self.no_sync(not should_step):
(loss / self.grad_accumulation_factor).backward()
if should_step:
if self.check_gradients(loss):
if not self.hparams.wav2vec2.freeze:
if self.optimizer_step >= self.hparams.warmup_steps:
self.wav2vec_optimizer.step()
self.model_optimizer.step()
self.zero_grad()
self.optimizer_step += 1

self.on_fit_batch_end(batch, outputs, loss, should_step)
return loss.detach().cpu()

def evaluate_batch(self, batch, stage):
"""Computations needed for validation/test batches"""
Expand Down
2 changes: 1 addition & 1 deletion recipes/IEMOCAP/emotion_recognition/train_with_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def compute_forward(self, batch, stage):
batch = batch.to(self.device)
wavs, lens = batch.sig

outputs = self.modules.wav2vec2(wavs)
outputs = self.modules.wav2vec2(wavs, lens)

# last dim will be used for AdaptativeAVG pool
outputs = self.hparams.avg_pool(outputs, lens)
Expand Down
2 changes: 1 addition & 1 deletion recipes/IWSLT22_lowresource/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compute_forward(self, batch, stage):
tokens_bos, _ = batch.tokens_bos # translation

# wav2vec module
feats = self.modules.wav2vec2(wavs)
feats = self.modules.wav2vec2(wavs, wav_lens)

# dimensionality reduction
src = self.modules.enc(feats)
Expand Down
Loading

0 comments on commit 44d1316

Please sign in to comment.