-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5977 from Shikhar-S/esc
ESC-50 classification with BEATs
- Loading branch information
Showing
21 changed files
with
679 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# ESC-50 Audio Classification Recipe | ||
|
||
This recipe implements the audio classification task with a BEATs encoder and linear layer decoder model on the ESC-50 dataset, very close to what is described in [this paper](https://arxiv.org/abs/2212.09058). | ||
More specifically, we provide the fine-tuning config and results for second last row in Table 1 (BEATS-iter3) from the paper. | ||
We reuse part of the code from the [BEATs repository](https://github.com/microsoft/unilm/tree/master/beats) for this implementation. | ||
|
||
# Training Details and Requirements | ||
We perform 5-fold cross validation on ESC-50 dataset. | ||
This dataset has 2k samples with 400 samples in each fold. | ||
Please note that the hyper-parameters might be different from those in appendix A.1 of the BEATs paper, but the ones used here gave us best results. | ||
They were tuned on fold 5 and then re-used for other folds. | ||
Fine-tuning for one run needs 1 GPU with 33 GB memory and runs for ~4.5 hours on L40S. | ||
|
||
### Steps to run | ||
|
||
1. Download ESC-50 dataset from [this repo](https://github.com/karolpiczak/ESC-50?tab=readme-ov-file#download) and set the path to its root directory in db.sh. | ||
2. Download the BEATs checkpoint: [BEATs_iter3](https://github.com/microsoft/unilm/tree/master/beats) and change the `beats_ckpt_path` path in `conf/beats_classification.yaml` | ||
3. Launch with `run.sh` | ||
|
||
|
||
## Trained checkpoints | ||
All trained checkpoints are available at: | ||
* Fold-1: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold1 94.3 | ||
* Fold-2: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold2 97.0 | ||
* Fold-3: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold3 94.8 | ||
* Fold-4: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold4 96.3 | ||
* Fold-5: https://huggingface.co/shikhar7ssu/BEATs-ESC-FinetunedFold5 91.8 | ||
|
||
Average acc: 94.8 | ||
|
||
# Error Analysis | ||
We also observe that top confusion in fold-5 are from the class `helicopter`, which is mainly confused with `washing machine`, and `airplane`. | ||
|
||
<!-- Generated by scripts/utils/show_asr_result.sh --> | ||
# RESULTS | ||
## Environments | ||
- date: `Sat Dec 14 19:04:56 EST 2024` | ||
- python version: `3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]` | ||
- espnet version: `espnet 202412` | ||
- pytorch version: `pytorch 2.4.0` | ||
- Git hash: `cb80e61a15d6a13dc342ae5a413d2b870dd869c6` | ||
- Commit date: `Fri Dec 13 11:57:16 2024 -0500` | ||
|
||
## /compute/babel-13-33/sbharad2/expdir/asr_fast.fold[i]/inference_ctc_weight0.0_maxlenratio-1_asr_model_valid.acc.best | ||
### Accuracy | ||
|
||
|dataset|Snt|Wrd|Acc| | ||
|---|---|---|---| | ||
|org/val1|400|400|94.3| | ||
|org/val2|400|400|97.0| | ||
|org/val3|400|400|94.8| | ||
|org/val4|400|400|96.3| | ||
|org/val5|400|400|91.8| | ||
|Average|||94.8| |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
token_type: word | ||
|
||
optim: adamw | ||
optim_conf: | ||
lr: 1.0e-4 | ||
weight_decay: 1.0e-2 | ||
betas: [0.9, 0.98] | ||
|
||
accum_grad: 1 | ||
|
||
batch_size: 128 # 12.5 steps per epoch with 1600 samples | ||
max_epoch: 1000 | ||
|
||
scheduler: CosineAnnealingWarmupRestarts | ||
scheduler_conf: | ||
first_cycle_steps: 6000 | ||
warmup_steps: 300 | ||
max_lr: 1.0e-4 | ||
min_lr: 5.0e-6 | ||
|
||
# BEATs implementation takes care of generating mel spectrogram, normalization and specaug | ||
frontend: none | ||
input_size: 1 # important to set input_size to 1 if frontend is none | ||
normalize: none # BEATs code does global mean and variance normalization | ||
|
||
# Initialization for the decoder | ||
init: xavier_normal | ||
|
||
model_conf: | ||
ctc_weight: 0.0 # No CTC, no attention. | ||
lsm_weight: 0.1 # label smoothing weight | ||
length_normalized_loss: true | ||
|
||
batch_type: folded | ||
unused_parameters: true | ||
grad_clip: 1 | ||
patience: none | ||
best_model_criterion: | ||
- - valid | ||
- acc | ||
- max | ||
keep_nbest_models: 1 | ||
use_amp: false # whether to use automatic mixed precision | ||
num_att_plot: 0 | ||
num_workers: 2 # dataloader workers | ||
|
||
encoder: beats | ||
encoder_conf: | ||
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats | ||
# (iter3) and update the path below | ||
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3.pt | ||
# Most values from Appendix A.1 of the BEATs paper or tuned on fold 5. | ||
# Please also check the README.md | ||
fbank_mean: 11.72215 | ||
fbank_std: 10.60431 | ||
beats_config: | ||
layer_wise_gradient_decay_ratio: 0.2 | ||
encoder_layerdrop: 0.1 | ||
dropout: 0.0 | ||
specaug_config: | ||
apply_time_warp: true | ||
apply_freq_mask: false | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0 | ||
- 0.06 | ||
num_time_mask: 1 | ||
roll_augment: true | ||
roll_interval: 16000 # 1 second, only 5 possible augmentations per sample | ||
use_weighted_representation: false | ||
|
||
# Simple linear decoder for classification. | ||
decoder: linear_decoder | ||
decoder_conf: | ||
pooling: mean | ||
dropout: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
--sample-frequency=16000 | ||
--num-mel-bins=80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Default configuration | ||
command qsub -V -v PATH -S /bin/bash | ||
option name=* -N $0 | ||
option mem=* -l mem=$0 | ||
option mem=0 # Do not add anything to qsub_opts | ||
option num_threads=* -l ncpus=$0 | ||
option num_threads=1 # Do not add anything to qsub_opts | ||
option num_nodes=* -l nodes=$0:ppn=1 | ||
default gpu=0 | ||
option gpu=0 | ||
option gpu=* -l ngpus=$0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
--sample-frequency=16000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Default configuration | ||
command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* | ||
option name=* -N $0 | ||
option mem=* -l mem_free=$0,ram_free=$0 | ||
option mem=0 # Do not add anything to qsub_opts | ||
option num_threads=* -pe smp $0 | ||
option num_threads=1 # Do not add anything to qsub_opts | ||
option max_jobs_run=* -tc $0 | ||
option num_nodes=* -pe mpi $0 # You must set this PE as allocation_rule=1 | ||
default gpu=0 | ||
option gpu=0 | ||
option gpu=* -l gpu=$0 -q g.q |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Default configuration | ||
command sbatch --export=PATH | ||
option name=* --job-name $0 | ||
option time=* --time $0 | ||
option mem=* --mem-per-cpu $0 | ||
option mem=0 | ||
option num_threads=* --cpus-per-task $0 | ||
option num_threads=1 --cpus-per-task 1 | ||
option num_nodes=* --nodes $0 | ||
default gpu=0 | ||
option gpu=0 -p cpu | ||
option gpu=* -p gpu --gres=gpu:$0 -c $0 # Recommend allocating more CPU than, or equal to the number of GPU | ||
# note: the --max-jobs-run option is supported as a special case | ||
# by slurm.pl and you don't have to handle it in the config file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Copyright 2023 Siddhant Arora | ||
# 2023 Carnegie Mellon University | ||
# Apache 2.0 | ||
|
||
import json | ||
import os | ||
import pickle | ||
import sys | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
from torch.utils.data import random_split | ||
|
||
if len(sys.argv) < 2: | ||
print(len(sys.argv)) | ||
print( | ||
"Usage: python data_prep.py [ESC-50_root] [FOLD] [ROOT] where " | ||
"FOLD and ROOT are optional" | ||
) | ||
sys.exit(1) | ||
|
||
esc_root = sys.argv[1] | ||
fold_num = int(sys.argv[2]) if len(sys.argv) > 2 else 1 | ||
data_prep_root = sys.argv[3] if len(sys.argv) == 4 else "." | ||
|
||
meta_data = pd.read_csv(Path(esc_root, "meta", "esc50.csv")) | ||
|
||
split_df = {} | ||
split_df[f"val{fold_num}"] = meta_data[meta_data["fold"] == fold_num] | ||
split_df[f"train{fold_num}"] = meta_data[meta_data["fold"] != fold_num] | ||
|
||
print( | ||
"For fold number:", | ||
fold_num, | ||
"Train and Val split", | ||
len(split_df[f"train{fold_num}"]), | ||
len(split_df[f"val{fold_num}"]), | ||
) | ||
|
||
dir_dict = split_df | ||
for x in dir_dict: | ||
os.makedirs(os.path.join(data_prep_root, "data", x), exist_ok=True) | ||
with open( | ||
os.path.join(data_prep_root, "data", x, "wav.scp"), "w" | ||
) as wav_scp_f, open( | ||
os.path.join(data_prep_root, "data", x, "utt2spk"), "w" | ||
) as utt2spk_f, open( | ||
os.path.join(data_prep_root, "data", x, "text"), "w" | ||
) as text_f: | ||
filename = dir_dict[x]["filename"].values.tolist() | ||
label = dir_dict[x]["target"].values.tolist() | ||
for line_count in range(len(filename)): | ||
cls = "audio_class:" + str(label[line_count]) | ||
utt_id = filename[line_count].replace(".wav", "") | ||
spk = utt_id | ||
data_dir = Path(esc_root, "audio", filename[line_count]) | ||
|
||
wav_scp_f.write(utt_id + " " + str(data_dir) + "\n") | ||
text_f.write(utt_id + " " + cls + "\n") | ||
utt2spk_f.write(utt_id + " " + spk + "\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/usr/bin/env bash | ||
# Set bash to 'debug' mode, it will exit on : | ||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', | ||
set -e | ||
set -u | ||
set -o pipefail | ||
|
||
train_set="train" | ||
valid_set="valid" | ||
test_sets="test valid" | ||
|
||
if python3 -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.7.0")' &> /dev/null; then | ||
asr_config=conf/train_asr.yaml | ||
else | ||
asr_config=conf/tuning/train_asr_transformer_adam_specaug.yaml #s3prl is installed when pytorch > 1.7. Hence using default frontend | ||
fi | ||
|
||
./asr.sh \ | ||
--lang en \ | ||
--ngpu 1 \ | ||
--use_lm false \ | ||
--token_type word\ | ||
--audio_format "flac.ark" \ | ||
--feats_type raw\ | ||
--max_wav_duration 30 \ | ||
--feats_normalize utterance_mvn\ | ||
--inference_nj 8 \ | ||
--inference_asr_model valid.acc.ave_5best.pth\ | ||
--asr_config "${asr_config}" \ | ||
--train_set "${train_set}" \ | ||
--valid_set "${valid_set}" \ | ||
--test_sets "${test_sets}" "$@" |
Oops, something went wrong.