forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'sds_demo_recipe' of https://github.com/siddhu001/espnet …
…into sds_demo_recipe
- Loading branch information
Showing
37 changed files
with
3,137 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Clotho Audio Captioning RECIPE | ||
|
||
This recipe implements the DCASE 2023 Automated Audio Captioning (AAC) task with a BEATs encoder BART decoder model on the Clotho_v2 dataset, very close to what is described in [this paper](https://arxiv.org/abs/2309.17352) and reusing part of code from the [original implementation](https://github.com/slSeanWU/beats-conformer-bart-audio-captioner?tab=readme-ov-file). | ||
More specifically, we provide the pre-training and fine-tuning code for second last row in Table 2 (without Instructor embedding) from the paper. | ||
We also reuse the code from the [BEATs repository](https://github.com/microsoft/unilm/tree/master/beats) for this implementation. | ||
|
||
# Training Details and Requirements | ||
The training is divided into an AAC pre-training and a fine-tuning stage. | ||
We do the AAC pre-training on [AudioCaps](https://aclanthology.org/N19-1011/) and [Clotho ChatGPT mixup](https://huggingface.co/datasets/slseanwu/clotho-chatgpt-mixup-50K) data. | ||
Fine-tuning is performed only on the standard development set of Clotho_v2 dataset. | ||
|
||
|
||
### Steps to run | ||
|
||
1. Download AudioCaps and set the path to its root directory in db.sh (This recipe downloads clotho for you). Also download clotho mixup data (wav ids and captions) from [this repo](https://huggingface.co/datasets/slseanwu/clotho-chatgpt-mixup-50K) and set `CLOTHO_CHATGPT_MIXUP` in db.sh. The code will take care of mixing and creating the audio files. | ||
2. Download the BEATs checkpoint: [BEATs_iter3+](https://onedrive.live.com/?authkey=%21AGXnEG4l3mlIzfA&id=6B83B49411CA81A7%2125960&cid=6B83B49411CA81A7&parId=root&parQt=sharedby&o=OneUp) and change the `beats_ckpt_path` path in `conf/beats_bart_pt.yaml` | ||
3. Launch with `run.sh` | ||
|
||
|
||
## Trained checkpoints | ||
AAC Pre-trained model: https://huggingface.co/espnet/DCASE23.AudioCaptioning.PreTrained | ||
|
||
Fine-tuned model: https://huggingface.co/espnet/DCASE23.AudioCaptioning.FineTuned | ||
|
||
|
||
### GPU Time | ||
AAC pre-training takes around ~4 hours on two A6000 gpus. | ||
Fine-tuning, decoding and evalution takes ~1 hour on two A6000. | ||
All the scripts above are setup with 2 GPUs but that can be changed with the `ngpu` argument. Please make sure to change the batch size accordingly (the provided setup requires 42.6GB GPU memory). | ||
|
||
|
||
<!-- Generated by scripts/utils/show_asr_result.sh --> | ||
# RESULTS | ||
## Environments | ||
- date: `Fri Nov 29 20:06:53 EST 2024` | ||
- python version: `3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]` | ||
- espnet version: `espnet 202409` | ||
- pytorch version: `pytorch 2.4.0` | ||
- Git hash: `65ea259e8effab5a43cdff87161a301dc0f20930` | ||
- Commit date: `Fri Nov 29 10:54:44 2024 -0500` | ||
|
||
<!-- Copied from the output produced by local/evaluation.py --> | ||
## exp/asr_ft | ||
``` | ||
===================================================== | ||
Split: evaluation Evaluation over 1045 predictions. | ||
===================================================== | ||
cider_d : 0.46045061153488653 | ||
spice : 0.1345877073651595 | ||
spider : 0.297519159450023 | ||
sbert_sim : 0.5147198918907185 | ||
fer : 0.019138755980861243 | ||
fense : 0.5048554784476391 | ||
meteor : 0.18487611036251259 | ||
rouge_l : 0.39408182293804006 | ||
fer.add_tail_prob : 0.046671394258737564 | ||
fer.repeat_event_prob: 0.07274453341960907 | ||
fer.repeat_adv_prob : 0.0019690715707838535 | ||
fer.remove_conj_prob: 0.12462866306304932 | ||
fer.remove_verb_prob: 0.22472403943538666 | ||
fer.error_prob : 0.34762266278266907 | ||
spider_fl : 0.2923702923078383 | ||
===================================================== | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../TEMPLATE/asr1/asr.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ====== | ||
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...> | ||
# e.g. | ||
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB | ||
# | ||
# Options: | ||
# --time <time>: Limit the maximum time to execute. | ||
# --mem <mem>: Limit the maximum memory usage. | ||
# -–max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs. | ||
# --num-threads <ngpu>: Specify the number of CPU core. | ||
# --gpu <ngpu>: Specify the number of GPU devices. | ||
# --config: Change the configuration file from default. | ||
# | ||
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs. | ||
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name, | ||
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively. | ||
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example. | ||
# | ||
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend. | ||
# These options are mapping to specific options for each backend and | ||
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default. | ||
# If jobs failed, your configuration might be wrong for your environment. | ||
# | ||
# | ||
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl: | ||
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html | ||
# =========================================================~ | ||
|
||
|
||
# Select the backend used by run.sh from "local", "stdout", "sge", "slurm", or "ssh" | ||
cmd_backend='local' | ||
|
||
# Local machine, without any Job scheduling system | ||
if [ "${cmd_backend}" = local ]; then | ||
|
||
# The other usage | ||
export train_cmd="run.pl" | ||
# Used for "*_train.py": "--gpu" is appended optionally by run.sh | ||
export cuda_cmd="run.pl" | ||
# Used for "*_recog.py" | ||
export decode_cmd="run.pl" | ||
|
||
# Local machine logging to stdout and log file, without any Job scheduling system | ||
elif [ "${cmd_backend}" = stdout ]; then | ||
|
||
# The other usage | ||
export train_cmd="stdout.pl" | ||
# Used for "*_train.py": "--gpu" is appended optionally by run.sh | ||
export cuda_cmd="stdout.pl" | ||
# Used for "*_recog.py" | ||
export decode_cmd="stdout.pl" | ||
|
||
|
||
# "qsub" (Sun Grid Engine, or derivation of it) | ||
elif [ "${cmd_backend}" = sge ]; then | ||
# The default setting is written in conf/queue.conf. | ||
# You must change "-q g.q" for the "queue" for your environment. | ||
# To know the "queue" names, type "qhost -q" | ||
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler. | ||
|
||
export train_cmd="queue.pl" | ||
export cuda_cmd="queue.pl" | ||
export decode_cmd="queue.pl" | ||
|
||
|
||
# "qsub" (Torque/PBS.) | ||
elif [ "${cmd_backend}" = pbs ]; then | ||
# The default setting is written in conf/pbs.conf. | ||
|
||
export train_cmd="pbs.pl" | ||
export cuda_cmd="pbs.pl" | ||
export decode_cmd="pbs.pl" | ||
|
||
|
||
# "sbatch" (Slurm) | ||
elif [ "${cmd_backend}" = slurm ]; then | ||
# The default setting is written in conf/slurm.conf. | ||
# You must change "-p cpu" and "-p gpu" for the "partition" for your environment. | ||
# To know the "partion" names, type "sinfo". | ||
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*" | ||
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}". | ||
|
||
export train_cmd="slurm.pl" | ||
export cuda_cmd="slurm.pl" | ||
export decode_cmd="slurm.pl" | ||
|
||
elif [ "${cmd_backend}" = ssh ]; then | ||
# You have to create ".queue/machines" to specify the host to execute jobs. | ||
# e.g. .queue/machines | ||
# host1 | ||
# host2 | ||
# host3 | ||
# Assuming you can login them without any password, i.e. You have to set ssh keys. | ||
|
||
export train_cmd="ssh.pl" | ||
export cuda_cmd="ssh.pl" | ||
export decode_cmd="ssh.pl" | ||
|
||
# This is an example of specifying several unique options in the JHU CLSP cluster setup. | ||
# Users can modify/add their own command options according to their cluster environments. | ||
elif [ "${cmd_backend}" = jhu ]; then | ||
|
||
export train_cmd="queue.pl --mem 2G" | ||
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/queue.conf" | ||
export decode_cmd="queue.pl --mem 4G" | ||
|
||
else | ||
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2 | ||
return 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"vocab_size": 50265, | ||
"max_position_embeddings": 256, | ||
"decoder_layers": 6, | ||
"decoder_ffn_dim": 3072, | ||
"decoder_attention_heads": 16, | ||
"d_model": 768, | ||
"decoder_layerdrop": 0.0, | ||
"activation_function": "gelu", | ||
"dropout": 0.1, | ||
"attention_dropout": 0.1, | ||
"activation_dropout": 0.1, | ||
"scale_embedding": false, | ||
"use_cache": true, | ||
"ignore_mismatched_sizes": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
batch_type: folded | ||
unused_parameters: true | ||
batch_size: 64 | ||
accum_grad: 2 | ||
max_epoch: 8 | ||
grad_clip: 1 | ||
patience: none | ||
best_model_criterion: | ||
- - valid | ||
- acc | ||
- max | ||
keep_nbest_models: 5 | ||
use_amp: false # whether to use automatic mixed precision | ||
num_att_plot: 0 | ||
num_workers: 2 # dataloader workers | ||
|
||
# BEATs implementation takes care of generating mel spectrogram, normalization and specaug | ||
frontend: none | ||
input_size: 1 # important to set input_size to 1 if frontend is none | ||
normalize: none # BEATs code does global mean and variance normalization | ||
|
||
freeze_param: [ | ||
"encoder.encoder", | ||
"encoder.layer_norm", | ||
"encoder.patch_embedding", | ||
"encoder.post_extract_proj", | ||
] | ||
|
||
encoder: beats | ||
encoder_conf: | ||
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats | ||
# (iter3+, Beats finetuned model 1) and update the path below | ||
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt | ||
specaug_config: | ||
apply_freq_mask: true | ||
freq_mask_width_range: | ||
- 0 | ||
- 64 | ||
num_freq_mask: 2 | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0 | ||
- 0.12 | ||
num_time_mask: 5 | ||
adapter_config: conf/wav2vec2_conformer_config.json | ||
downsampling_rate: 3 # CNN downsampling over beats encoder | ||
max_layer: 10 # 0 based index | ||
use_weighted_representation: false | ||
add_positional_information: true | ||
max_positions: 1024 # These many positional embeddings will be learned | ||
|
||
# Pleae note that the decoder is not the standard BART-base, | ||
# but a custom one whose config is defined in the file below | ||
decoder: hugging_face_transformers | ||
decoder_conf: | ||
model_name_or_path: facebook/bart-base | ||
overriding_architecture_config: conf/bart_decoder_config.json | ||
load_pretrained_weights: false | ||
separate_lm_head: true | ||
|
||
# Initialization does not matter we use a pre-trained model | ||
init: normal | ||
|
||
token_type: hugging_face | ||
|
||
# Loss, optimizer, scheduler | ||
model_conf: | ||
ctc_weight: 0.0 # No CTC, only attention branch | ||
lsm_weight: 0.1 # label smoothing weight | ||
length_normalized_loss: true | ||
# BART Dictionary customizations | ||
ignore_id: 1 | ||
sym_blank: "<pad>" | ||
sym_sos: "<s>" | ||
sym_eos: "</s>" | ||
sym_space: "Ġ" | ||
|
||
optim: adamw | ||
optim_conf: | ||
lr: 0.00002 # 2e-5 | ||
weight_decay: 0.001 # 1e-3 | ||
|
||
scheduler: warmuplr | ||
scheduler_conf: | ||
warmup_steps: 1000 #1k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
batch_type: folded | ||
unused_parameters: true | ||
batch_size: 64 | ||
accum_grad: 1 | ||
max_epoch: 10 | ||
grad_clip: 1 | ||
patience: none | ||
best_model_criterion: | ||
- - valid | ||
- acc | ||
- max | ||
keep_nbest_models: 5 | ||
use_amp: false # whether to use automatic mixed precision | ||
num_att_plot: 0 | ||
num_workers: 2 # number of workers in dataloader | ||
|
||
# BEATs implementation takes care of generating mel spectrogram, normalization and specaug | ||
frontend: none | ||
input_size: 1 # important to set input_size to 1 if frontend is none | ||
normalize: none # BEATs code does global mean and variance normalization | ||
|
||
|
||
freeze_param: [ | ||
"encoder.encoder", | ||
"encoder.layer_norm", | ||
"encoder.patch_embedding", | ||
"encoder.post_extract_proj", | ||
] | ||
|
||
|
||
encoder: beats | ||
encoder_conf: | ||
# Please download the BEATs model from https://github.com/microsoft/unilm/tree/master/beats | ||
# (iter3+, Beats finetuned model 1) and update the path below | ||
beats_ckpt_path: /compute/babel-13-33/sbharad2/models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt | ||
specaug_config: | ||
apply_freq_mask: true | ||
freq_mask_width_range: | ||
- 0 | ||
- 64 | ||
num_freq_mask: 2 | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0 | ||
- 0.12 | ||
num_time_mask: 5 | ||
adapter_config: conf/wav2vec2_conformer_config.json | ||
downsampling_rate: 3 # CNN downsampling over beats encoder | ||
max_layer: 10 # 0 based index | ||
use_weighted_representation: false | ||
add_positional_information: true | ||
max_positions: 1024 # These many positional embeddings will be learned | ||
|
||
# Pleae note that the decoder is not the standard BART-base, | ||
# but a custom one whose config is defined in the file below | ||
decoder: hugging_face_transformers | ||
decoder_conf: | ||
model_name_or_path: facebook/bart-base | ||
overriding_architecture_config: conf/bart_decoder_config.json | ||
load_pretrained_weights: false | ||
separate_lm_head: true | ||
|
||
# This takes care of BART and adapter layers' initialization | ||
# (essentially everything else other than BEATs encoder backbone) | ||
init: normal | ||
|
||
token_type: hugging_face | ||
|
||
# Loss, optimizer, scheduler | ||
model_conf: | ||
ctc_weight: 0.0 # No CTC, only attention branch | ||
lsm_weight: 0.1 # label smoothing weight | ||
length_normalized_loss: true # length normalization | ||
# BART Dictionary customizations | ||
ignore_id: 1 | ||
sym_blank: "<pad>" | ||
sym_sos: "<s>" | ||
sym_eos: "</s>" | ||
sym_space: "Ġ" | ||
|
||
optim: adamw | ||
optim_conf: | ||
lr: 0.0002 # 2e-4 | ||
weight_decay: 0.001 | ||
|
||
scheduler: warmuplr | ||
scheduler_conf: | ||
warmup_steps: 2500 # 2.5k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
--sample-frequency=16000 | ||
--num-mel-bins=80 |
Oops, something went wrong.