Skip to content

Commit

Permalink
Update masked LM code
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed May 3, 2019
1 parent 1dba761 commit 4bc2225
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def eval_select(model, tokenizer, model_path, best_score, epoch, arch):
return best_score


def test(args, split="test", model=None, tokenizer=None, training=False):
def test(args, split="test", model=None, tokenizer=None, training_or_lm=False):
if model is None:
epoch, arch, model, tokenizer, scores = load_checkpoint(
args.pytorch_dump_path)

if training:
if training_or_lm:
# Load MB data
test_dataset = load_data(args.data_path, args.data_name,
args.batch_size, tokenizer, split,
Expand Down
16 changes: 6 additions & 10 deletions masked_lm.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from pytorch_pretrained_bert.modeling import BertForMaskedLM, BertConfig
from pytorch_pretrained_bert.modeling import BertForMaskedLM
from pytorch_pretrained_bert import BertTokenizer
import torch

bert_model = 'bert-large-uncased'
model = BertForMaskedLM.from_pretrained(bert_model)
tokenizer = BertTokenizer.from_pretrained(bert_model)

question = "Who discovered relativity ?"
answer = "Einstein or Newton or Bohr"
question = 'who invented the telephone' # "the telephone was invented by whom"
tokenized_question = tokenizer.tokenize(question)
tokenized_answer = tokenizer.tokenize(answer)

masked_index = 0 # Who
masked_index = 0
tokenized_question[masked_index] = '[MASK]'
question_ids = tokenizer.convert_tokens_to_ids(tokenized_question)
answer_ids = tokenizer.convert_tokens_to_ids(tokenized_answer)
print(answer_ids[2])
combined_ids = question_ids + answer_ids
segments_ids = [0] * len(question_ids) + [1] * len(answer_ids)
combined_ids = question_ids
segments_ids = [0] * len(question_ids)

tokens_tensor = torch.tensor([combined_ids])
segments_tensor = torch.tensor([segments_ids])

model.eval()
predictions = model(tokens_tensor, segments_tensor) # 1 x len(combined_ids) x vocab size
predicted_index = torch.topk(predictions[0, masked_index], 300)[1].tolist()
predicted_index = torch.topk(predictions[0, masked_index], 20)[1].tolist()
print(predicted_index)
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
print(predicted_token)

0 comments on commit 4bc2225

Please sign in to comment.