Skip to content

Commit

Permalink
Merge pull request #10 from githwd2016/master
Browse files Browse the repository at this point in the history
Comment out getattr in translate.py
  • Loading branch information
shubhamagarwal92 authored Jun 14, 2019
2 parents af1550e + 5a92f11 commit e351500
Showing 1 changed file with 57 additions and 28 deletions.
85 changes: 57 additions & 28 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(args):
# # Local
# annoyIndex = ""
# annoyPkl = ""
model_type = getattr(models, args.model_type)
# model_type = getattr(models, args.model_type)

kb_len = None
celeb_len = None
Expand All @@ -73,33 +73,62 @@ def main(args):
celeb_vec_size = len(celeb_vocab[0])
del kb_vocab, celeb_vocab

model = model_type(src_vocab_size=vocab_size,
tgt_vocab_size=vocab_size,
src_emb_dim=config['model']['src_emb_dim'],
tgt_emb_dim=config['model']['tgt_emb_dim'],
enc_hidden_size=config['model']['enc_hidden_size'],
dec_hidden_size=config['model']['dec_hidden_size'],
context_hidden_size=config['model']['context_hidden_size'],
batch_size=config['data']['batch_size'],
image_in_size=config['model']['image_in_size'],
bidirectional_enc=config['model']['bidirectional_enc'],
bidirectional_context=config['model']['bidirectional_context'],
num_enc_layers=config['model']['num_enc_layers'],
num_dec_layers=config['model']['num_dec_layers'],
num_context_layers=config['model']['num_context_layers'],
dropout_enc=config['model']['dropout_enc'],
dropout_dec=config['model']['dropout_dec'],
dropout_context=config['model']['dropout_context'],
max_decode_len=config['model']['max_decode_len'],
non_linearity=config['model']['non_linearity'],
enc_type=config['model']['enc_type'],
dec_type=config['model']['dec_type'],
context_type=config['model']['context_type'],
use_attention=config['model']['use_attention'],
decode_function=config['model']['decode_function'],
num_states=args.num_states,
use_kb=use_kb, kb_size=kb_size, celeb_vec_size=celeb_vec_size
)
if args.model_type == 'MultimodalHRED':
model = MultimodalHRED(src_vocab_size=vocab_size,
tgt_vocab_size=vocab_size,
src_emb_dim=config['model']['src_emb_dim'],
tgt_emb_dim=config['model']['tgt_emb_dim'],
enc_hidden_size=config['model']['enc_hidden_size'],
dec_hidden_size=config['model']['dec_hidden_size'],
context_hidden_size=config['model']['context_hidden_size'],
batch_size=config['data']['batch_size'],
image_in_size=config['model']['image_in_size'],
bidirectional_enc=config['model']['bidirectional_enc'],
bidirectional_context=config['model']['bidirectional_context'],
num_enc_layers=config['model']['num_enc_layers'],
num_dec_layers=config['model']['num_dec_layers'],
num_context_layers=config['model']['num_context_layers'],
dropout_enc=config['model']['dropout_enc'],
dropout_dec=config['model']['dropout_dec'],
dropout_context=config['model']['dropout_context'],
max_decode_len=config['model']['max_decode_len'],
non_linearity=config['model']['non_linearity'],
enc_type=config['model']['enc_type'],
dec_type=config['model']['dec_type'],
context_type=config['model']['context_type'],
use_attention=config['model']['use_attention'],
decode_function=config['model']['decode_function'],
num_states=args.num_states,
use_kb=use_kb, kb_size=kb_size, celeb_vec_size=celeb_vec_size
)
else:
model = HRED(src_vocab_size=vocab_size,
tgt_vocab_size=vocab_size,
src_emb_dim=config['model']['src_emb_dim'],
tgt_emb_dim=config['model']['tgt_emb_dim'],
enc_hidden_size=config['model']['enc_hidden_size'],
dec_hidden_size=config['model']['dec_hidden_size'],
context_hidden_size=config['model']['context_hidden_size'],
batch_size=config['data']['batch_size'],
image_in_size=config['model']['image_in_size'],
bidirectional_enc=config['model']['bidirectional_enc'],
bidirectional_context=config['model']['bidirectional_context'],
num_enc_layers=config['model']['num_enc_layers'],
num_dec_layers=config['model']['num_dec_layers'],
num_context_layers=config['model']['num_context_layers'],
dropout_enc=config['model']['dropout_enc'],
dropout_dec=config['model']['dropout_dec'],
dropout_context=config['model']['dropout_context'],
max_decode_len=config['model']['max_decode_len'],
non_linearity=config['model']['non_linearity'],
enc_type=config['model']['enc_type'],
dec_type=config['model']['dec_type'],
context_type=config['model']['context_type'],
use_attention=config['model']['use_attention'],
decode_function=config['model']['decode_function'],
num_states=args.num_states,
use_kb=use_kb, kb_size=kb_size, celeb_vec_size=celeb_vec_size
)
model = torch_utils.gpu_wrapper(model, use_cuda=use_cuda)
# model = torch.load('model.pkl')
model.load_state_dict(torch.load(args.checkpoint_path))
Expand Down

0 comments on commit e351500

Please sign in to comment.