Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
misc minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-karpukhin committed Apr 2, 2021
1 parent 0b8eb15 commit fc34a97
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions conf/encoder/speech_mixed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ ctx_sequence_length: 256 # Max length of the encoder input sequence
ctx_dropout: 0.1
ctx_pretrained: True # if False, the model won't load pre-trained BERT weights

# whether to fix (don't update) context encoder during training or not
fix_ctx_encoder: False

# -------------- COMMON -------------------

# whether to fix (don't update) context encoder during training or not
fix_ctx_encoder: False

#TODO: move to train config group?
optimizer: hf-adam # fairseq-adam
Expand Down
5 changes: 3 additions & 2 deletions dpr/models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from transformers.optimization import AdamW
from transformers.tokenization_bert import BertTokenizer
from transformers.tokenization_roberta import RobertaTokenizer
from transformers import Wav2Vec2Model, Wav2Vec2Config # will fail

# from transformers import Wav2Vec2Model, Wav2Vec2Config # will fail

from dpr.models.biencoder import BiEncoder
from dpr.utils.data_utils import Tensorizer
Expand Down Expand Up @@ -205,7 +206,7 @@ def get_wav2vec_encoder(
output_layer: int,
):
encoder = Wav2Vec2HFEncoder.init_encoder(
pretrained_model, max_audio_t, extra_proj_d, final_drop, use_activation, output_layer
pretrained_model, max_audio_t, extra_proj_d, final_drop, use_activation, output_layer=output_layer
)
return encoder

Expand Down
2 changes: 1 addition & 1 deletion dpr/models/mixed_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_audio_mixed_biencoder_components(cfg, inference_only: bool = False, **kw
groups = get_hf_model_param_grouping(biencoder.ctx_model, weight_decay=cfg.train.weight_decay)
q_groups = get_hf_model_param_grouping(biencoder.question_model, weight_decay=cfg.train.weight_decay)
for g in q_groups:
g["lr"] = lr * cfg.encoder.audio_encoder_lr_factor
g["lr"] = lr * cfg.encoder.q_audio_encoder_lr_factor
logger.info("Setting lr=%s for wav2vec encoder param group", g["lr"])
groups.append(g)
else:
Expand Down

0 comments on commit fc34a97

Please sign in to comment.