Skip to content

Commit

Permalink
Merge branch 'sds_demo_recipe' of https://github.com/siddhu001/espnet
Browse files Browse the repository at this point in the history
…into sds_demo_recipe
  • Loading branch information
Siddhant committed Dec 9, 2024
2 parents 0bcca8e + d6b5a5a commit c06197d
Show file tree
Hide file tree
Showing 37 changed files with 3,137 additions and 25 deletions.
1 change: 1 addition & 0 deletions egs2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ See: https://espnet.github.io/espnet/espnet2_tutorial.html#recipes-using-espnet2
| bur_openslr80 | Burmese ASR training dataset | ASR | BUR | https://openslr.org/80/ | |
| catslu | CATSLU-MAPS | SLU | CMN | https://sites.google.com/view/catslu/home | |
| catslu_entity | CATSLU | SLU/Entity Classifi. | CMN | https://sites.google.com/view/catslu/home | |
| clotho_v2 | Clotho v2.1 dataset for audio captioning | ASR | ENG | https://zenodo.org/records/4783391
| chime1 | The 1st CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | https://spandh.dcs.shef.ac.uk/chime_challenge/chime2011/ | |
| chime2 | The 2nd CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | https://spandh.dcs.shef.ac.uk/chime_challenge/chime2013/ | |
| chime4 | The 4th CHiME Speech Separation and Recognition Challenge | ASR/Multichannel ASR | ENG | http://spandh.dcs.shef.ac.uk/chime_challenge/chime2016/ | |
Expand Down
3 changes: 3 additions & 0 deletions egs2/TEMPLATE/asr1/db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ SPRING_INX=downloads
VOXCELEB=
KSPONSPEECH=
HIFITTS=downloads
CLOTHO_V2=downloads
AUDIOCAPS=
CLOTHO_CHATGPT_MIXUP=

# For only CMU TIR environment
if [[ "$(hostname)" == tir* ]]; then
Expand Down
64 changes: 64 additions & 0 deletions egs2/clotho_v2/asr1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Clotho Audio Captioning RECIPE

This recipe implements the DCASE 2023 Automated Audio Captioning (AAC) task with a BEATs encoder BART decoder model on the Clotho_v2 dataset, very close to what is described in [this paper](https://arxiv.org/abs/2309.17352) and reusing part of code from the [original implementation](https://github.com/slSeanWU/beats-conformer-bart-audio-captioner?tab=readme-ov-file).
More specifically, we provide the pre-training and fine-tuning code for second last row in Table 2 (without Instructor embedding) from the paper.
We also reuse the code from the [BEATs repository](https://github.com/microsoft/unilm/tree/master/beats) for this implementation.

# Training Details and Requirements
The training is divided into an AAC pre-training and a fine-tuning stage.
We do the AAC pre-training on [AudioCaps](https://aclanthology.org/N19-1011/) and [Clotho ChatGPT mixup](https://huggingface.co/datasets/slseanwu/clotho-chatgpt-mixup-50K) data.
Fine-tuning is performed only on the standard development set of Clotho_v2 dataset.


### Steps to run

1. Download AudioCaps and set the path to its root directory in db.sh (This recipe downloads clotho for you). Also download clotho mixup data (wav ids and captions) from [this repo](https://huggingface.co/datasets/slseanwu/clotho-chatgpt-mixup-50K) and set `CLOTHO_CHATGPT_MIXUP` in db.sh. The code will take care of mixing and creating the audio files.
2. Download the BEATs checkpoint: [BEATs_iter3+](https://onedrive.live.com/?authkey=%21AGXnEG4l3mlIzfA&id=6B83B49411CA81A7%2125960&cid=6B83B49411CA81A7&parId=root&parQt=sharedby&o=OneUp) and change the `beats_ckpt_path` path in `conf/beats_bart_pt.yaml`
3. Launch with `run.sh`


## Trained checkpoints
AAC Pre-trained model: https://huggingface.co/espnet/DCASE23.AudioCaptioning.PreTrained

Fine-tuned model: https://huggingface.co/espnet/DCASE23.AudioCaptioning.FineTuned


### GPU Time
AAC pre-training takes around ~4 hours on two A6000 gpus.
Fine-tuning, decoding and evalution takes ~1 hour on two A6000.
All the scripts above are setup with 2 GPUs but that can be changed with the `ngpu` argument. Please make sure to change the batch size accordingly (the provided setup requires 42.6GB GPU memory).


<!-- Generated by scripts/utils/show_asr_result.sh -->
# RESULTS
## Environments
- date: `Fri Nov 29 20:06:53 EST 2024`
- python version: `3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]`
- espnet version: `espnet 202409`
- pytorch version: `pytorch 2.4.0`
- Git hash: `65ea259e8effab5a43cdff87161a301dc0f20930`
- Commit date: `Fri Nov 29 10:54:44 2024 -0500`

<!-- Copied from the output produced by local/evaluation.py -->
## exp/asr_ft
```
=====================================================
Split: evaluation Evaluation over 1045 predictions.
=====================================================
cider_d : 0.46045061153488653
spice : 0.1345877073651595
spider : 0.297519159450023
sbert_sim : 0.5147198918907185
fer : 0.019138755980861243
fense : 0.5048554784476391
meteor : 0.18487611036251259
rouge_l : 0.39408182293804006
fer.add_tail_prob : 0.046671394258737564
fer.repeat_event_prob: 0.07274453341960907
fer.repeat_adv_prob : 0.0019690715707838535
fer.remove_conj_prob: 0.12462866306304932
fer.remove_verb_prob: 0.22472403943538666
fer.error_prob : 0.34762266278266907
spider_fl : 0.2923702923078383
=====================================================
```
1 change: 1 addition & 0 deletions egs2/clotho_v2/asr1/asr.sh
110 changes: 110 additions & 0 deletions egs2/clotho_v2/asr1/cmd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>
# e.g.
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
#
# Options:
# --time <time>: Limit the maximum time to execute.
# --mem <mem>: Limit the maximum memory usage.
# -–max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs.
# --num-threads <ngpu>: Specify the number of CPU core.
# --gpu <ngpu>: Specify the number of GPU devices.
# --config: Change the configuration file from default.
#
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name,
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
#
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
# These options are mapping to specific options for each backend and
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
# If jobs failed, your configuration might be wrong for your environment.
#
#
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
# =========================================================~


# Select the backend used by run.sh from "local", "stdout", "sge", "slurm", or "ssh"
cmd_backend='local'

# Local machine, without any Job scheduling system
if [ "${cmd_backend}" = local ]; then

# The other usage
export train_cmd="run.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="run.pl"
# Used for "*_recog.py"
export decode_cmd="run.pl"

# Local machine logging to stdout and log file, without any Job scheduling system
elif [ "${cmd_backend}" = stdout ]; then

# The other usage
export train_cmd="stdout.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="stdout.pl"
# Used for "*_recog.py"
export decode_cmd="stdout.pl"


# "qsub" (Sun Grid Engine, or derivation of it)
elif [ "${cmd_backend}" = sge ]; then
# The default setting is written in conf/queue.conf.
# You must change "-q g.q" for the "queue" for your environment.
# To know the "queue" names, type "qhost -q"
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.

export train_cmd="queue.pl"
export cuda_cmd="queue.pl"
export decode_cmd="queue.pl"


# "qsub" (Torque/PBS.)
elif [ "${cmd_backend}" = pbs ]; then
# The default setting is written in conf/pbs.conf.

export train_cmd="pbs.pl"
export cuda_cmd="pbs.pl"
export decode_cmd="pbs.pl"


# "sbatch" (Slurm)
elif [ "${cmd_backend}" = slurm ]; then
# The default setting is written in conf/slurm.conf.
# You must change "-p cpu" and "-p gpu" for the "partition" for your environment.
# To know the "partion" names, type "sinfo".
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".

export train_cmd="slurm.pl"
export cuda_cmd="slurm.pl"
export decode_cmd="slurm.pl"

elif [ "${cmd_backend}" = ssh ]; then
# You have to create ".queue/machines" to specify the host to execute jobs.
# e.g. .queue/machines
# host1
# host2
# host3
# Assuming you can login them without any password, i.e. You have to set ssh keys.

export train_cmd="ssh.pl"
export cuda_cmd="ssh.pl"
export decode_cmd="ssh.pl"

# This is an example of specifying several unique options in the JHU CLSP cluster setup.
# Users can modify/add their own command options according to their cluster environments.
elif [ "${cmd_backend}" = jhu ]; then

export train_cmd="queue.pl --mem 2G"
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/queue.conf"
export decode_cmd="queue.pl --mem 4G"

else
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
return 1
fi
16 changes: 16 additions & 0 deletions egs2/clotho_v2/asr1/conf/bart_decoder_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"vocab_size": 50265,
"max_position_embeddings": 256,
"decoder_layers": 6,
"decoder_ffn_dim": 3072,
"decoder_attention_heads": 16,
"d_model": 768,
"decoder_layerdrop": 0.0,
"activation_function": "gelu",
"dropout": 0.1,
"attention_dropout": 0.1,
"activation_dropout": 0.1,
"scale_embedding": false,
"use_cache": true,
"ignore_mismatched_sizes": true
}
85 changes: 85 additions & 0 deletions egs2/clotho_v2/asr1/conf/beats_bart_ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
batch_type: folded
unused_parameters: true
batch_size: 64
accum_grad: 2
max_epoch: 8
grad_clip: 1
patience: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 5
use_amp: false # whether to use automatic mixed precision
num_att_plot: 0
num_workers: 2 # dataloader workers

# BEATs implementation takes care of generating mel spectrogram, normalization and specaug
frontend: none
input_size: 1 # important to set input_size to 1 if frontend is none
normalize: none # BEATs code does global mean and variance normalization

freeze_param: [
"encoder.encoder",
"encoder.layer_norm",
"encoder.patch_embedding",
"encoder.post_extract_proj",
]

encoder: beats
encoder_conf:
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats
# (iter3+, Beats finetuned model 1) and update the path below
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt
specaug_config:
apply_freq_mask: true
freq_mask_width_range:
- 0
- 64
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0
- 0.12
num_time_mask: 5
adapter_config: conf/wav2vec2_conformer_config.json
downsampling_rate: 3 # CNN downsampling over beats encoder
max_layer: 10 # 0 based index
use_weighted_representation: false
add_positional_information: true
max_positions: 1024 # These many positional embeddings will be learned

# Pleae note that the decoder is not the standard BART-base,
# but a custom one whose config is defined in the file below
decoder: hugging_face_transformers
decoder_conf:
model_name_or_path: facebook/bart-base
overriding_architecture_config: conf/bart_decoder_config.json
load_pretrained_weights: false
separate_lm_head: true

# Initialization does not matter we use a pre-trained model
init: normal

token_type: hugging_face

# Loss, optimizer, scheduler
model_conf:
ctc_weight: 0.0 # No CTC, only attention branch
lsm_weight: 0.1 # label smoothing weight
length_normalized_loss: true
# BART Dictionary customizations
ignore_id: 1
sym_blank: "<pad>"
sym_sos: "<s>"
sym_eos: "</s>"
sym_space: "Ġ"

optim: adamw
optim_conf:
lr: 0.00002 # 2e-5
weight_decay: 0.001 # 1e-3

scheduler: warmuplr
scheduler_conf:
warmup_steps: 1000 #1k
88 changes: 88 additions & 0 deletions egs2/clotho_v2/asr1/conf/beats_bart_pt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
batch_type: folded
unused_parameters: true
batch_size: 64
accum_grad: 1
max_epoch: 10
grad_clip: 1
patience: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 5
use_amp: false # whether to use automatic mixed precision
num_att_plot: 0
num_workers: 2 # number of workers in dataloader

# BEATs implementation takes care of generating mel spectrogram, normalization and specaug
frontend: none
input_size: 1 # important to set input_size to 1 if frontend is none
normalize: none # BEATs code does global mean and variance normalization


freeze_param: [
"encoder.encoder",
"encoder.layer_norm",
"encoder.patch_embedding",
"encoder.post_extract_proj",
]


encoder: beats
encoder_conf:
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats
# (iter3+, Beats finetuned model 1) and update the path below
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt
specaug_config:
apply_freq_mask: true
freq_mask_width_range:
- 0
- 64
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0
- 0.12
num_time_mask: 5
adapter_config: conf/wav2vec2_conformer_config.json
downsampling_rate: 3 # CNN downsampling over beats encoder
max_layer: 10 # 0 based index
use_weighted_representation: false
add_positional_information: true
max_positions: 1024 # These many positional embeddings will be learned

# Pleae note that the decoder is not the standard BART-base,
# but a custom one whose config is defined in the file below
decoder: hugging_face_transformers
decoder_conf:
model_name_or_path: facebook/bart-base
overriding_architecture_config: conf/bart_decoder_config.json
load_pretrained_weights: false
separate_lm_head: true

# This takes care of BART and adapter layers' initialization
# (essentially everything else other than BEATs encoder backbone)
init: normal

token_type: hugging_face

# Loss, optimizer, scheduler
model_conf:
ctc_weight: 0.0 # No CTC, only attention branch
lsm_weight: 0.1 # label smoothing weight
length_normalized_loss: true # length normalization
# BART Dictionary customizations
ignore_id: 1
sym_blank: "<pad>"
sym_sos: "<s>"
sym_eos: "</s>"
sym_space: "Ġ"

optim: adamw
optim_conf:
lr: 0.0002 # 2e-4
weight_decay: 0.001

scheduler: warmuplr
scheduler_conf:
warmup_steps: 2500 # 2.5k
2 changes: 2 additions & 0 deletions egs2/clotho_v2/asr1/conf/fbank.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--sample-frequency=16000
--num-mel-bins=80
Loading

0 comments on commit c06197d

Please sign in to comment.