Skip to content

Commit

Permalink
Fix a bug in DynamicMixingPreprocessor when utt2category is used
Browse files Browse the repository at this point in the history
Emrys365 committed Jun 5, 2023
1 parent b7041cb commit 683aab2
Showing 3 changed files with 39 additions and 29 deletions.
5 changes: 4 additions & 1 deletion egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
@@ -1183,7 +1183,7 @@ cd $(pwd | rev | cut -d/ -f1-3 | rev)
./run.sh --skip_data_prep false --skip_train true --download_model ${_model_name}</code>
</pre></li>
<li><strong>Results</strong><pre><code>$(cat "${enh_exp}"/RESULTS.md)</code></pre></li>
<li><strong>ASR config</strong><pre><code>$(cat "${enh_exp}"/config.yaml)</code></pre></li>
<li><strong>Enh config</strong><pre><code>$(cat "${enh_exp}"/config.yaml)</code></pre></li>
</ul>
EOF

@@ -1242,6 +1242,9 @@ if ! "${skip_upload_hf}"; then
# shellcheck disable=SC2034
task_exp=${enh_exp}
eval "echo \"$(cat scripts/utils/TEMPLATE_HF_Readme.md)\"" > "${dir_repo}"/README.md
if ${is_tse_task}; then
sed -i -e "s#./run.sh --skip_data_prep false --skip_train true --download_model#./run.sh --skip_data_prep false --skip_train true --is_tse_task true --download_model#g" "${dir_repo}"/README.md
fi

this_folder=${PWD}
cd ${dir_repo}
53 changes: 25 additions & 28 deletions espnet2/tasks/enh.py
Original file line number Diff line number Diff line change
@@ -375,35 +375,32 @@ def build_preprocess_fn(
if use_preprocessor:
# TODO(simpleoier): To make this as simple as model parts, e.g. encoder
if args.preprocessor == "dynamic_mixing":
if train:
retval = preprocessor_choices.get_class(args.preprocessor)(
train=train,
source_scp=os.path.join(
os.path.dirname(
args.train_data_path_and_name_and_type[0][0]
),
args.preprocessor_conf.get("source_scp_name", "spk1.scp"),
),
ref_num=args.preprocessor_conf.get(
"ref_num", args.separator_conf["num_spk"]
),
dynamic_mixing_gain_db=args.preprocessor_conf.get(
"dynamic_mixing_gain_db", 0.0
),
speech_name=args.preprocessor_conf.get(
"speech_name", "speech_mix"
),
speech_ref_name_prefix=args.preprocessor_conf.get(
"speech_ref_name_prefix", "speech_ref"
),
mixture_source_name=args.preprocessor_conf.get(
"mixture_source_name", None
retval = preprocessor_choices.get_class(args.preprocessor)(
train=train,
source_scp=os.path.join(
os.path.dirname(
args.train_data_path_and_name_and_type[0][0]
),
utt2spk=getattr(args, "utt2spk", None),
categories=args.preprocessor_conf.get("categories", None),
)
else:
retval = None
args.preprocessor_conf.get("source_scp_name", "spk1.scp"),
),
ref_num=args.preprocessor_conf.get(
"ref_num", args.separator_conf["num_spk"]
),
dynamic_mixing_gain_db=args.preprocessor_conf.get(
"dynamic_mixing_gain_db", 0.0
),
speech_name=args.preprocessor_conf.get(
"speech_name", "speech_mix"
),
speech_ref_name_prefix=args.preprocessor_conf.get(
"speech_ref_name_prefix", "speech_ref"
),
mixture_source_name=args.preprocessor_conf.get(
"mixture_source_name", None
),
utt2spk=getattr(args, "utt2spk", None),
categories=args.preprocessor_conf.get("categories", None),
)
elif args.preprocessor == "enh":
retval = preprocessor_choices.get_class(args.preprocessor)(
train=train,
10 changes: 10 additions & 0 deletions espnet2/train/preprocessor.py
Original file line number Diff line number Diff line change
@@ -771,6 +771,11 @@ def __call__(
), "Multi-channel input has not been tested"

# Add the category information (an integer) to `data`
if not self.categories and "category" in data:
raise ValueError(
"categories must be set in the config file when utt2category files "
"exist in the data directory (e.g., dump/raw/*/utt2category)"
)
if self.categories and "category" in data:
category = data.pop("category")
assert category in self.categories, category
@@ -957,6 +962,11 @@ def _speech_process(
data[k] = data[k][..., chs]

# Add the category information (an integer) to `data`
if not self.categories and "category" in data:
raise ValueError(
"categories must be set in the config file when utt2category files "
"exist in the data directory (e.g., dump/raw/*/utt2category)"
)
if self.categories and "category" in data:
category = data.pop("category")
assert category in self.categories, category

0 comments on commit 683aab2

Please sign in to comment.