Skip to content

Commit

Permalink
[egs] Add recipe for Multi-Decoder DPRNN (asteroid-team#463)
Browse files Browse the repository at this point in the history
Co-authored-by: Joseph Zhu <junzhez2@ifp-10.ifp.uiuc.edu>
Co-authored-by: Pariente Manuel <pariente.mnl@gmail.com>
  • Loading branch information
3 people authored Jul 30, 2021
1 parent ee1ff24 commit 98d1241
Show file tree
Hide file tree
Showing 11 changed files with 1,515 additions and 0 deletions.
151 changes: 151 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Author: Joseph(Junzhe) Zhu, 2021/5. Email: josefzhu@stanford.edu / junzhe.joseph.zhu@gmail.com
For the original code for the paper[1], please refer to https://github.com/JunzheJosephZhu/MultiDecoder-DPRNN
Demo Page: https://junzhejosephzhu.github.io/Multi-Decoder-DPRNN/
Multi-Decoder DPRNN is a method for source separation when the number of speakers is unknown.
Our contribution is using multiple output heads, with each head modelling a distinct number of source outputs.
In addition, we design a selector network which determines which output head to use, i.e. estimates the number of sources.
The "DPRNN" part of the architecture is orthogonal to our contribution, and can be replaced with any other separator, e.g. Conv/LSTM-TasNet.
References:
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation",
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022
"""
from metrics import Penalized_PIT_Wrapper, pairwise_neg_sisdr_loss
import os
import json
import yaml
import argparse
import random
import torch
from tqdm import tqdm
import pandas as pd
import soundfile as sf
from pprint import pprint

from asteroid.utils import tensors_to_device
from asteroid.metrics import get_metrics

from model import load_best_model, make_model_and_optimizer
from wsj0_mix_variable import Wsj0mixVariable, _collate_fn


parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
default="sep_count",
type=str,
help="One of `enh_single`, `enh_both`, " "`sep_clean` or `sep_noisy`",
)
parser.add_argument(
"--test_dir", type=str, required=True, help="Test directory including the json files"
)
parser.add_argument(
"--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution"
)
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root")
parser.add_argument(
"--n_save_ex", type=int, default=50, help="Number of audio examples to save, -1 means all"
)


def main(conf):
best_model_path = os.path.join(conf["exp_dir"], "best_model.pth")
if not os.path.exists(best_model_path):
# make pth from checkpoint
model = load_best_model(
conf["train_conf"], conf["exp_dir"], sample_rate=conf["sample_rate"]
)
torch.save(model.state_dict(), best_model_path)
else:
model, _ = make_model_and_optimizer(conf["train_conf"], sample_rate=conf["sample_rate"])
model.eval()
model.load_state_dict(torch.load(best_model_path))
# Handle device placement
if conf["use_gpu"]:
model.cuda()
model_device = next(model.parameters()).device
test_dirs = [
conf["test_dir"].format(n_src) for n_src in conf["train_conf"]["masknet"]["n_srcs"]
]
test_set = Wsj0mixVariable(
json_dirs=test_dirs,
n_srcs=conf["train_conf"]["masknet"]["n_srcs"],
sample_rate=conf["train_conf"]["data"]["sample_rate"],
seglen=None,
minlen=None,
)

# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf["exp_dir"], "examples/")
if conf["n_save_ex"] == -1:
conf["n_save_ex"] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources = [
torch.Tensor(x) for x in tensors_to_device(test_set[idx], device=model_device)
]
est_sources = model.separate(mix[None])
p_si_snr = Penalized_PIT_Wrapper(pairwise_neg_sisdr_loss)(est_sources, sources)
utt_metrics = {
"P-Si-SNR": p_si_snr.item(),
"counting_accuracy": float(sources.size(0) == est_sources.size(0)),
}
utt_metrics["mix_path"] = test_set.data[idx][0]
series_list.append(pd.Series(utt_metrics))

# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = est_sources.cpu().data.numpy()
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"])
# Loop over the sources and estimates
for src_idx, src in enumerate(sources_np):
sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"])
for src_idx, est_src in enumerate(est_sources_np):
sf.write(
local_save_dir + "s{}_estimate.wav".format(src_idx + 1),
est_src,
conf["sample_rate"],
)
# Write local metrics to the example folder.
with open(local_save_dir + "metrics.json", "w") as f:
json.dump(utt_metrics, f, indent=0)

# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv"))

# Print and save summary metrics
final_results = {}
for metric_name in ["P-Si-SNR", "counting_accuracy"]:
final_results[metric_name] = all_metrics_df[metric_name].mean()
print("Overall metrics :")
pprint(final_results)
with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f:
json.dump(final_results, f, indent=0)


if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))

# Load training config
conf_path = os.path.join(args.exp_dir, "conf.yml")
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"]
arg_dic["train_conf"] = train_conf

if args.task != arg_dic["train_conf"]["data"]["task"]:
print(
"Warning : the task used to test is different than "
"the one from training, be sure this is what you want."
)

main(arg_dic)
41 changes: 41 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Filterbank config
filterbank:
n_filters: 64
kernel_size: 8
stride: 4
# Network config
masknet:
n_srcs: [2, 3, 4, 5]
bn_chan: 128
hid_size: 128
chunk_size: 128
hop_size: 64
n_repeats: 8
mask_act: 'sigmoid'
bidirectional: true
dropout: 0
use_mulcat: false
# Training config
training:
epochs: 200
batch_size: 2
num_workers: 2
half_lr: yes
lr_decay: yes
early_stop: yes
gradient_clipping: 5
# Optim config
optim:
optimizer: adam
lr: 0.001
weight_decay: 0.00000
# Data config
data:
train_dir: "data/{}speakers/wav8k/min/tr"
valid_dir: "data/{}speakers/wav8k/min/cv"
task: sep_count
sample_rate: 8000
seglen: 4.0
minlen: 2.0
loss:
lambda: 0.05
38 changes: 38 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/convert_sphere2wav.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
# MIT Copyright (c) 2018 Kaituo XU


sphere_dir=tmp
wav_dir=tmp

. utils/parse_options.sh || exit 1;


echo "Download sph2pipe_v2.5 into egs/tools"
mkdir -p ../../tools
wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools
cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd -

echo "Convert sphere format to wav format"
sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe

if [ ! -x $sph2pipe ]; then
echo "Could not find (or execute) the sph2pipe program at $sph2pipe";
exit 1;
fi

tmp=data/local/
mkdir -p $tmp

[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list

if [ ! -d $wav_dir ]; then
while read line; do
wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'`
echo $wav
mkdir -p `dirname $wav`
$sph2pipe -f wav $line > $wav
done < $tmp/sph.list > $tmp/wav.list
else
echo "Do you already get wav files? if not, please remove $wav_dir"
fi
47 changes: 47 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/preprocess_wsj0mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse
import json
import os
import soundfile as sf


def preprocess_one_dir(in_dir, out_dir, out_filename):
""" Create .json file for one condition."""
file_infos = []
in_dir = os.path.abspath(in_dir)
wav_list = os.listdir(in_dir)
wav_list.sort()
for wav_file in wav_list:
if not wav_file.endswith(".wav"):
continue
wav_path = os.path.join(in_dir, wav_file)
samples = sf.SoundFile(wav_path)
file_infos.append((wav_path, len(samples)))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with open(os.path.join(out_dir, out_filename + ".json"), "w") as f:
json.dump(file_infos, f, indent=4)


def preprocess(inp_args):
""" Create .json files for all conditions."""
speaker_list = ["mix"] + [f"s{n+1}" for n in range(inp_args.n_src)]
for data_type in ["tr", "cv", "tt"]:
for spk in speaker_list:
preprocess_one_dir(
os.path.join(inp_args.in_dir, data_type, spk),
os.path.join(inp_args.out_dir, data_type),
spk,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser("WSJ0-MIX data preprocessing")
parser.add_argument(
"--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt"
)
parser.add_argument("--n_src", type=int, default=2, help="Number of sources in wsj0-mix")
parser.add_argument(
"--out_dir", type=str, default=None, help="Directory path to put output files"
)
args = parser.parse_args()
preprocess(args)
90 changes: 90 additions & 0 deletions egs/wsj0-mix-var/Multi-Decoder-DPRNN/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import torch.nn as nn
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from torch.nn.modules.loss import _Loss
from scipy.optimize import linear_sum_assignment


class PairwiseNegSDR_Loss(_Loss):
"""
Same as asteroid.losses.PairwiseNegSDR, but supports speaker number mismatch
"""

def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
super(PairwiseNegSDR_Loss, self).__init__()
assert sdr_type in ["snr", "sisdr", "sdsdr"]
self.sdr_type = sdr_type
self.zero_mean = zero_mean
self.take_log = take_log
self.EPS = EPS

def forward(self, est_targets, targets):
# Step 1. Zero-mean norm
if self.zero_mean:
mean_source = torch.mean(targets, dim=2, keepdim=True)
mean_estimate = torch.mean(est_targets, dim=2, keepdim=True)
targets = targets - mean_source
est_targets = est_targets - mean_estimate
# Step 2. Pair-wise SI-SDR. (Reshape to use broadcast)
s_target = torch.unsqueeze(targets, dim=1)
s_estimate = torch.unsqueeze(est_targets, dim=2)

if self.sdr_type in ["sisdr", "sdsdr"]:
# [batch, n_src, n_src, 1]
pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)
# [batch, 1, n_src, 1]
s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS
# [batch, n_src, n_src, time]
pair_wise_proj = pair_wise_dot * s_target / s_target_energy
else:
# [batch, n_src, n_src, time]
pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1)
if self.sdr_type in ["sdsdr", "snr"]:
e_noise = s_estimate - s_target
else:
e_noise = s_estimate - pair_wise_proj
# [batch, n_src, n_src]
pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / (
torch.sum(e_noise ** 2, dim=3) + self.EPS
)
if self.take_log:
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
return -pair_wise_sdr


class Penalized_PIT_Wrapper(nn.Module):
"""
Implementation of P-Si-SNR, as purposed in [1]
References:
[1] "Multi-Decoder DPRNN: High Accuracy Source Counting and Separation",
Junzhe Zhu, Raymond Yeh, Mark Hasegawa-Johnson. https://arxiv.org/abs/2011.12022
"""

def __init__(self, loss_func, penalty=30, perm_reduce=None):
super().__init__()
assert penalty > 0, "penalty term should be positive"
self.neg_penalty = -penalty
self.perm_reduce = perm_reduce
self.loss_func = loss_func

def forward(self, est_targets, targets, **kwargs):
"""
est_targets: torch.Tensor, $(est_nsrc, ...)$
targets: torch.Tensor, $(gt_nsrc, ...)$
"""
est_nsrc, T = est_targets.size()
gt_nsrc = est_targets.size(0)
pw_losses = self.loss_func(est_targets.unsqueeze(0), targets.unsqueeze(0)).squeeze(0)
# After transposition, dim 1 corresp. to sources and dim 2 to estimates
pwl = pw_losses.transpose(-1, -2)
# Loop over batch + row indices are always ordered for square matrices.
row, col = [torch.Tensor(x).long() for x in linear_sum_assignment(pwl.detach().cpu())]
avg_neg_sdr = pwl[row, col].mean()
p_si_snr = (
-avg_neg_sdr * min(est_nsrc, gt_nsrc) + self.neg_penalty * abs(est_nsrc - gt_nsrc)
) / max(est_nsrc, gt_nsrc)
return p_si_snr


# alias
pairwise_neg_sisdr_loss = PairwiseNegSDR_Loss("sisdr")
Loading

0 comments on commit 98d1241

Please sign in to comment.