Skip to content

Commit

Permalink
[src] BaseEncoderMaskerDecoder: remove old hooks (asteroid-team#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Nov 26, 2020
1 parent 850109d commit 18c430e
Showing 1 changed file with 7 additions and 63 deletions.
70 changes: 7 additions & 63 deletions asteroid/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(self, wav):
reconstructed = pad_x_to_y(decoded, wav)
return _shape_reconstructed(reconstructed, shape)

def forward_encoder(self, wav):
def forward_encoder(self, wav: torch.Tensor) -> torch.Tensor:
"""Computes time-frequency representation of `wav`.
Args:
Expand All @@ -249,23 +249,9 @@ def forward_encoder(self, wav):
torch.Tensor, of shape (batch, feat, seq).
"""
tf_rep = self.encoder(wav)
tf_rep = self.postprocess_encoded(tf_rep)
return self.enc_activation(tf_rep)

def postprocess_encoded(self, tf_rep):
"""Hook to perform transformations on the encoded, time-frequency domain
representation (output of the encoder) before encoder activation is applied.
Args:
tf_rep (Tensor of shape (batch, freq, time)):
Output of the encoder, before encoder activation is applied.
Return:
Transformed `tf_rep`
"""
return tf_rep

def forward_masker(self, tf_rep):
def forward_masker(self, tf_rep: torch.Tensor) -> torch.Tensor:
"""Estimates masks from time-frequency representation.
Args:
Expand All @@ -275,23 +261,9 @@ def forward_masker(self, tf_rep):
Returns:
torch.Tensor: Estimated masks
"""
est_masks = self.masker(tf_rep)
return self.postprocess_masks(est_masks)

def postprocess_masks(self, masks):
"""Hook to perform transformations on the masks (output of the masker) before
masks are applied.
Args:
masks (Tensor of shape (batch, n_src, freq, time)):
Output of the masker
Return:
Transformed `masks`
"""
return masks
return self.masker(tf_rep)

def apply_masks(self, tf_rep, est_masks):
def apply_masks(self, tf_rep: torch.Tensor, est_masks: torch.Tensor) -> torch.Tensor:
"""Applies masks to time-frequency representation.
Args:
Expand All @@ -302,23 +274,9 @@ def apply_masks(self, tf_rep, est_masks):
Returns:
torch.Tensor: Masked time-frequency representations.
"""
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
return self.postprocess_masked(masked_tf_rep)
return est_masks * tf_rep.unsqueeze(1)

def postprocess_masked(self, masked_tf_rep):
"""Hook to perform transformations on the masked time-frequency domain
representation (result of masking in the time-frequency domain) before decoding.
Args:
masked_tf_rep (Tensor of shape (batch, n_src, freq, time)):
Masked time-frequency representation, before decoding.
Return:
Transformed `masked_tf_rep`
"""
return masked_tf_rep

def forward_decoder(self, masked_tf_rep):
def forward_decoder(self, masked_tf_rep: torch.Tensor) -> torch.Tensor:
"""Reconstructs time-domain waveforms from masked representations.
Args:
Expand All @@ -327,21 +285,7 @@ def forward_decoder(self, masked_tf_rep):
Returns:
torch.Tensor: Time-domain waveforms.
"""
decoded = self.decoder(masked_tf_rep)
return self.postprocess_decoded(decoded)

def postprocess_decoded(self, decoded):
"""Hook to perform transformations on the decoded, time domain representation
(output of the decoder) before original shape reconstruction.
Args:
decoded (Tensor of shape (batch, n_src, time)):
Output of the decoder, before original shape reconstruction.
Return:
Transformed `decoded`
"""
return decoded
return self.decoder(masked_tf_rep)

def get_model_args(self):
""" Arguments needed to re-instantiate the model. """
Expand Down

0 comments on commit 18c430e

Please sign in to comment.