Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
save all mixed encoder params during traning adn restore upon reading…
Browse files Browse the repository at this point in the history
… from a checkpoint
  • Loading branch information
vlad-karpukhin committed Mar 30, 2021
1 parent 00b3fba commit 0b8eb15
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 77 deletions.
7 changes: 6 additions & 1 deletion conf/datasets/retriever_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@ nq_audio_test:
nq_tts_asr_test:
_target_: dpr.data.retriever_data.TTS_ASR_QASrc
file: /private/home/scottyih/playground/bert-qa/data/nq-test.qa.csv
trans_file: /checkpoint/vladk/speechqa/data/test/transcriptions/hypo.word-960h_scratch.pt-test.txt
trans_file: /checkpoint/vladk/speechqa/data/test/transcriptions/hypo.word-960h_scratch.pt-test.txt

nq_tts_asr_test2:
_target_: dpr.data.retriever_data.TTS_ASR_QASrc
file: /private/home/scottyih/playground/bert-qa/data/nq-test.qa.csv
trans_file: /checkpoint/kushall/data/speechqa/transcription_test/hypo.word-wav2vec_small_960h.pt-test.txt
2 changes: 1 addition & 1 deletion conf/encoder/speech_mixed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ q_wav2vec_model_cfg: #facebook/wav2vec2-base-960h
# fairseq only params
q_wav2vec_cp_file: #/checkpoint/vladk/speechqa/wav2vec_small_960h.pt
q_wav2vec_apply_mask: True
q_output_layer: # Which layer representation to use

q_output_layer: # Which layer representation to use, int-index for HF implementation and a string for fairseq
q_projection_dim: 768 # Extra linear layer on top of pre-trained encoder
q_dropout: 0.1
q_use_activation: False
Expand Down
7 changes: 4 additions & 3 deletions dense_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def generate_question_vectors(
else:
batch_tensors = [tensorizer.text_to_tensor(q) for q in batch_questions]

# TODO: this only works for Wav2vec pipeline
# """ ---------------------------------
# TODO: this only works for Wav2vec pipeline adn will crash the regular text pipeline
# """
max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
min_vector_len = min(q_t.size(1) for q_t in batch_tensors)

Expand All @@ -81,7 +81,7 @@ def generate_question_vectors(
from dpr.models.reader import _pad_to_len

batch_tensors = [_pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors]
# """ ---------------------------------
# """

q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
Expand Down Expand Up @@ -283,6 +283,7 @@ def main(cfg: DictConfig):
logger.info("%s", OmegaConf.to_yaml(cfg))

saved_state = load_states_from_checkpoint(cfg.model_file)

set_cfg_params_from_state(saved_state.encoder_params, cfg)

tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)
Expand Down
27 changes: 7 additions & 20 deletions dpr/models/fairseq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
from fairseq.models.roberta.hub_interface import RobertaHubInterface
from fairseq.models.roberta.model import RobertaModel as FaiseqRobertaModel
from fairseq.models.roberta.model import RobertaModel as FairseqRobertaModel
from torch import Tensor as T
from torch import nn

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, fairseq_roberta_hub: RobertaHubInterface):

@classmethod
def from_pretrained(cls, pretrained_dir_path: str):
model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path)
model = FairseqRobertaModel.from_pretrained(pretrained_dir_path)
return cls(model)

def forward(
Expand Down Expand Up @@ -87,11 +87,12 @@ def __init__(
model = task.build_model(w2v_args)
model.load_state_dict(state["model"], strict=True)
logger.info(
"Initialized Wav2Vec2Encoder model as %s, from cp=%s, use_tanh=%s, dropout=%s",
"Initialized Wav2Vec2Encoder model as %s, from cp=%s, use_tanh=%s, dropout=%s, output_layer=%s",
type(model),
cp_file,
use_tanh,
dropout,
output_layer,
)
if isinstance(model, fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model):
self.wav2vec_model = model
Expand All @@ -100,6 +101,7 @@ def __init__(
self.wav2vec_model = model.w2v_encoder.w2v_model
hidden_size = self.wav2vec_model.post_extract_proj.out_features

self.hidden_size = hidden_size
self.max_audio_t = max_audio_t * hidden_size
logger.info("Wav2Vec2Encoder max_audio_t %s", self.max_audio_t)

Expand All @@ -111,9 +113,6 @@ def __init__(
self.use_tanh = use_tanh
self.output_layer = output_layer

# TODO: remove after debug
self.tmp_long_audio_samples = 0

def forward(self, input_ids: T, _token_type_ids: T, attention_mask: T, representation_token_pos=0) -> Tuple[T, ...]:
mask = self.apply_mask and self.training

Expand All @@ -122,13 +121,10 @@ def forward(self, input_ids: T, _token_type_ids: T, attention_mask: T, represent

wav2vec_out, pad_mask = self.wav2vec_model.extract_features(input_ids, padding_mask=attention_mask, mask=mask)

# logger.info("!!! wav2vec_out sz %s", wav2vec_out.size())
# logger.info("!!! wav2vec_out %s", wav2vec_out)

B, T, C = wav2vec_out.size()

flat_encoded_out = wav2vec_out.reshape(B, -1)
if T > self.max_audio_t:
if flat_encoded_out.size(1) > self.max_audio_t:
logger.warning("T>max_audio_t: %d>%d", T, self.max_audio_t)

# TODO: make a util method
Expand All @@ -137,12 +133,7 @@ def pad_to_len(
max_len,
):
s_len = seq.size(0)
# TODO: remove after debug
if s_len > max_len:
self.tmp_long_audio_samples += 1
if self.tmp_long_audio_samples % 100 == 0:
logger.info("tmp_long_audio_samples %s", self.tmp_long_audio_samples)

return seq[0:max_len]
r = torch.cat(
[
Expand All @@ -165,14 +156,10 @@ def pad_to_len(

if self.training:
pooled_output = self.dropout(pooled_output)

# logger.info("!!! wav2vec_out pooled sz %s", pooled_output.size())
# logger.info("!!! wav2vec_out pooled %s", pooled_output)

return None, pooled_output, None

def get_out_size(self):
return self.wav2vec_model.post_extract_proj.out_features
return self.hidden_size


class HubertEncoder(nn.Module):
Expand Down
69 changes: 46 additions & 23 deletions dpr/models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,24 @@ def get_bert_reader_components(cfg, inference_only: bool = False, **kwargs):
return tensorizer, reader, optimizer


def get_bert_tensorizer(cfg, tokenizer=None):
# TODO: unify tensorizer init methods
def get_bert_tensorizer(cfg):
sequence_length = cfg.encoder.sequence_length
pretrained_model_cfg = cfg.encoder.pretrained_model_cfg

if not tokenizer:
tokenizer = get_bert_tokenizer(pretrained_model_cfg, do_lower_case=cfg.do_lower_case)
if cfg.special_tokens:
_add_special_tokens(tokenizer, cfg.special_tokens)
tokenizer = get_bert_tokenizer(pretrained_model_cfg, do_lower_case=cfg.do_lower_case)
if cfg.special_tokens:
_add_special_tokens(tokenizer, cfg.special_tokens)

return BertTensorizer(tokenizer, sequence_length)


def get_bert_tensorizer_p(
pretrained_model_cfg: str, sequence_length: int, do_lower_case: bool = True, special_tokens: List[str] = []
):
tokenizer = get_bert_tokenizer(pretrained_model_cfg, do_lower_case=do_lower_case)
if special_tokens:
_add_special_tokens(tokenizer, special_tokens)

return BertTensorizer(tokenizer, sequence_length)

Expand Down Expand Up @@ -187,9 +197,16 @@ def get_roberta_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):


def get_wav2vec_encoder(
pretrained_model: str, max_audio_t: int, extra_proj_d: int, final_drop: float, use_activation: bool
pretrained_model: str,
max_audio_t: int,
extra_proj_d: int,
final_drop: float,
use_activation: bool,
output_layer: int,
):
encoder = Wav2Vec2HFEncoder.init_encoder(pretrained_model, max_audio_t, extra_proj_d, final_drop, use_activation)
encoder = Wav2Vec2HFEncoder.init_encoder(
pretrained_model, max_audio_t, extra_proj_d, final_drop, use_activation, output_layer
)
return encoder


Expand Down Expand Up @@ -270,17 +287,20 @@ def get_out_size(self):


class Wav2Vec2HFEncoder(Wav2Vec2Model):
def __init__(self, config, max_audio_t: int, project_dim, final_dropout: float, use_activation: bool):
def __init__(
self, config, max_audio_t: int, project_dim, final_dropout: float, use_activation: bool, output_layer: int = -1
):
Wav2Vec2Model.__init__(self, config)
hidden_size = config.hidden_size

self.max_audio = max_audio_t * hidden_size
logger.info(
"Wav2Vec2HFEncoder: max_audio_t %s, project_dim %s, dropout %s, use_activation %s",
"Wav2Vec2HFEncoder: max_audio_t %s, project_dim %s, dropout %s, use_activation %s, output_layer %d",
self.max_audio,
project_dim,
final_dropout,
use_activation,
output_layer,
)

self.dense = nn.Linear(self.max_audio, hidden_size) if project_dim != 0 else None
Expand All @@ -291,6 +311,8 @@ def __init__(self, config, max_audio_t: int, project_dim, final_dropout: float,
self.use_activation = use_activation
self.init_weights()

self.output_layer = output_layer

@classmethod
def init_encoder(
cls,
Expand All @@ -300,6 +322,7 @@ def init_encoder(
dropout: float,
use_activation: bool,
pretrained: bool = True,
output_layer: int = -1,
) -> Wav2Vec2Model:
logger.info("Initializing HF Wav2Vec2Model Encoder. cfg_name=%s", cfg_name)
cfg_name = cfg_name if cfg_name else "facebook/wav2vec2-base-960h"
Expand All @@ -316,9 +339,10 @@ def init_encoder(
project_dim=projection_dim,
final_dropout=dropout,
use_activation=use_activation,
output_layer=output_layer,
)
else:
return Wav2Vec2HFEncoder(cfg, max_audio_t, projection_dim, dropout, use_activation)
return Wav2Vec2HFEncoder(cfg, max_audio_t, projection_dim, dropout, use_activation, output_layer)

def forward(
self,
Expand All @@ -328,23 +352,18 @@ def forward(
representation_token_pos=0,
) -> Tuple[T, ...]:

# TODO: remove after debug
# torch.cuda.ipc_collect()

# logger.info("!!! wav2vec input_ids %s", input_ids.size())

wav2vec_out = super().forward(input_ids, output_hidden_states=True) # attention_mask=attention_mask

wav2vec_out = wav2vec_out.last_hidden_state

# logger.info("!!! wav2vec_out sz %s", wav2vec_out.size())

wav2vec_out = (
wav2vec_out.last_hidden_state
if self.output_layer in [None, -1]
else wav2vec_out.hidden_states[self.output_layer]
)
B, T, C = wav2vec_out.size()

if self.dense:
flat_encoded_out = wav2vec_out.reshape(B, -1)
# if T * C > self.max_audio:
# logger.warning("T>max_audio_t: %d>%d", T, self.max_audio)
if flat_encoded_out.size(1) > self.max_audio:
logger.warning("TxC>max_audio_t: %d>%d", flat_encoded_out.size(1), self.max_audio)

# TODO: make a util method
def pad_to_len(
Expand Down Expand Up @@ -374,10 +393,14 @@ def pad_to_len(

if self.training:
wav2vec_out = self.dropout(wav2vec_out)
else:
wav2vec_out = wav2vec_out[:, representation_token_pos, :]

# logger.info("!!! wav2vec final out sz %s", wav2vec_out.size())
return None, wav2vec_out, None

def get_out_size(self):
return self.config.hidden_size


class BertTensorizer(Tensorizer):
def __init__(self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True):
Expand Down
16 changes: 8 additions & 8 deletions dpr/models/mixed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
import numpy as np
import torch
from typing import List, Tuple
from torch import Tensor as T
from torch import nn


from dpr.data.speech_data import BiEncoderMixedSample
from dpr.models.biencoder import BiEncoder, BiEncoderBatch

from dpr.models.hf_models import (
HFBertEncoder,
get_optimizer,
get_bert_tensorizer,
get_hf_model_param_grouping,
get_optimizer_grouped,
get_wav2vec_encoder,
get_bert_tensorizer_p,
)

from dpr.models.fairseq_models import (
Expand Down Expand Up @@ -65,16 +62,17 @@ def get_audio_mixed_biencoder_components(cfg, inference_only: bool = False, **kw

def get_query_encoder(cfg):
# TODO: unify initialization
if cfg.encoder.q_encoder_type == "hf-wav2vec": # HF-based
if cfg.encoder.q_encoder_type == "hf-wav2vec" and cfg.encoder.q_wav2vec_model_cfg: # HF-based
query_encoder = get_wav2vec_encoder(
cfg.encoder.q_wav2vec_model_cfg,
cfg.encoder.q_max_audio_t,
cfg.encoder.q_projection_dim,
cfg.encoder.q_dropout,
cfg.encoder.q_use_activation,
cfg.encoder.q_output_layer,
)

elif cfg.encoder.q_wav2vec_cp_file: # Fairseq based
elif cfg.encoder.q_wav2vec_cp_file and cfg.encoder.q_wav2vec_cp_file: # Fairseq based

if cfg.encoder.q_encoder_type == "fairseq-wav2vec":
audio_cls = Wav2Vec2Encoder
Expand Down Expand Up @@ -106,7 +104,9 @@ def get_ctx_encoder(cfg) -> Tuple[object, Tensorizer]:
dropout=cfg.encoder.ctx_dropout,
pretrained=cfg.encoder.ctx_pretrained,
)
tensorizer = get_bert_tensorizer(cfg)
tensorizer = get_bert_tensorizer_p(
cfg.encoder.ctx_model_cfg, cfg.encoder.ctx_sequence_length, cfg.do_lower_case, cfg.special_tokens
)
elif cfg.encoder.ctx_encoder_type == "fairseq-roberta": # Fairseq based
ctx_encoder, tensorizer = get_roberta_encoder_components(
cfg.encoder.ctx_pretrained_file,
Expand All @@ -116,7 +116,7 @@ def get_ctx_encoder(cfg) -> Tuple[object, Tensorizer]:
)

else:
raise RuntimeError("Either q_wav2vec_model_cfg or q_wav2vec_cp_file should be defined")
raise RuntimeError("encoder.ctx_encoder_type should be defined")
return ctx_encoder, tensorizer


Expand Down
Loading

0 comments on commit 0b8eb15

Please sign in to comment.