Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
added HF wav2vec support +misc.
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-karpukhin committed Mar 26, 2021
1 parent 82709d5 commit 284014d
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 296 deletions.
4 changes: 4 additions & 0 deletions conf/datasets/encoder_train_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ nq_speech_mixed_train:
_target_: dpr.data.speech_data.WavJsonTextDataset
json_file: /checkpoint/vladk/dpr_open_source/biencoder-nq-train.json
wav_tsv_file: /checkpoint/vladk/speechqa/data/train/train.tsv
#normalize_audio: True
normalize_audio: False

nq_speech_mixed_dev:
_target_: dpr.data.speech_data.WavJsonTextDataset
json_file: /checkpoint/vladk/dpr_open_source/biencoder-nq-dev.json
wav_tsv_file: /checkpoint/vladk/speechqa/data/dev/dev.tsv
#normalize_audio: True
normalize_audio: False
5 changes: 4 additions & 1 deletion conf/datasets/retriever_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ nq_dev:

trivia_test:
_target_: dpr.data.retriever_data.CsvQASrc
file: data.retriever.qas.trivia-test
#file: data.retriever.qas.trivia-test
file: /private/home/scottyih/playground/bert-qa/data/triviaqa-test.qa.csv

trivia_train:
_target_: dpr.data.retriever_data.CsvQASrc
Expand All @@ -30,6 +31,8 @@ nq_audio_test:
#file: data.retriever.qas.nq-test
file: /private/home/scottyih/playground/bert-qa/data/nq-test.qa.csv
wav_tsv_file: /checkpoint/vladk/speechqa/data/test/test.tsv
#normalize_audio: True
normalize_audio: False

nq_tts_asr_test:
_target_: dpr.data.retriever_data.TTS_ASR_QASrc
Expand Down
23 changes: 18 additions & 5 deletions conf/encoder/speech_mixed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# model type. One of [hf_bert, pytext_bert, fairseq_roberta]
encoder_model_type: mixed_hf_bert_wav2vec
#hf_bert

# HuggingFace's config name for model initialization
pretrained_model_cfg: bert-base-uncased
Expand All @@ -24,12 +25,24 @@ fix_ctx_encoder: False
# if False, the model won't load pre-trained BERT weights
pretrained: True

#wav2vec_cp_file: /checkpoint/vladk/speechqa/wav2vec_small_960h.pt
# non finetuned
# HF params
pretrained_wav2vec_model_cfg:
#facebook/wav2vec2-base-960h

wav2_vec_extra_proj_dim: 768 # TODO: make a common param
wav2vec_dropout: 0.1

# fairseq params
wav2vec_cp_file: /checkpoint/vladk/speechqa/wav2vec_small.pt
# non finetuned
#wav2vec_cp_file: /checkpoint/vladk/speechqa/wav2vec_small.pt

wav2vec_apply_mask: True

wav2vec_apply_mask: False

max_audio_t: 300
# wav2vec common params
wav2vec_max_audio_t: 300
wav2vec_use_activation: False

use_tanh: False
#TODO: move to train cfg group
audio_encoder_lr_factor: 0
78 changes: 23 additions & 55 deletions dense_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,26 @@ def generate_question_vectors(
# TODO: tmp workaround for EL, remove or revise
if query_token == "[START_ENT]":
batch_tensors = [
_select_span_with_token(q, tensorizer, token_str=query_token)
for q in batch_questions
_select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions
]
else:
batch_tensors = [
tensorizer.text_to_tensor(" ".join([query_token, q]))
for q in batch_questions
]
batch_tensors = [tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions]
elif isinstance(batch_questions[0], T):
batch_tensors = [q for q in batch_questions]
else:
batch_tensors = [tensorizer.text_to_tensor(q) for q in batch_questions]

# TODO: this only works for Wav2vec pipeline
# """ ---------------------------------
max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
min_vector_len = min(q_t.size(1) for q_t in batch_tensors)

# TODO: this only works for Wav2vec pipeline
if max_vector_len != min_vector_len:
# TODO: _pad_to_len move to utils
from dpr.models.reader import _pad_to_len

batch_tensors = [
_pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors
]
batch_tensors = [_pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors]
# """ ---------------------------------

q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
Expand Down Expand Up @@ -116,17 +112,13 @@ def generate_question_vectors(


class DenseRetriever(object):
def __init__(
self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer
):
def __init__(self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer):
self.question_encoder = question_encoder
self.batch_size = batch_size
self.tensorizer = tensorizer
self.selector = None

def generate_question_vectors(
self, questions: List[str], query_token: str = None
) -> T:
def generate_question_vectors(self, questions: List[str], query_token: str = None) -> T:

bsz = self.batch_size
self.question_encoder.eval()
Expand Down Expand Up @@ -168,19 +160,15 @@ def index_encoded_data(
:return:
"""
buffer = []
for i, item in enumerate(
iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)
):
for i, item in enumerate(iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)):
buffer.append(item)
if 0 < buffer_size == len(buffer):
self.index.index_data(buffer)
buffer = []
self.index.index_data(buffer)
logger.info("Data indexing completed.")

def get_top_docs(
self, query_vectors: np.array, top_docs: int = 100
) -> List[Tuple[List[object], List[float]]]:
def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]:
"""
Does the retrieval of the best matching passages given the query vectors batch
:param query_vectors:
Expand All @@ -201,9 +189,7 @@ def validate(
workers_num: int,
match_type: str,
) -> List[List[bool]]:
match_stats = calculate_matches(
passages, answers, result_ctx_ids, workers_num, match_type
)
match_stats = calculate_matches(passages, answers, result_ctx_ids, workers_num, match_type)
top_k_hits = match_stats.top_k_hits

logger.info("Validation results: top k documents hits %s", top_k_hits)
Expand Down Expand Up @@ -253,9 +239,7 @@ def save_results(
logger.info("Saved results * scores to %s", out_file)


def iterate_encoded_files(
vector_files: list, path_id_prefixes: List = None
) -> Iterator[Tuple]:
def iterate_encoded_files(vector_files: list, path_id_prefixes: List = None) -> Iterator[Tuple]:
for i, file in enumerate(vector_files):
logger.info("Reading file %s", file)
id_prefix = None
Expand All @@ -277,9 +261,7 @@ def validate_tables(
workers_num: int,
match_type: str,
) -> List[List[bool]]:
match_stats = calculate_chunked_matches(
passages, answers, result_ctx_ids, workers_num, match_type
)
match_stats = calculate_chunked_matches(passages, answers, result_ctx_ids, workers_num, match_type)
top_k_chunk_hits = match_stats.top_k_chunk_hits
top_k_table_hits = match_stats.top_k_table_hits

Expand All @@ -303,9 +285,7 @@ def main(cfg: DictConfig):
saved_state = load_states_from_checkpoint(cfg.model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)

tensorizer, encoder, _ = init_biencoder_components(
cfg.encoder.encoder_model_type, cfg, inference_only=True
)
tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)

logger.info("Loading saved model state ...")
encoder.load_state(saved_state, strict=False)
Expand All @@ -317,9 +297,7 @@ def main(cfg: DictConfig):
logger.info("Selecting standard question encoder")
encoder = encoder.question_model

encoder, _ = setup_for_distributed_mode(
encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16
)
encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16)
encoder.eval()
model_to_load = get_model_obj(encoder)
vector_size = model_to_load.get_out_size()
Expand Down Expand Up @@ -359,9 +337,7 @@ def main(cfg: DictConfig):
retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)

logger.info("Using special token %s", qa_src.special_query_token)
questions_tensor = retriever.generate_question_vectors(
questions, query_token=qa_src.special_query_token
)
questions_tensor = retriever.generate_question_vectors(questions, query_token=qa_src.special_query_token)

if qa_src.selector:
logger.info("Using custom representation token selector")
Expand All @@ -382,13 +358,11 @@ def main(cfg: DictConfig):

logger.info("ctx_files_patterns: %s", ctx_files_patterns)
if ctx_files_patterns:
assert len(ctx_files_patterns) == len(
id_prefixes
), "ctx len={} pref leb={}".format(len(ctx_files_patterns), len(id_prefixes))
assert len(ctx_files_patterns) == len(id_prefixes), "ctx len={} pref leb={}".format(
len(ctx_files_patterns), len(id_prefixes)
)
else:
assert (
index_path
), "Either encoded_ctx_files or index_path parameter should be set."
assert index_path, "Either encoded_ctx_files or index_path parameter should be set."

input_paths = []
path_id_prefixes = []
Expand All @@ -405,9 +379,7 @@ def main(cfg: DictConfig):
retriever.index.deserialize(index_path)
else:
logger.info("Reading all passages data from files: %s", input_paths)
retriever.index_encoded_data(
input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes
)
retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes)
if index_path:
retriever.index.serialize(index_path)

Expand All @@ -422,9 +394,7 @@ def main(cfg: DictConfig):
ctx_src.load_data_to(all_passages)

if len(all_passages) == 0:
raise RuntimeError(
"No passages data found. Please specify ctx_file param properly."
)
raise RuntimeError("No passages data found. Please specify ctx_file param properly.")

if cfg.validate_as_tables:
questions_doc_hits = validate_tables(
Expand Down Expand Up @@ -454,9 +424,7 @@ def main(cfg: DictConfig):
)

if cfg.kilt_out_file:
kilt_ctx = next(
iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None
)
kilt_ctx = next(iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None)
if not kilt_ctx:
raise RuntimeError("No Kilt compatible context file provided")
assert hasattr(cfg, "kilt_out_file")
Expand Down
25 changes: 6 additions & 19 deletions dpr/data/speech_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def load_data(self):
# filter those without positive ctx
self.data = [r for r in data if len(r["positive_ctxs"]) > 0]
logger.info("Total cleaned data size: {}".format(len(self.data)))
self.id_to_audio_file_map = _get_id_to_audio_file_map(
self.audio_file_prefix, self.wav_tsv_file
)
self.id_to_audio_file_map = _get_id_to_audio_file_map(self.audio_file_prefix, self.wav_tsv_file)
logger.info("id_to_audio_file_map size: %d", len(self.id_to_audio_file_map))

def __getitem__(self, index) -> BiEncoderMixedSample:
Expand All @@ -79,29 +77,19 @@ def __getitem__(self, index) -> BiEncoderMixedSample:
audio_file = self.id_to_audio_file_map[sample_id]

query_tensor = _get_audio_feats(audio_file, self.normalize_audio)
# logger.info("Audio query_tensor %s", query_tensor.size())

if query_tensor.size(1) > self.max_features_sz:
query_tensor = query_tensor[:, 0 : self.max_features_sz]
self.cut_samples += 1
if self.cut_samples % 100 == 0:
logger.info("!!! cut_samples %d", self.cut_samples)

# if query_tensor.size(1) == 371519:
# logger.info("!!! 371519 Audio size for file =%s", audio_file)

# r.query = torchaudio.load(audio_file)
r.query = query_tensor

positive_ctxs = json_sample["positive_ctxs"]
negative_ctxs = (
json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else []
)
hard_negative_ctxs = (
json_sample["hard_negative_ctxs"]
if "hard_negative_ctxs" in json_sample
else []
)
negative_ctxs = json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else []
hard_negative_ctxs = json_sample["hard_negative_ctxs"] if "hard_negative_ctxs" in json_sample else []

for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs:
if "title" not in ctx:
Expand Down Expand Up @@ -146,7 +134,6 @@ def __getitem__(self, index) -> QASample:
sample_id = index + 1
audio_file = self.id_to_audio_file_map[sample_id]
query_tensor = _get_audio_feats(audio_file, self.normalize_audio)
logger.info("Audio query_tensor %s", query_tensor.size())

# TODO: tmp
size = query_tensor.size(1)
Expand Down Expand Up @@ -175,9 +162,7 @@ def load_data(self):
data.append(QASample(self._process_question(question), None, answers))

self.data = data
self.id_to_audio_file_map = _get_id_to_audio_file_map(
self.audio_file_prefix, self.wav_tsv_file
)
self.id_to_audio_file_map = _get_id_to_audio_file_map(self.audio_file_prefix, self.wav_tsv_file)
logger.info("id_to_audio_file_map size: %d", len(self.id_to_audio_file_map))


Expand All @@ -190,12 +175,14 @@ def _read_audio(fname):

def _get_audio_feats(loc, normalize_audio: bool) -> T:
x = _read_audio(loc)
# logger.info("Raw Audio tensor %s, %s", x.shape, x)
with torch.no_grad():
source = torch.from_numpy(x).float() # .cuda()
if normalize_audio:
assert source.dim() == 1, source.dim()
with torch.no_grad():
source = F.layer_norm(source, source.shape)
# logger.info("Normalized Audio tensor %s, %s", source.size(), source)
source = source.view(1, -1)
return source

Expand Down
Loading

0 comments on commit 284014d

Please sign in to comment.