Skip to content

Commit

Permalink
Merge pull request #5977 from Shikhar-S/esc
Browse files Browse the repository at this point in the history
ESC-50 classification with BEATs
  • Loading branch information
sw005320 authored Jan 5, 2025
2 parents 735c86a + 5b0f8c5 commit b70dc52
Show file tree
Hide file tree
Showing 21 changed files with 679 additions and 61 deletions.
2 changes: 1 addition & 1 deletion egs2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ See: https://espnet.github.io/espnet/espnet2_tutorial.html#recipes-using-espnet2
| bur_openslr80 | Burmese ASR training dataset | ASR | BUR | https://openslr.org/80/ | |
| catslu | CATSLU-MAPS | SLU | CMN | https://sites.google.com/view/catslu/home | |
| catslu_entity | CATSLU | SLU/Entity Classifi. | CMN | https://sites.google.com/view/catslu/home | |
| clotho_v2 | Clotho v2.1 dataset for audio captioning | ASR | ENG | https://zenodo.org/records/4783391
| clotho_v2 | Clotho v2.1 dataset for audio captioning | AAC | ENG | https://zenodo.org/records/4783391
| chime1 | The 1st CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | https://spandh.dcs.shef.ac.uk/chime_challenge/chime2011/ | |
| chime2 | The 2nd CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | https://spandh.dcs.shef.ac.uk/chime_challenge/chime2013/ | |
| chime4 | The 4th CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | http://spandh.dcs.shef.ac.uk/chime_challenge/chime2016/ | |
Expand Down
54 changes: 54 additions & 0 deletions egs2/esc50/asr1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ESC-50 Audio Classification Recipe

This recipe implements the audio classification task with a BEATs encoder and linear layer decoder model on the ESC-50 dataset, very close to what is described in [this paper](https://arxiv.org/abs/2212.09058).
More specifically, we provide the fine-tuning config and results for second last row in Table 1 (BEATS-iter3) from the paper.
We reuse part of the code from the [BEATs repository](https://github.com/microsoft/unilm/tree/master/beats) for this implementation.

# Training Details and Requirements
We perform 5-fold cross validation on ESC-50 dataset.
This dataset has 2k samples with 400 samples in each fold.
Please note that the hyper-parameters might be different from those in appendix A.1 of the BEATs paper, but the ones used here gave us best results.
They were tuned on fold 5 and then re-used for other folds.
Fine-tuning for one run needs 1 GPU with 33 GB memory and runs for ~4.5 hours on L40S.

### Steps to run

1. Download ESC-50 dataset from [this repo](https://github.com/karolpiczak/ESC-50?tab=readme-ov-file#download) and set the path to its root directory in db.sh.
2. Download the BEATs checkpoint: [BEATs_iter3](https://github.com/microsoft/unilm/tree/master/beats) and change the `beats_ckpt_path` path in `conf/beats_classification.yaml`
3. Launch with `run.sh`


## Trained checkpoints
All trained checkpoints are available at:
* Fold-1: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold1 94.3
* Fold-2: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold2 97.0
* Fold-3: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold3 94.8
* Fold-4: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold4 96.3
* Fold-5: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold5 91.8

Average acc: 94.8

# Error Analysis
We also observe that top confusion in fold-5 are from the class `helicopter`, which is mainly confused with `washing machine`, and `airplane`.

<!-- Generated by scripts/utils/show_asr_result.sh -->
# RESULTS
## Environments
- date: `Sat Dec 14 19:04:56 EST 2024`
- python version: `3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]`
- espnet version: `espnet 202412`
- pytorch version: `pytorch 2.4.0`
- Git hash: `cb80e61a15d6a13dc342ae5a413d2b870dd869c6`
- Commit date: `Fri Dec 13 11:57:16 2024 -0500`

## /compute/babel-13-33/sbharad2/expdir/asr_fast.fold[i]/inference_ctc_weight0.0_maxlenratio-1_asr_model_valid.acc.best
### Accuracy

|dataset|Snt|Wrd|Acc|
|---|---|---|---|
|org/val1|400|400|94.3|
|org/val2|400|400|97.0|
|org/val3|400|400|94.8|
|org/val4|400|400|96.3|
|org/val5|400|400|91.8|
|Average|||94.8|
76 changes: 76 additions & 0 deletions egs2/esc50/asr1/conf/beats_classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
token_type: word

optim: adamw
optim_conf:
lr: 1.0e-4
weight_decay: 1.0e-2
betas: [0.9, 0.98]

accum_grad: 1

batch_size: 128 # 12.5 steps per epoch with 1600 samples
max_epoch: 1000

scheduler: CosineAnnealingWarmupRestarts
scheduler_conf:
first_cycle_steps: 6000
warmup_steps: 300
max_lr: 1.0e-4
min_lr: 5.0e-6

# BEATs implementation takes care of generating mel spectrogram, normalization and specaug
frontend: none
input_size: 1 # important to set input_size to 1 if frontend is none
normalize: none # BEATs code does global mean and variance normalization

# Initialization for the decoder
init: xavier_normal

model_conf:
ctc_weight: 0.0 # No CTC, no attention.
lsm_weight: 0.1 # label smoothing weight
length_normalized_loss: true

batch_type: folded
unused_parameters: true
grad_clip: 1
patience: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 1
use_amp: false # whether to use automatic mixed precision
num_att_plot: 0
num_workers: 2 # dataloader workers

encoder: beats
encoder_conf:
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats
# (iter3) and update the path below
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3.pt
# Most values from Appendix A.1 of the BEATs paper or tuned on fold 5.
# Please also check the README.md
fbank_mean: 11.72215
fbank_std: 10.60431
beats_config:
layer_wise_gradient_decay_ratio: 0.2
encoder_layerdrop: 0.1
dropout: 0.0
specaug_config:
apply_time_warp: true
apply_freq_mask: false
apply_time_mask: true
time_mask_width_ratio_range:
- 0
- 0.06
num_time_mask: 1
roll_augment: true
roll_interval: 16000 # 1 second, only 5 possible augmentations per sample
use_weighted_representation: false

# Simple linear decoder for classification.
decoder: linear_decoder
decoder_conf:
pooling: mean
dropout: 0.1
2 changes: 2 additions & 0 deletions egs2/esc50/asr1/conf/fbank.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--sample-frequency=16000
--num-mel-bins=80
11 changes: 11 additions & 0 deletions egs2/esc50/asr1/conf/pbs.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Default configuration
command qsub -V -v PATH -S /bin/bash
option name=* -N $0
option mem=* -l mem=$0
option mem=0 # Do not add anything to qsub_opts
option num_threads=* -l ncpus=$0
option num_threads=1 # Do not add anything to qsub_opts
option num_nodes=* -l nodes=$0:ppn=1
default gpu=0
option gpu=0
option gpu=* -l ngpus=$0
1 change: 1 addition & 0 deletions egs2/esc50/asr1/conf/pitch.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--sample-frequency=16000
12 changes: 12 additions & 0 deletions egs2/esc50/asr1/conf/queue.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Default configuration
command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64*
option name=* -N $0
option mem=* -l mem_free=$0,ram_free=$0
option mem=0 # Do not add anything to qsub_opts
option num_threads=* -pe smp $0
option num_threads=1 # Do not add anything to qsub_opts
option max_jobs_run=* -tc $0
option num_nodes=* -pe mpi $0 # You must set this PE as allocation_rule=1
default gpu=0
option gpu=0
option gpu=* -l gpu=$0 -q g.q
14 changes: 14 additions & 0 deletions egs2/esc50/asr1/conf/slurm.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Default configuration
command sbatch --export=PATH
option name=* --job-name $0
option time=* --time $0
option mem=* --mem-per-cpu $0
option mem=0
option num_threads=* --cpus-per-task $0
option num_threads=1 --cpus-per-task 1
option num_nodes=* --nodes $0
default gpu=0
option gpu=0 -p cpu
option gpu=* -p gpu --gres=gpu:$0 -c $0 # Recommend allocating more CPU than, or equal to the number of GPU
# note: the --max-jobs-run option is supported as a special case
# by slurm.pl and you don't have to handle it in the config file.
33 changes: 23 additions & 10 deletions egs2/esc50/asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ log "$0 $*"
. ./path.sh
. ./cmd.sh

if [ $# -ne 0 ]; then
log "Error: No positional arguments are required."
exit 2
fi
FOLD=${1:-1}
DATA_PREP_ROOT=${2:-"."}

if [ -z "${ESC50}" ]; then
log "Fill the value of 'ESC50' of db.sh"
Expand All @@ -41,14 +39,29 @@ fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2: Data Preparation"
mkdir -p data/{train,valid,test}
python3 local/data_prep.py ${ESC50}
for x in test valid train; do

if [ ${FOLD} -le 1 ] && [ ${FOLD} -ge 1 ]; then
# Keep the code from SLU, FOLD 1 is default
mkdir -p data/{train,valid,test}
python3 local/data_prep.py ${ESC50}
for x in test valid train; do
for f in text wav.scp utt2spk; do
sort data/${x}/${f} -o data/${x}/${f}
done
utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > "data/${x}/spk2utt"
utils/validate_data_dir.sh --no-feats data/${x} || exit 1
done
fi

# Prepare data for 5-fold cross-validation
echo "Preparing data for fold ${FOLD}"
python3 local/data_prep_multi_folds.py ${ESC50} ${FOLD} ${DATA_PREP_ROOT}
for x in val${FOLD} train${FOLD}; do
for f in text wav.scp utt2spk; do
sort data/${x}/${f} -o data/${x}/${f}
sort ${DATA_PREP_ROOT}/data/${x}/${f} -o ${DATA_PREP_ROOT}/data/${x}/${f}
done
utils/utt2spk_to_spk2utt.pl data/${x}/utt2spk > "data/${x}/spk2utt"
utils/validate_data_dir.sh --no-feats data/${x} || exit 1
utils/utt2spk_to_spk2utt.pl ${DATA_PREP_ROOT}/data/${x}/utt2spk > "${DATA_PREP_ROOT}/data/${x}/spk2utt"
utils/validate_data_dir.sh --no-feats ${DATA_PREP_ROOT}/data/${x} || exit 1
done
fi

Expand Down
62 changes: 62 additions & 0 deletions egs2/esc50/asr1/local/data_prep_multi_folds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env bash

# Copyright 2023 Siddhant Arora
# 2023 Carnegie Mellon University
# Apache 2.0

import json
import os
import pickle
import sys
from pathlib import Path

import pandas as pd
from torch.utils.data import random_split

if len(sys.argv) < 2:
print(len(sys.argv))
print(
"Usage: python data_prep.py [ESC-50_root] [FOLD] [ROOT] where "
"FOLD and ROOT are optional"
)
sys.exit(1)

esc_root = sys.argv[1]
fold_num = int(sys.argv[2]) if len(sys.argv) > 2 else 1
data_prep_root = sys.argv[3] if len(sys.argv) == 4 else "."

meta_data = pd.read_csv(Path(esc_root, "meta", "esc50.csv"))

split_df = {}
split_df[f"val{fold_num}"] = meta_data[meta_data["fold"] == fold_num]
split_df[f"train{fold_num}"] = meta_data[meta_data["fold"] != fold_num]

print(
"For fold number:",
fold_num,
"Train and Val split",
len(split_df[f"train{fold_num}"]),
len(split_df[f"val{fold_num}"]),
)

dir_dict = split_df
for x in dir_dict:
os.makedirs(os.path.join(data_prep_root, "data", x), exist_ok=True)
with open(
os.path.join(data_prep_root, "data", x, "wav.scp"), "w"
) as wav_scp_f, open(
os.path.join(data_prep_root, "data", x, "utt2spk"), "w"
) as utt2spk_f, open(
os.path.join(data_prep_root, "data", x, "text"), "w"
) as text_f:
filename = dir_dict[x]["filename"].values.tolist()
label = dir_dict[x]["target"].values.tolist()
for line_count in range(len(filename)):
cls = "audio_class:" + str(label[line_count])
utt_id = filename[line_count].replace(".wav", "")
spk = utt_id
data_dir = Path(esc_root, "audio", filename[line_count])

wav_scp_f.write(utt_id + " " + str(data_dir) + "\n")
text_f.write(utt_id + " " + cls + "\n")
utt2spk_f.write(utt_id + " " + spk + "\n")
32 changes: 32 additions & 0 deletions egs2/esc50/asr1/local/run_single_fold.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

train_set="train"
valid_set="valid"
test_sets="test valid"

if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then
asr_config=conf/train_asr.yaml
else
asr_config=conf/tuning/train_asr_transformer_adam_specaug.yaml #s3prl is installed when pytorch > 1.7. Hence using default frontend
fi

./asr.sh \
--lang en \
--ngpu 1 \
--use_lm false \
--token_type word\
--audio_format "flac.ark" \
--feats_type raw\
--max_wav_duration 30 \
--feats_normalize utterance_mvn\
--inference_nj 8 \
--inference_asr_model valid.acc.ave_5best.pth\
--asr_config "${asr_config}" \
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" "$@"
Loading

0 comments on commit b70dc52

Please sign in to comment.