forked from asteroid-team/asteroid
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[egs] Add recipe for Multi-Decoder DPRNN (asteroid-team#463)
Co-authored-by: Joseph Zhu <junzhez2@ifp-10.ifp.uiuc.edu> Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
- Loading branch information
1 parent
ee1ff24
commit 98d1241
Showing
11 changed files
with
1,515 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
""" | ||
Author: Joseph(Junzhe) Zhu, 2021/5. Email: josefzhu@stanford.edu / junzhe.joseph.zhu@gmail.com | ||
For the original code for the paper[1], please refer to https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN | ||
Demo Page: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/ | ||
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown. | ||
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs. | ||
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources. | ||
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet. | ||
References: | ||
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation", | ||
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022 | ||
""" | ||
from metrics import Penalized_PIT_Wrapper, pairwise_neg_sisdr_loss | ||
import os | ||
import json | ||
import yaml | ||
import argparse | ||
import random | ||
import torch | ||
from tqdm import tqdm | ||
import pandas as pd | ||
import soundfile as sf | ||
from pprint import pprint | ||
|
||
from asteroid.utils import tensors_to_device | ||
from asteroid.metrics import get_metrics | ||
|
||
from model import load_best_model, make_model_and_optimizer | ||
from wsj0_mix_variable import Wsj0mixVariable, _collate_fn | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--task", | ||
default="sep_count", | ||
type=str, | ||
help="One of `enh_single`, `enh_both`, " "`sep_clean` or `sep_noisy`", | ||
) | ||
parser.add_argument( | ||
"--test_dir", type=str, required=True, help="Test directory including the json files" | ||
) | ||
parser.add_argument( | ||
"--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" | ||
) | ||
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") | ||
parser.add_argument( | ||
"--n_save_ex", type=int, default=50, help="Number of audio examples to save, -1 means all" | ||
) | ||
|
||
|
||
def main(conf): | ||
best_model_path = os.path.join(conf["exp_dir"], "best_model.pth") | ||
if not os.path.exists(best_model_path): | ||
# make pth from checkpoint | ||
model = load_best_model( | ||
conf["train_conf"], conf["exp_dir"], sample_rate=conf["sample_rate"] | ||
) | ||
torch.save(model.state_dict(), best_model_path) | ||
else: | ||
model, _ = make_model_and_optimizer(conf["train_conf"], sample_rate=conf["sample_rate"]) | ||
model.eval() | ||
model.load_state_dict(torch.load(best_model_path)) | ||
# Handle device placement | ||
if conf["use_gpu"]: | ||
model.cuda() | ||
model_device = next(model.parameters()).device | ||
test_dirs = [ | ||
conf["test_dir"].format(n_src) for n_src in conf["train_conf"]["masknet"]["n_srcs"] | ||
] | ||
test_set = Wsj0mixVariable( | ||
json_dirs=test_dirs, | ||
n_srcs=conf["train_conf"]["masknet"]["n_srcs"], | ||
sample_rate=conf["train_conf"]["data"]["sample_rate"], | ||
seglen=None, | ||
minlen=None, | ||
) | ||
|
||
# Randomly choose the indexes of sentences to save. | ||
ex_save_dir = os.path.join(conf["exp_dir"], "examples/") | ||
if conf["n_save_ex"] == -1: | ||
conf["n_save_ex"] = len(test_set) | ||
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) | ||
series_list = [] | ||
torch.no_grad().__enter__() | ||
for idx in tqdm(range(len(test_set))): | ||
# Forward the network on the mixture. | ||
mix, sources = [ | ||
torch.Tensor(x) for x in tensors_to_device(test_set[idx], device=model_device) | ||
] | ||
est_sources = model.separate(mix[None]) | ||
p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources, sources) | ||
utt_metrics = { | ||
"P-Si-SNR": p_si_snr.item(), | ||
"counting_accuracy": float(sources.size(0) == est_sources.size(0)), | ||
} | ||
utt_metrics["mix_path"] = test_set.data[idx][0] | ||
series_list.append(pd.Series(utt_metrics)) | ||
|
||
# Save some examples in a folder. Wav files and metrics as text. | ||
if idx in save_idx: | ||
mix_np = mix[None].cpu().data.numpy() | ||
sources_np = sources.cpu().data.numpy() | ||
est_sources_np = est_sources.cpu().data.numpy() | ||
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) | ||
os.makedirs(local_save_dir, exist_ok=True) | ||
sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"]) | ||
# Loop over the sources and estimates | ||
for src_idx, src in enumerate(sources_np): | ||
sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"]) | ||
for src_idx, est_src in enumerate(est_sources_np): | ||
sf.write( | ||
local_save_dir + "s{}_estimate.wav".format(src_idx + 1), | ||
est_src, | ||
conf["sample_rate"], | ||
) | ||
# Write local metrics to the example folder. | ||
with open(local_save_dir + "metrics.json", "w") as f: | ||
json.dump(utt_metrics, f, indent=0) | ||
|
||
# Save all metrics to the experiment folder. | ||
all_metrics_df = pd.DataFrame(series_list) | ||
all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv")) | ||
|
||
# Print and save summary metrics | ||
final_results = {} | ||
for metric_name in ["P-Si-SNR", "counting_accuracy"]: | ||
final_results[metric_name] = all_metrics_df[metric_name].mean() | ||
print("Overall metrics :") | ||
pprint(final_results) | ||
with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f: | ||
json.dump(final_results, f, indent=0) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
arg_dic = dict(vars(args)) | ||
|
||
# Load training config | ||
conf_path = os.path.join(args.exp_dir, "conf.yml") | ||
with open(conf_path) as f: | ||
train_conf = yaml.safe_load(f) | ||
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] | ||
arg_dic["train_conf"] = train_conf | ||
|
||
if args.task != arg_dic["train_conf"]["data"]["task"]: | ||
print( | ||
"Warning : the task used to test is different than " | ||
"the one from training, be sure this is what you want." | ||
) | ||
|
||
main(arg_dic) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Filterbank config | ||
filterbank: | ||
n_filters: 64 | ||
kernel_size: 8 | ||
stride: 4 | ||
# Network config | ||
masknet: | ||
n_srcs: [2, 3, 4, 5] | ||
bn_chan: 128 | ||
hid_size: 128 | ||
chunk_size: 128 | ||
hop_size: 64 | ||
n_repeats: 8 | ||
mask_act: 'sigmoid' | ||
bidirectional: true | ||
dropout: 0 | ||
use_mulcat: false | ||
# Training config | ||
training: | ||
epochs: 200 | ||
batch_size: 2 | ||
num_workers: 2 | ||
half_lr: yes | ||
lr_decay: yes | ||
early_stop: yes | ||
gradient_clipping: 5 | ||
# Optim config | ||
optim: | ||
optimizer: adam | ||
lr: 0.001 | ||
weight_decay: 0.00000 | ||
# Data config | ||
data: | ||
train_dir: "data/{}speakers/wav8k/min/tr" | ||
valid_dir: "data/{}speakers/wav8k/min/cv" | ||
task: sep_count | ||
sample_rate: 8000 | ||
seglen: 4.0 | ||
minlen: 2.0 | ||
loss: | ||
lambda: 0.05 |
38 changes: 38 additions & 0 deletions
38
egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/convert_sphere2wav.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/bin/bash | ||
# MIT Copyright (c) 2018 Kaituo XU | ||
|
||
|
||
sphere_dir=tmp | ||
wav_dir=tmp | ||
|
||
. utils/parse_options.sh || exit 1; | ||
|
||
|
||
echo "Download sph2pipe_v2.5 into egs/tools" | ||
mkdir -p ../../tools | ||
wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools | ||
cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - | ||
|
||
echo "Convert sphere format to wav format" | ||
sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe | ||
|
||
if [ ! -x $sph2pipe ]; then | ||
echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; | ||
exit 1; | ||
fi | ||
|
||
tmp=data/local/ | ||
mkdir -p $tmp | ||
|
||
[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list | ||
|
||
if [ ! -d $wav_dir ]; then | ||
while read line; do | ||
wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` | ||
echo $wav | ||
mkdir -p `dirname $wav` | ||
$sph2pipe -f wav $line > $wav | ||
done < $tmp/sph.list > $tmp/wav.list | ||
else | ||
echo "Do you already get wav files? if not, please remove $wav_dir" | ||
fi |
47 changes: 47 additions & 0 deletions
47
egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/preprocess_wsj0mix.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import argparse | ||
import json | ||
import os | ||
import soundfile as sf | ||
|
||
|
||
def preprocess_one_dir(in_dir, out_dir, out_filename): | ||
""" Create .json file for one condition.""" | ||
file_infos = [] | ||
in_dir = os.path.abspath(in_dir) | ||
wav_list = os.listdir(in_dir) | ||
wav_list.sort() | ||
for wav_file in wav_list: | ||
if not wav_file.endswith(".wav"): | ||
continue | ||
wav_path = os.path.join(in_dir, wav_file) | ||
samples = sf.SoundFile(wav_path) | ||
file_infos.append((wav_path, len(samples))) | ||
if not os.path.exists(out_dir): | ||
os.makedirs(out_dir) | ||
with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: | ||
json.dump(file_infos, f, indent=4) | ||
|
||
|
||
def preprocess(inp_args): | ||
""" Create .json files for all conditions.""" | ||
speaker_list = ["mix"] + [f"s{n+1}" for n in range(inp_args.n_src)] | ||
for data_type in ["tr", "cv", "tt"]: | ||
for spk in speaker_list: | ||
preprocess_one_dir( | ||
os.path.join(inp_args.in_dir, data_type, spk), | ||
os.path.join(inp_args.out_dir, data_type), | ||
spk, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("WSJ0-MIX data preprocessing") | ||
parser.add_argument( | ||
"--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" | ||
) | ||
parser.add_argument("--n_src", type=int, default=2, help="Number of sources in wsj0-mix") | ||
parser.add_argument( | ||
"--out_dir", type=str, default=None, help="Directory path to put output files" | ||
) | ||
args = parser.parse_args() | ||
preprocess(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
import torch.nn as nn | ||
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr | ||
from torch.nn.modules.loss import _Loss | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
|
||
class PairwiseNegSDR_Loss(_Loss): | ||
""" | ||
Same as asteroid.losses.PairwiseNegSDR, but supports speaker number mismatch | ||
""" | ||
|
||
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): | ||
super(PairwiseNegSDR_Loss, self).__init__() | ||
assert sdr_type in ["snr", "sisdr", "sdsdr"] | ||
self.sdr_type = sdr_type | ||
self.zero_mean = zero_mean | ||
self.take_log = take_log | ||
self.EPS = EPS | ||
|
||
def forward(self, est_targets, targets): | ||
# Step 1. Zero-mean norm | ||
if self.zero_mean: | ||
mean_source = torch.mean(targets, dim=2, keepdim=True) | ||
mean_estimate = torch.mean(est_targets, dim=2, keepdim=True) | ||
targets = targets - mean_source | ||
est_targets = est_targets - mean_estimate | ||
# Step 2. Pair-wise SI-SDR. (Reshape to use broadcast) | ||
s_target = torch.unsqueeze(targets, dim=1) | ||
s_estimate = torch.unsqueeze(est_targets, dim=2) | ||
|
||
if self.sdr_type in ["sisdr", "sdsdr"]: | ||
# [batch, n_src, n_src, 1] | ||
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) | ||
# [batch, 1, n_src, 1] | ||
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS | ||
# [batch, n_src, n_src, time] | ||
pair_wise_proj = pair_wise_dot * s_target / s_target_energy | ||
else: | ||
# [batch, n_src, n_src, time] | ||
pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1) | ||
if self.sdr_type in ["sdsdr", "snr"]: | ||
e_noise = s_estimate - s_target | ||
else: | ||
e_noise = s_estimate - pair_wise_proj | ||
# [batch, n_src, n_src] | ||
pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / ( | ||
torch.sum(e_noise ** 2, dim=3) + self.EPS | ||
) | ||
if self.take_log: | ||
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) | ||
return -pair_wise_sdr | ||
|
||
|
||
class Penalized_PIT_Wrapper(nn.Module): | ||
""" | ||
Implementation of P-Si-SNR, as purposed in [1] | ||
References: | ||
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation", | ||
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022 | ||
""" | ||
|
||
def __init__(self, loss_func, penalty=30, perm_reduce=None): | ||
super().__init__() | ||
assert penalty > 0, "penalty term should be positive" | ||
self.neg_penalty = -penalty | ||
self.perm_reduce = perm_reduce | ||
self.loss_func = loss_func | ||
|
||
def forward(self, est_targets, targets, **kwargs): | ||
""" | ||
est_targets: torch.Tensor, $(est_nsrc, ...)$ | ||
targets: torch.Tensor, $(gt_nsrc, ...)$ | ||
""" | ||
est_nsrc, T = est_targets.size() | ||
gt_nsrc = est_targets.size(0) | ||
pw_losses = self.loss_func(est_targets.unsqueeze(0), targets.unsqueeze(0)).squeeze(0) | ||
# After transposition, dim 1 corresp. to sources and dim 2 to estimates | ||
pwl = pw_losses.transpose(-1, -2) | ||
# Loop over batch + row indices are always ordered for square matrices. | ||
row, col = [torch.Tensor(x).long() for x in linear_sum_assignment(pwl.detach().cpu())] | ||
avg_neg_sdr = pwl[row, col].mean() | ||
p_si_snr = ( | ||
-avg_neg_sdr * min(est_nsrc, gt_nsrc) + self.neg_penalty * abs(est_nsrc - gt_nsrc) | ||
) / max(est_nsrc, gt_nsrc) | ||
return p_si_snr | ||
|
||
|
||
# alias | ||
pairwise_neg_sisdr_loss = PairwiseNegSDR_Loss("sisdr") |
Oops, something went wrong.