Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
decoupled encoders cfg for mixed speech model
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-karpukhin committed Mar 26, 2021
1 parent c385c55 commit 00b3fba
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 117 deletions.
24 changes: 24 additions & 0 deletions conf/encoder/fairseq_roberta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _group_

# model type. One of [hf_bert, pytext_bert, fairseq_roberta]
encoder_model_type: fairseq_roberta

# this is only used by HF
pretrained_model_cfg: roberta-base

# Some encoders need to be initialized from a file
pretrained_file: /private/home/vladk/data/fairseq_checkpoints/roberta.base/

# Extra linear layer on top of standard bert/roberta encoder
projection_dim: 0

# Max length of the encoder input sequence
sequence_length: 256

dropout: 0.1

# whether to fix (don't update) context encoder during training or not
fix_ctx_encoder: False

# if False, the model won't load pre-trained BERT weights
pretrained: True
65 changes: 28 additions & 37 deletions conf/encoder/speech_mixed.yaml
Original file line number Diff line number Diff line change
@@ -1,54 +1,45 @@

# @package _group_

# model type. One of [mixed_hf_bert_wav2vec, mixed_hf_bert_hubert]
# CHANGE THIS TO for a different encoder!
encoder_model_type: mixed_hf_bert_wav2vec
# encoder_model_type: mixed_hf_bert_hubert

# HuggingFace's config name for model initialization
pretrained_model_cfg: bert-base-uncased
encoder_model_type: mixed_audio

# Some encoders need to be initialized from a file
pretrained_file:

# Which layer representation to use
output_layer:
# ------------ QUERY ENCODER ------------
q_encoder_type: # hf-wav2vec fairseq-wav2vec or fairseq-hubert

# Extra linear layer on top of standard bert/roberta encoder
projection_dim: 0

# Max length of the encoder input sequence
sequence_length: 256
# HF only params
q_wav2vec_model_cfg: #facebook/wav2vec2-base-960h

dropout: 0.1
# fairseq only params
q_wav2vec_cp_file: #/checkpoint/vladk/speechqa/wav2vec_small_960h.pt
q_wav2vec_apply_mask: True
q_output_layer: # Which layer representation to use

# whether to fix (don't update) context encoder during training or not
fix_ctx_encoder: False
q_projection_dim: 768 # Extra linear layer on top of pre-trained encoder
q_dropout: 0.1
q_use_activation: False
q_max_audio_t: 300
q_audio_encoder_lr_factor: 0

# if False, the model won't load pre-trained BERT weights
pretrained: True
# ------------ CTX ENCODER ------------
ctx_encoder_type: # hf-bert or fairseq-roberta

# fairseq only params
ctx_pretrained_file: # /private/home/vladk/data/fairseq_checkpoints/roberta.base/

# HF params
pretrained_wav2vec_model_cfg:
#facebook/wav2vec2-base-960h
ctx_model_cfg: bert-base-uncased # roberta-base
ctx_projection_dim: 0 # Extra linear layer on top of pre-trained encoder
ctx_sequence_length: 256 # Max length of the encoder input sequence
ctx_dropout: 0.1
ctx_pretrained: True # if False, the model won't load pre-trained BERT weights

wav2_vec_extra_proj_dim: 768 # TODO: make a common param
wav2vec_dropout: 0.1

# fairseq params
# CHANGE THIS TO for a different encoder!
wav2vec_cp_file: /checkpoint/vladk/speechqa/wav2vec_small.pt
# non finetuned
# wav2vec_cp_file: /checkpoint/vladk/speechqa/wav2vec_small.pt
# -------------- COMMON -------------------

wav2vec_apply_mask: True
# whether to fix (don't update) context encoder during training or not
fix_ctx_encoder: False

#TODO: move to train config group?
optimizer: hf-adam # fairseq-adam

# wav2vec common params
wav2vec_max_audio_t: 300
wav2vec_use_activation: False

#TODO: move to train cfg group
audio_encoder_lr_factor: 0
5 changes: 2 additions & 3 deletions dpr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def init_hf_roberta_tenzorizer(args, **kwargs):
raise RuntimeError("Please install transformers lib")
from .hf_models import get_roberta_tensorizer

return get_roberta_tensorizer(args)
return get_roberta_tensorizer(args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length)


def init_audio_mixed_biencoder_components(args, **kwargs):
Expand All @@ -74,8 +74,7 @@ def init_audio_mixed_biencoder_components(args, **kwargs):
"hf_bert": init_hf_bert_biencoder,
"pytext_bert": init_pytext_bert_biencoder,
"fairseq_roberta": init_fairseq_roberta_biencoder,
"mixed_hf_bert_wav2vec": init_audio_mixed_biencoder_components,
"mixed_hf_bert_hubert": init_audio_mixed_biencoder_components,
"mixed_audio": init_audio_mixed_biencoder_components,
}

READER_INITIALIZERS = {
Expand Down
51 changes: 34 additions & 17 deletions dpr/models/fairseq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
Encoder model wrappers based on Fairseq code
"""

import collections
import logging
from typing import Tuple

Expand All @@ -20,26 +20,29 @@

import fairseq
from dpr.models.hf_models import get_roberta_tensorizer
from dpr.utils.data_utils import Tensorizer
from fairseq.optim.adam import FairseqAdam
from .biencoder import BiEncoder

logger = logging.getLogger(__name__)

FairseqOptCfg = collections.namedtuple("FairseqOptCfg", ["lr", "adam_betas", "adam_eps", "weight_decay"])


def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs):
question_encoder = RobertaEncoder.from_pretrained(args.pretrained_file)
ctx_encoder = RobertaEncoder.from_pretrained(args.pretrained_file)
question_encoder = RobertaEncoder.from_pretrained(args.encoder.pretrained_file)
ctx_encoder = RobertaEncoder.from_pretrained(args.encoder.pretrained_file)
biencoder = BiEncoder(question_encoder, ctx_encoder)
optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None

tensorizer = get_roberta_tensorizer(args)

tensorizer = get_roberta_tensorizer(
args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length
)
return tensorizer, biencoder, optimizer


def get_fairseq_adamw_optimizer(model: nn.Module, args):
setattr(args, "lr", [args.learning_rate])
return FairseqAdam(args, model.parameters()).optimizer
cfg = FairseqOptCfg(args.train.learning_rate, args.train.adam_betas, args.train.adam_eps, args.train.weight_decay)
return FairseqAdam(cfg, model.parameters()).optimizer


class RobertaEncoder(nn.Module):
Expand All @@ -52,9 +55,15 @@ def from_pretrained(cls, pretrained_dir_path: str):
model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path)
return cls(model)

def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]:
def forward(
self,
input_ids: T,
token_type_ids: T,
attention_mask: T,
representation_token_pos=0,
) -> Tuple[T, ...]:
roberta_out = self.fairseq_roberta.extract_features(input_ids)
cls_out = roberta_out[:, 0, :]
cls_out = roberta_out[:, representation_token_pos, :]
return roberta_out, cls_out, None

def get_out_size(self):
Expand All @@ -69,6 +78,7 @@ def __init__(
max_audio_t: int,
use_tanh: bool = True,
dropout: float = 0.0,
output_layer: str = None,
):
super(Wav2Vec2Encoder, self).__init__()
state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file)
Expand Down Expand Up @@ -99,13 +109,12 @@ def __init__(

self.apply_mask = apply_mask
self.use_tanh = use_tanh
self.output_layer = output_layer

# TODO: remove after debug
self.tmp_long_audio_samples = 0

def forward(
self, input_ids: T, _token_type_ids: T, attention_mask: T, representation_token_pos=0, output_layer=None
) -> Tuple[T, ...]:
def forward(self, input_ids: T, _token_type_ids: T, attention_mask: T, representation_token_pos=0) -> Tuple[T, ...]:
mask = self.apply_mask and self.training

# TODO: remove after debug
Expand Down Expand Up @@ -174,6 +183,7 @@ def __init__(
max_audio_t: int,
use_tanh: bool = True,
dropout: float = 0.0,
output_layer: str = None,
):
super(HubertEncoder, self).__init__()
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_file])
Expand All @@ -196,20 +206,19 @@ def __init__(

self.apply_mask = apply_mask
self.use_tanh = use_tanh
self.output_layer = output_layer

# TODO: remove after debug
self.tmp_long_audio_samples = 0

def forward(
self, input_ids: T, _token_type_ids: T, padding_mask: T, representation_token_pos=0, output_layer=None
) -> Tuple[T, ...]:
def forward(self, input_ids: T, _token_type_ids: T, padding_mask: T, representation_token_pos=0) -> Tuple[T, ...]:
mask = self.apply_mask and self.training

# TODO: remove after debug
torch.cuda.ipc_collect()

features, padding_mask = self.model.extract_features(
input_ids, padding_mask=padding_mask, mask=mask, output_layer=output_layer
input_ids, padding_mask=padding_mask, mask=mask, output_layer=self.output_layer
)

bsz, seq_len, feature_dim = features.size()
Expand Down Expand Up @@ -258,3 +267,11 @@ def pad_to_len(

def get_out_size(self):
return self.hidden_size


def get_roberta_encoder_components(
pretrained_file: str, pretrained_model_cfg: str, do_lower_case: bool, sequence_length: int
) -> Tuple[RobertaEncoder, Tensorizer]:
encoder = RobertaEncoder.from_pretrained(pretrained_file)
tensorizer = get_roberta_tensorizer(pretrained_model_cfg, do_lower_case, sequence_length)
return encoder, tensorizer
7 changes: 3 additions & 4 deletions dpr/models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ def _add_special_tokens(tokenizer, special_tokens):
logger.info("Tokenizer's all_special_tokens %s", tokenizer.all_special_tokens)


def get_roberta_tensorizer(args, tokenizer=None):
if not tokenizer:
tokenizer = get_roberta_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case)
return RobertaTensorizer(tokenizer, args.sequence_length)
def get_roberta_tensorizer(pretrained_model_cfg: str, do_lower_case: bool, sequence_length: int):
tokenizer = get_roberta_tokenizer(pretrained_model_cfg, do_lower_case=do_lower_case)
return RobertaTensorizer(tokenizer, sequence_length)


def get_optimizer(
Expand Down
Loading

0 comments on commit 00b3fba

Please sign in to comment.