Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning BEATs on AudioSet-2M multi-label classification #6006

Draft
wants to merge 57 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
5d70791
take data normalization via config
Shikhar-S Dec 10, 2024
23914db
add linear decoder for classification tasks
Shikhar-S Dec 10, 2024
225a26f
add 5-fold data prep
Shikhar-S Dec 10, 2024
76f17d5
add config for beats fine-tuning on esc
Shikhar-S Dec 10, 2024
6b8e3d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
468f817
cleanup
Shikhar-S Dec 10, 2024
6a47f63
add readme template
Shikhar-S Dec 10, 2024
fa5f449
Merge branch 'esc' of github.com:Shikhar-S/espnet into esc
Shikhar-S Dec 10, 2024
02bcb79
add results
Shikhar-S Dec 10, 2024
8d3bb3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
fd36ebe
restore default slurm config
Shikhar-S Dec 10, 2024
9148500
add files for cls task
Shikhar-S Dec 10, 2024
c167bd3
fix beats test
Shikhar-S Dec 10, 2024
4e07464
Merge branch 'esc' of github.com:Shikhar-S/espnet into esc
Shikhar-S Dec 10, 2024
4756126
add more test for linear decoder
Shikhar-S Dec 11, 2024
56009b9
add roll augmentation
Shikhar-S Dec 11, 2024
6dcc151
Merge branch 'esc' into cls
Shikhar-S Dec 11, 2024
cb80e61
add quantized rolling, separate vocab concerns in linear decoder
Shikhar-S Dec 13, 2024
b897137
add quantized rolling, separate vocab concerns in linear decoder
Shikhar-S Dec 13, 2024
ca9ca1e
add dropout to linear decoder, clean up
Shikhar-S Dec 15, 2024
22a3c82
cleanup, add model links
Shikhar-S Dec 15, 2024
da315fe
clean up decoder
Shikhar-S Dec 15, 2024
e714871
add inferene and training files for classification
Shikhar-S Dec 15, 2024
0e5a217
clean up decoder
Shikhar-S Dec 15, 2024
452291f
add dropout to linear decoder, clean up
Shikhar-S Dec 15, 2024
e3881fc
cleanup, add model links
Shikhar-S Dec 15, 2024
409893d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2024
16728b9
fix lin dec test
Shikhar-S Dec 15, 2024
6f1a151
fix beats test, the first version had unnecessary dimension always se…
Shikhar-S Dec 15, 2024
bbd2127
Merge branch 'esc' of github.com:Shikhar-S/espnet into esc
Shikhar-S Dec 15, 2024
54b2706
merge esc
Shikhar-S Dec 15, 2024
83fa462
unstable: saving work from babel, initial template for cls task with …
Shikhar-S Dec 25, 2024
c836ede
Merge branch 'espnet:master' into cls
Shikhar-S Dec 27, 2024
77ad502
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2024
d14d6e2
fix cls and setup shells
Shikhar-S Dec 27, 2024
78ae554
add tests for classification metrics
Shikhar-S Dec 28, 2024
bf732ce
merge remote
Shikhar-S Dec 28, 2024
ed25ca7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 28, 2024
7a83600
add scoring scripts
Shikhar-S Dec 29, 2024
6cbd753
add mixup augmentation
Shikhar-S Dec 29, 2024
7447d30
add script to show results
Shikhar-S Dec 29, 2024
6aed560
Merge branch 'cls' of github.com:Shikhar-S/espnet into cls
Shikhar-S Dec 29, 2024
bab37d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2024
4a754da
add test for mixup, change mixup rate
Shikhar-S Jan 2, 2025
018e5bd
Merge branch 'cls' of github.com:Shikhar-S/espnet into cls
Shikhar-S Jan 2, 2025
5a3e173
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
3c371a8
add recipe for AudioSet-20k
Shikhar-S Jan 2, 2025
635b3ad
Merge branch 'cls' of github.com:Shikhar-S/espnet into cls
Shikhar-S Jan 2, 2025
170b673
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
2a6ec9b
tune mixup, use full training data
Shikhar-S Jan 4, 2025
79329f1
Merge branch 'cls' of github.com:Shikhar-S/espnet into cls
Shikhar-S Jan 4, 2025
b1bf3e7
add audio filtering stage
Shikhar-S Jan 4, 2025
5ad1a95
add option to truncate long audio and repeat small audio, fix ci
Shikhar-S Jan 5, 2025
9508135
add packing and hf upload stages
Shikhar-S Jan 5, 2025
9634114
fix ci
Shikhar-S Jan 5, 2025
ad72b4e
add skeleton code for AudioSet-2M finetuning
Shikhar-S Jan 5, 2025
d41e3c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add linear decoder for classification tasks
  • Loading branch information
Shikhar-S committed Dec 10, 2024
commit 23914db1bf1b6098f79af615b2aaa7cd80fc7eb2
83 changes: 83 additions & 0 deletions espnet2/asr/decoder/linear_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""A simple linear layer decoder.

This can be used for classification tasks from sequence input.
"""

from typing import Tuple

import torch
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from typeguard import typechecked


class LinearDecoder(AbsDecoder):

@typechecked
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
pooling: str = "CLS",
):
"""Initialize the module."""
super().__init__()

self.input_dim = encoder_output_size
self.output_dim = vocab_size
self.linear_out = torch.nn.Linear(self.input_dim, self.output_dim)
assert pooling in [
"mean",
"max",
"CLS",
], f"Invalid pooling: {pooling}. Should be 'mean', 'max' or 'CLS'."
self.pooling = pooling

def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor = None,
ys_in_lens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
hs_pad: (B, Tmax, D)
hlens: (B,)
Returns:
output: (B, n_classes)
"""

mask = make_pad_mask(lengths=hlens, xs=hs_pad, length_dim=1).to(hs_pad.device)
if self.pooling == "mean":
unmasked_entries = (~mask).to(dtype=hs_pad.dtype)
input_feature = (hs_pad * unmasked_entries).sum(dim=1)
input_feature = input_feature / unmasked_entries.sum(dim=1)
elif self.pooling == "max":
input_feature = hs_pad.masked_fill(mask, float("-inf"))
input_feature, _ = torch.max(input_feature, dim=1)
elif self.pooling == "CLS":
input_feature = hs_pad[:, 0, :]

output = self.linear_out(input_feature) # Get logits

# Fix blank, unk and sos/eos to -inf
# This ensure that they are never selected at inference.
output[:, 0] = float("-inf")
output[:, 1] = float("-inf")
output[:, -1] = float("-inf")
return output

def score(self, ys, state, x):
"""Classify x."""
hs_len = torch.tensor([x.shape[0]], dtype=torch.long).to(x.device)
logits = self.forward(
x.unsqueeze(0),
hs_len,
)
logp = torch.nn.functional.log_softmax(logits, dim=-1)
return logp.squeeze(0), None

def output_size(self) -> int:
"""Get the output size."""
return self.output_dim
44 changes: 43 additions & 1 deletion espnet2/asr/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.decoder.linear_decoder import LinearDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
)

self.use_transducer_decoder = joint_network is not None
self.use_linear_decoder = isinstance(decoder, LinearDecoder)

self.error_calculator = None

Expand Down Expand Up @@ -155,6 +157,12 @@ def __init__(
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
elif self.use_linear_decoder:
assert ctc_weight == 0.0, "CTC is not supported with LinearDecoder."
self.decoder = decoder
self.criterion_classif = torch.nn.CrossEntropyLoss(
ignore_index=ignore_id, label_smoothing=lsm_weight
)
else:
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
Expand Down Expand Up @@ -243,6 +251,7 @@ def forward(
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_transducer, cer_transducer, wer_transducer = None, None, None
loss_classif, acc_classif = None, None
stats = dict()

# 1. CTC branch
Expand Down Expand Up @@ -325,8 +334,13 @@ def forward(
stats["cer_transducer"] = cer_transducer
stats["wer_transducer"] = wer_transducer

elif self.use_linear_decoder:
# 2b. Linear decoder branch for classification tasks
loss, acc = self._calc_classif_loss(encoder_out, encoder_out_lens, text)
stats["loss"] = loss
stats["acc"] = acc
else:
# 2b. Attention decoder branch
# 2c. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
Expand Down Expand Up @@ -672,3 +686,31 @@ def _calc_batch_ctc_loss(
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
self.ctc.reduce = do_reduce
return loss_ctc

def _calc_classif_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
labels: torch.Tensor,
):
"""Compute classification loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
encoder_out_lens: Encoder output sequences lengths. (B,)
labels: Label ID sequences. (B, 1)
Return:
loss_classif: Classification loss value.
acc_classif: Classification accuracy.
"""
# Calc classification loss
assert labels.dim() == 2, labels.shape
assert labels.shape[1] == 1, labels.shape
logits = self.decoder(encoder_out, encoder_out_lens) # (B, n_class + 3)
logits = logits[:, 2:-1] # remove blank, unk and sos/eos # (B, n_class)
# We do not want unk/seos/blank, just class
assert logits.shape[1] == self.vocab_size - 3, logits.shape
# Shift up labels to remove blank and unk.
labels = labels - 2
loss_classif = self.criterion_classif(logits, labels.squeeze(-1))
acc_classif = th_accuracy(logits, labels, ignore_label=self.ignore_id)
return loss_classif, acc_classif
24 changes: 14 additions & 10 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
]
]

logger = logging.getLogger(__name__)
# NOTE(shikhar): We use contextual logging here because
# RTF calculation looks for "INFO: " as a prefix in the logs.


class Speech2Text:
"""Speech2Text class
Expand Down Expand Up @@ -157,7 +161,7 @@ def __init__(
asr_model.to(dtype=getattr(torch, dtype)).eval()

if quantize_asr_model:
logging.info("Use quantized asr model for decoding.")
logger.info("Use quantized asr model for decoding.")

asr_model = torch.quantization.quantize_dynamic(
asr_model, qconfig_spec=qconfig_spec, dtype=quantize_dtype
Expand All @@ -180,7 +184,7 @@ def __init__(
)

if quantize_lm:
logging.info("Use quantized lm for decoding.")
logger.info("Use quantized lm for decoding.")

lm = torch.quantization.quantize_dynamic(
lm, qconfig_spec=qconfig_spec, dtype=quantize_dtype
Expand Down Expand Up @@ -337,7 +341,7 @@ def __init__(
raise NotImplementedError(
"BeamSearchTimeSync with batching is not yet supported."
)
logging.info("BeamSearchTimeSync implementation is selected.")
logger.info("BeamSearchTimeSync implementation is selected.")

scorers["ctc"] = asr_model.ctc
beam_search = BeamSearchTimeSync(
Expand Down Expand Up @@ -371,14 +375,14 @@ def __init__(
if streaming:
beam_search.__class__ = BatchBeamSearchOnlineSim
beam_search.set_streaming_config(asr_train_config)
logging.info(
logger.info(
"BatchBeamSearchOnlineSim implementation is selected."
)
else:
beam_search.__class__ = BatchBeamSearch
logging.info("BatchBeamSearch implementation is selected.")
logger.info("BatchBeamSearch implementation is selected.")
else:
logging.warning(
logger.warning(
f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation."
)
Expand All @@ -387,8 +391,8 @@ def __init__(
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
logging.info(f"Beam_search: {beam_search}")
logging.info(f"Decoding device={device}, dtype={dtype}")
logger.info(f"Beam_search: {beam_search}")
logger.info(f"Decoding device={device}, dtype={dtype}")

# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
Expand Down Expand Up @@ -466,7 +470,7 @@ def __init__(
beam_search.set_hyp_primer(
list(converter.tokenizer.tokenizer.convert_tokens_to_ids(a1))
)
logging.info(f"Text tokenizer: {tokenizer}")
logger.info(f"Text tokenizer: {tokenizer}")

self.asr_model = asr_model
self.asr_train_args = asr_train_args
Expand Down Expand Up @@ -513,7 +517,7 @@ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> Union[
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
logging.info("speech length: " + str(speech.size(1)))
logger.info("speech length: " + str(speech.size(1)))

# a. To device
batch = to_device(batch, device=self.device)
Expand Down
2 changes: 2 additions & 0 deletions espnet2/tasks/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from espnet2.asr.decoder.hugging_face_transformers_decoder import ( # noqa: H301
HuggingFaceTransformersDecoder,
)
from espnet2.asr.decoder.linear_decoder import LinearDecoder
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
from espnet2.asr.decoder.rnn_decoder import RNNDecoder
from espnet2.asr.decoder.s4_decoder import S4Decoder
Expand Down Expand Up @@ -189,6 +190,7 @@
whisper=OpenAIWhisperDecoder,
hugging_face_transformers=HuggingFaceTransformersDecoder,
s4=S4Decoder,
linear_decoder=LinearDecoder,
),
type_check=AbsDecoder,
default=None,
Expand Down
16 changes: 16 additions & 0 deletions test/espnet2/asr/decoder/test_linear_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch
from espnet2.asr.decoder.linear_decoder import LinearDecoder


@pytest.mark.execution_timeout(30)
@pytest.mark.parametrize("vocab_size", [10, 5])
@pytest.mark.parametrize("encoder_output_size", [4, 21])
@pytest.mark.parametrize("pooling", ["mean", "max", "CLS"])
def test_LinearDecoder_forward_backward(vocab_size, encoder_output_size, pooling):
decoder = LinearDecoder(vocab_size, encoder_output_size, pooling)
x = torch.randn(2, 10, encoder_output_size, requires_grad=True)
x_len = torch.randint(1, 10, [2], dtype=torch.long)
logits = decoder(x, x_len)
assert logits.shape == (2, vocab_size), logits.shape
logits.sum().backward()