Skip to content

Commit

Permalink
update for test
Browse files Browse the repository at this point in the history
  • Loading branch information
LiChenda committed Jan 2, 2024
1 parent eaa0ccc commit 749b256
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
12 changes: 6 additions & 6 deletions espnet2/enh/diffusion/score_based_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,17 @@ def enhance(
"""Enhance function.
Args:
noisy_specturm (torch.Tensor): noisy feature in [Batch, T, F]
sampler_type (str): sampler, 'pc' for Predictor-Corrector and 'ode' for ODE
sampler
predictor (str): the name of Predictor. 'reverse_diffusion',
noisy_specturm (torch.Tensor): noisy feature in [Batch, T, F]
sampler_type (str): sampler, 'pc' for Predictor-Corrector and 'ode' for ODE
sampler.
predictor (str): the name of Predictor. 'reverse_diffusion',
'euler_maruyama', or 'none'
corrector (str): the name of Corrector. 'langevin', 'ald' or 'none'
N (int): The number of reverse sampling steps.
N (int): The number of reverse sampling steps.
corrector_steps (int) : number of steps in the Corrector.
snr (float): The SNR to use for the corrector.
Returns:
X_Hat (torch.Tensor): enhanced feature in [Batch, T, F]
X_Hat (torch.Tensor): enhanced feature in [Batch, T, F]
"""
Y = noisy_specturm.permute(0, 2, 1).unsqueeze(1)
if sampler_type == "pc":
Expand Down
1 change: 0 additions & 1 deletion test/espnet2/enh/diffusion/test_score_based_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from espnet2.enh.diffusion.score_based_diffusion import ScoreModel



def test_score_based_diffusion_forward_backward_dcunet():
parameters = {
"score_model": "dcunet",
Expand Down

0 comments on commit 749b256

Please sign in to comment.