Skip to content

Commit

Permalink
[egs] Fix wavfiles saving in eval.py for enh tasks (estimates)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed Jun 4, 2020
1 parent aa4b0aa commit ce5c345
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion egs/kinect-wsj/DeepClustering/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'],
metrics_list=compute_metrics)
Expand Down
2 changes: 1 addition & 1 deletion egs/librimix/ConvTasNet/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(conf):
return_est=True)
mix_np = mix.cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
# For each utterance, we get a dictionary with the mixture path,
# the input and output metrics
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
Expand Down
2 changes: 1 addition & 1 deletion egs/wham/ConvTasNet/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
Expand Down
2 changes: 1 addition & 1 deletion egs/wham/DPRNN/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
Expand Down
2 changes: 1 addition & 1 deletion egs/wham/DynamicMixing/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
Expand Down
2 changes: 1 addition & 1 deletion egs/wham/TwoStep/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
Expand Down
2 changes: 1 addition & 1 deletion egs/whamr/TasNet/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
Expand Down
2 changes: 1 addition & 1 deletion egs/wsj0-mix/DeepClustering/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(conf):
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'],
metrics_list=compute_metrics)
Expand Down

0 comments on commit ce5c345

Please sign in to comment.