Skip to content

Commit

Permalink
Merge pull request espnet#5414 from Emrys365/espnet2_enh
Browse files Browse the repository at this point in the history
Update the Enh framework to support training with variable numbers of speakers
  • Loading branch information
sw005320 authored Oct 16, 2023
2 parents 72fd7bf + 6a4a1d9 commit 1699eb8
Show file tree
Hide file tree
Showing 17 changed files with 661 additions and 60 deletions.
6 changes: 6 additions & 0 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,14 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
for t in ${feats_types}; do
echo "==== feats_type=${t} ==="
./run.sh --ngpu 0 --stage 1 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --python "${python}" --enh-args "--num_workers 0"
./run.sh --ngpu 0 --stage 3 --stop-stage 6 --skip-upload false --feats-type "${t}" --ref-num 1 --python "${python}" \
--train_set train_nodev_unk_nspk --valid_set test_unk_nspk --test_sets "train_dev_unk_nspk" \
--enh_config ./conf/train_variable_nspk_debug.yaml --enh-args "--num_workers 0" --variable_num_refs true
./run.sh --ngpu 0 --stage 1 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --python "${python}" \
--local_data_opts "--random-enrollment true" --enh_config ./conf/train_random_enrollment_debug.yaml --enh-args "--num_workers 0"
./run.sh --ngpu 0 --stage 3 --stop-stage 6 --skip-upload false --feats-type "${t}" --ref-num 1 --python "${python}" \
--train_set train_nodev_unk_nspk --valid_set test_unk_nspk --test_sets "train_dev_unk_nspk" \
--enh_config ./conf/train_variable_nspk_random_enrollment_debug.yaml --enh-args "--num_workers 0" --variable_num_refs true
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
103 changes: 87 additions & 16 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ is_tse_task=false # Whether perform the target speaker extraction task or norm
# Training data related
use_dereverb_ref=false
use_noise_ref=false
variable_num_refs=false # Whether to use variable numbers of references in spk1.scp, dereverb1.scp, enroll_spk1.scp, etc.
extra_wav_list= # Extra list of scp files for wav formatting

# Pretrained model related
Expand Down Expand Up @@ -156,11 +157,12 @@ Options:
--is_tse_task # Whether perform the target speaker extraction task or normal speech enhancement/separation tasks (default="${is_tse_task}")
# Training data related
--use_dereverb_ref # Whether or not to use dereverberated signal as an additional reference
for training a dereverberation model (default="${use_dereverb_ref}")
--use_noise_ref # Whether or not to use noise signal as an additional reference
for training a denoising model (default="${use_noise_ref}")
--extra_wav_list # Extra list of scp files for wav formatting (default="${extra_wav_list}")
--use_dereverb_ref # Whether or not to use dereverberated signal as an additional reference
for training a dereverberation model (default="${use_dereverb_ref}")
--use_noise_ref # Whether or not to use noise signal as an additional reference
for training a denoising model (default="${use_noise_ref}")
--variable_num_refs # Whether or not to use variable numbers of references in spk1.scp, dereverb1.scp, enroll_spk1.scp, etc. If True, --ref_num and --dereverb_ref_num must be 1. (default="${variable_num_refs}")
--extra_wav_list # Extra list of scp files for wav formatting (default="${extra_wav_list}")
# Pretrained model related
--init_param # pretrained model path and module name (default="${init_param}")
Expand Down Expand Up @@ -290,6 +292,27 @@ if [ -z "${inference_tag}" ]; then
fi
fi

if ${variable_num_refs}; then
# load variable numbers of speakers in spk1.scp, dereverb1.scp, enroll_spk1.scp, etc.
if [ "${ref_num}" -ne 1 ]; then
log "[ERROR] --ref_num must be 1 if --variable_num_refs is true, but got ${ref_num}"
exit 1
fi
if [ "${dereverb_ref_num}" -ne 1 ]; then
log "[ERROR] --dereverb_ref_num must be 1 if --variable_num_refs is true, but got ${dereverb_ref_num}"
exit 1
fi
if [ ! -e "data/${train_set}/utt2category" ] || [ ! -e "data/${valid_set}/utt2category" ]; then
log "[ERROR] utt2category must be prepared in data/${train_set} and data/${valid_set} if --variable_num_refs is true."
exit 1
else
log "[WARNING] Variable speaker number is enabled. Please ensure the utt2category file assigns the same category ID to samples with the same number of speakers."
fi
if [[ "${audio_format}" == *ark* ]]; then
log "[WARNING] Since audio_format=*ark* and variable_num_refs=true is applied,\nplease ensure that the first dimension of each array defined in the ark data\nfor 'spk1.scp', 'dereverb1.scp', 'enroll_spk1.scp' and so on corresponds the\nnumber of references (speakers)."
fi
log "[INFO] Variable speaker number is enabled. Please make sure the argument 'flexible_numspk' is True in the preprocessor in the model config."
fi


# ========================== Main stages start from here. ==========================
Expand Down Expand Up @@ -376,15 +399,22 @@ if ! "${skip_data_prep}"; then
_spk_list+=$(for n in $(seq $dereverb_ref_num); do echo -n "dereverb$n "; done)
fi

for spk in ${_spk_list} "wav" ; do
if ${is_tse_task} && [[ "${spk}" == *enroll_spk* ]]; then
for spk in "wav" ${_spk_list}; do
if ${is_tse_task} && [[ "${spk}" == enroll_spk* ]]; then
audio_path=$(head -n 1 "data/${dset}/${spk}.scp" | awk '{print $2}')
if [[ ("${dset}" == "${train_set}" && "${audio_path:0:1}" == "*") || "${audio_path: -4}" == ".npy" ]]; then
# In case of
# 1. a special format in `enroll_spk?.scp`:
# MIXTURE_UID *UID SPEAKER_ID
# 2. speaker embeddings instead of enrollment audios in `enroll_spk?.scp`
utils/filter_scp.pl "${data_feats}${_suf}/${dset}/spk1.scp" "data/${dset}/${spk}.scp" > "${data_feats}${_suf}/${dset}/${spk}.scp"
utils/filter_scp.pl "${data_feats}${_suf}/${dset}/wav.scp" "data/${dset}/${spk}.scp" > "${data_feats}${_suf}/${dset}/${spk}.scp"
continue
fi
fi
if ${variable_num_refs}; then
if [[ "${spk}" == spk* ]] || [[ "${spk}" == dereverb* ]] || [[ "${spk}" == enroll_spk* ]]; then
# skip formatting for multi-audio-column scp files
utils/filter_scp.pl "${data_feats}${_suf}/${dset}/wav.scp" "data/${dset}/${spk}.scp" > "${data_feats}${_suf}/${dset}/${spk}.scp"
continue
fi
fi
Expand Down Expand Up @@ -501,9 +531,15 @@ if ! "${skip_train}"; then
_scp=wav.scp
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
_type_ref=kaldi_ark
else
# "sound" supports "wav", "flac", etc.
_type=sound
if ${variable_num_refs}; then
_type_ref="variable_columns_sound"
else
_type_ref="sound"
fi
fi

# 1. Split the key file
Expand Down Expand Up @@ -540,8 +576,8 @@ if ! "${skip_train}"; then
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type_ref} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type_ref} "

# for target-speaker extraction
if $is_tse_task; then
Expand All @@ -553,9 +589,9 @@ if ! "${skip_train}"; then
if $use_dereverb_ref; then
# references for dereverberation
_train_data_param+=$(for n in $(seq $dereverb_ref_num); do echo -n \
"--train_data_path_and_name_and_type ${_enh_train_dir}/dereverb${n}.scp,dereverb_ref${n},${_type} "; done)
"--train_data_path_and_name_and_type ${_enh_train_dir}/dereverb${n}.scp,dereverb_ref${n},${_type_ref} "; done)
_valid_data_param+=$(for n in $(seq $dereverb_ref_num); do echo -n \
"--valid_data_path_and_name_and_type ${_enh_valid_dir}/dereverb${n}.scp,dereverb_ref${n},${_type} "; done)
"--valid_data_path_and_name_and_type ${_enh_valid_dir}/dereverb${n}.scp,dereverb_ref${n},${_type_ref} "; done)
fi

if $use_noise_ref; then
Expand Down Expand Up @@ -590,6 +626,35 @@ if ! "${skip_train}"; then
for i in $(seq "${_nj}"); do
_opts+="--input_dir ${_logdir}/stats.${i} "
done
if ${variable_num_refs}; then
# When variable numbers of speakers are enabled, different stats dirs may contain
# different numbers of key files (with different number suffixes).
# So we need to manually create dummy stats files and placeholder entries to avoid
# errors when using `espnet2.bin.aggregate_stats_dirs`.
for dset in train valid; do
# aggregate all batch keys in case some are missing in some stats dirs
for i in $(seq "${_nj}"); do
ls "${_logdir}/stats.${i}/${dset}/"
done | sort | uniq | grep -oP '.*(?=_shape)' > "${_logdir}/${dset}_batch_keys"
while IFS= read -r name; do
fname="${name}_shape"
for i in $(seq "${_nj}"); do
if [ ! -e "${_logdir}/stats.${i}/${dset}/${fname}" ]; then
# create dummy stats files
awk '{print $1 " 0"}' "${_logdir}/${dset}.${i}.scp" > "${_logdir}/stats.${i}/${dset}/${fname}"
else
# create placeholder entries for missing samples in each shape file
mv "${_logdir}/stats.${i}/${dset}/${fname}" "${_logdir}/stats.${i}/${dset}/${fname}.bak"
awk 'NR==FNR{a[$1]=$2; next} {if($1 in a) {print $1" "a[$1]} else {print $1" 0"}}' "${_logdir}/stats.${i}/${dset}/${fname}.bak" "${_logdir}/${dset}.${i}.scp" > "${_logdir}/stats.${i}/${dset}/${fname}"
rm "${_logdir}/stats.${i}/${dset}/${fname}.bak"
fi
done
done < "${_logdir}/${dset}_batch_keys"
for i in $(seq "${_nj}"); do
cp "${_logdir}/${dset}_batch_keys" "${_logdir}/stats.${i}/${dset}/batch_keys"
done
done
fi
# shellcheck disable=SC2086
${python} -m espnet2.bin.aggregate_stats_dirs ${_opts} --skip_sum_stats --output_dir "${enh_stats_dir}"

Expand All @@ -612,9 +677,15 @@ if ! "${skip_train}"; then
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
_type_ref=kaldi_ark
else
# "sound" supports "wav", "flac", etc.
_type=sound
if ${variable_num_refs}; then
_type_ref="variable_columns_sound"
else
_type_ref="sound"
fi
fi
_fold_length="$((enh_speech_fold_length * 100))"

Expand All @@ -626,7 +697,7 @@ if ! "${skip_train}"; then
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type_ref} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "

Expand All @@ -639,7 +710,7 @@ if ! "${skip_train}"; then
done

for spk in $(seq "${ref_num}"); do
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type_ref} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "

# for target-speaker extraction
Expand All @@ -652,9 +723,9 @@ if ! "${skip_train}"; then
if $use_dereverb_ref; then
# references for dereverberation
for n in $(seq "${dereverb_ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/dereverb${n}.scp,dereverb_ref${n},${_type} "
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/dereverb${n}.scp,dereverb_ref${n},${_type_ref} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/dereverb_ref${n}_shape "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/dereverb${n}.scp,dereverb_ref${n},${_type} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/dereverb${n}.scp,dereverb_ref${n},${_type_ref} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/dereverb_ref${n}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
done
Expand Down
68 changes: 68 additions & 0 deletions egs2/mini_an4/tse1/conf/train_variable_nspk_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This is a debug config for CI
encoder: conv
encoder_conf:
channel: 2
kernel_size: 2
stride: 8

decoder: conv
decoder_conf:
channel: 2
kernel_size: 2
stride: 8

extractor: td_speakerbeam
extractor_conf:
layer: 2
stack: 2
bottleneck_dim: 2
hidden_dim: 2
skip_dim: 2
kernel: 3
causal: false
norm_type: gLN
nonlinear: relu
# enrollment related
i_adapt_layer: 1
adapt_layer_type: mul
adapt_enroll_dim: 2

criterions:
# The first criterion
- name: snr
conf:
eps: 1.0e-7
wrapper: fixed_order
wrapper_conf:
weight: 1.0

optim: adam

iterator_type: chunk

chunk_length: 5000
# exclude keys "enroll_ref", "enroll_ref1", "enroll_ref2", ...
# from the length consistency check in ChunkIterFactory
chunk_excluded_key_prefixes:
- "enroll_ref"

model_conf:
num_spk: 3 # the maximum possible number of speakers in the model
flexible_numspk: true
share_encoder: true

preprocessor: enh
preprocessor_conf:
num_spk: 3 # the maximum possible number of speakers in the preprocessor
train_spk2enroll: null
enroll_segment: 5000
load_spk_embedding: false
load_all_speakers: true
flexible_numspk: true
categories:
- 1spk
- 2spk
- 3spk

max_epoch: 1
num_iters_per_epoch: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This is a debug config for CI
encoder: conv
encoder_conf:
channel: 2
kernel_size: 2
stride: 8

decoder: conv
decoder_conf:
channel: 2
kernel_size: 2
stride: 8

extractor: td_speakerbeam
extractor_conf:
layer: 2
stack: 2
bottleneck_dim: 2
hidden_dim: 2
skip_dim: 2
kernel: 3
causal: false
norm_type: gLN
nonlinear: relu
# enrollment related
i_adapt_layer: 1
adapt_layer_type: mul
adapt_enroll_dim: 2

criterions:
# The first criterion
- name: snr
conf:
eps: 1.0e-7
wrapper: fixed_order
wrapper_conf:
weight: 1.0

optim: adam

iterator_type: chunk

chunk_length: 5000
# exclude keys "enroll_ref", "enroll_ref1", "enroll_ref2", ...
# from the length consistency check in ChunkIterFactory
chunk_excluded_key_prefixes:
- "enroll_ref"

model_conf:
num_spk: 3 # the maximum possible number of speakers in the model
flexible_numspk: true
share_encoder: true

preprocessor: enh
preprocessor_conf:
num_spk: 3 # the maximum possible number of speakers in the preprocessor
train_spk2enroll: data/train_nodev/spk2enroll.json
enroll_segment: 5000
load_spk_embedding: false
load_all_speakers: true
flexible_numspk: true
categories:
- 1spk
- 2spk
- 3spk

max_epoch: 1
num_iters_per_epoch: 1
24 changes: 23 additions & 1 deletion egs2/mini_an4/tse1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--audio-format "flac" --fs "16k" \
"data/${dset}/wav.scp" "dump/raw/org/${dset}" \
"dump/raw/org/${dset}/logs/wav" "dump/raw/org/${dset}/data/wav"
done
done

python3 local/prepare_spk2enroll_mini_an4.py \
"dump/raw/org/${train_set}" \
Expand Down Expand Up @@ -91,5 +91,27 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
done
fi

if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
log "stage 4: Data preparation for variable numbers of speakers (up to 3 speakers)"
for dset in ${train_set} ${train_dev} test; do
mkdir -p "data/${dset}_unk_nspk"
cp "data/${dset}/wav.scp" "data/${dset}_unk_nspk"
cp "data/${dset}/utt2spk" "data/${dset}_unk_nspk"
cp "data/${dset}/spk2utt" "data/${dset}_unk_nspk"
awk -v min=1 -v max=3 '{srand(FNR*6); nspk=int(min+rand()*(max-min+1)); print($1 " " nspk "spk")}' \
"data/${dset}/enroll_spk1.scp" > "data/${dset}_unk_nspk/utt2category"
# NOTE: Kaldi pipe IO is not supported when preparing the multi-column scp file
awk 'NR==FNR{nspk[$1]=substr($2,1,1); next} {msg=$1; for (i=1; i<=nspk[$1]; ++i) {msg=msg" "$2} print(msg)}' \
"data/${dset}_unk_nspk/utt2category" "dump/raw/org/${dset}/wav.scp" > "data/${dset}_unk_nspk/spk1.scp"
if [ "${dset}" = "${train_set}" ] && $random_enrollment; then
awk 'NR==FNR{nspk[$1]=substr($2,1,1); next} {msg=$1; s=$2; for (i=3; i<=NF; i++) {s=s" "$i;} for (i=1; i<=nspk[$1]; ++i) {msg=msg" "s} print(msg)}' \
"data/${dset}_unk_nspk/utt2category" "data/${dset}/enroll_spk1.scp" > "data/${dset}_unk_nspk/enroll_spk1.scp"
else
cp "data/${dset}_unk_nspk/spk1.scp" "data/${dset}_unk_nspk/enroll_spk1.scp"
fi
cp "data/${dset}_unk_nspk/spk1.scp" "data/${dset}_unk_nspk/dereverb1.scp"
done
fi


log "Successfully finished. [elapsed=${SECONDS}s]"
Loading

0 comments on commit 1699eb8

Please sign in to comment.