Skip to content

Commit

Permalink
Add option to load local model and tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Apr 25, 2019
1 parent 6e512b9 commit 1db4b07
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
30 changes: 17 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@

def train(args):
if args.load_trained:
epoch, arch, model, tokenizer, scores = load_checkpoint(args.pytorch_dump_path)
last_epoch, arch, model, tokenizer, scores = load_checkpoint(args.pytorch_dump_path)
else:
model, tokenizer = load_pretrained_model_tokenizer(args.model_type, device=args.device)
# May load local file or download from huggingface
model, tokenizer = load_pretrained_model_tokenizer(args.model_type, local_model=args.local_model,
local_tokenizer=args.local_tokenizer,
device=args.device)
last_epoch = 1
train_dataset = load_data(args.data_path, args.data_name, args.batch_size, tokenizer, "train", args.device)
validate_dataset = load_data(args.data_path, args.data_name, args.batch_size, tokenizer, "dev", args.device)
test_dataset = load_data(args.data_path, args.data_name, args.batch_size, tokenizer, "test", args.device)
Expand All @@ -32,9 +36,10 @@ def train(args):
model.train()
global_step = 0
best_score = 0
for epoch in range(1, args.num_train_epochs+1):
for epoch in range(last_epoch, args.num_train_epochs + 1):
tr_loss = 0
# random.shuffle(train_dataset)
print('epoch: {}'.format(epoch))
for step, batch in enumerate(tqdm(train_dataset)):
if batch is None:
break
Expand All @@ -51,10 +56,10 @@ def train(args):
global_step += 1

if args.eval_steps > 0 and step % args.eval_steps == 0:
best_score = eval_select(model, tokenizer, validate_dataset, test_dataset, args.pytorch_dump_path, best_score, epoch, args.model_type)
best_score = eval_select(model, tokenizer, validate_dataset, test_dataset, args.pytorch_dump_path, best_score, epoch, step, args.model_type)

print("[train] loss: {}".format(tr_loss))
best_score = eval_select(model, tokenizer, validate_dataset, test_dataset, args.pytorch_dump_path, best_score, epoch, args.model_type)
best_score = eval_select(model, tokenizer, validate_dataset, test_dataset, args.pytorch_dump_path, best_score, epoch, step, args.model_type)

scores = test(args, split="test")
print_scores(scores)
Expand Down Expand Up @@ -96,20 +101,17 @@ def load_checkpoint(filename):
state = torch.load(filename)
return state['epoch'], state['arch'], state['model'], state['tokenizer'], state['scores']

def test(args, split="test", model=None, tokenizer=None, test_dataset=None, train=False):
def test(args, split="test", model=None, tokenizer=None):
if model is None:
epoch, arch, model, tokenizer, scores = load_checkpoint(args.pytorch_dump_path)
# if test_dataset is None:
# model, tokenizer = load_pretrained_model_tokenizer(args.model_type, device=args.device)
# _, _, model, tokenizer, _ = load_checkpoint('saved.model_tweet2014_3_best')
pickle.dump(tokenizer, open("tokenizer.pkl", "wb"))

# print("Load test set")
if train:
if split == 'train':
# Load MB data
test_dataset = load_data(args.data_path, args.data_name,
args.batch_size, tokenizer, split,
args.device)
else:
# Load Robust04 data
test_dataset = load_trec_data(args.data_path, args.data_name,
args.batch_size, tokenizer, split, args.device)

Expand Down Expand Up @@ -149,7 +151,7 @@ def test(args, split="test", model=None, tokenizer=None, test_dataset=None, trai
torch.cuda.empty_cache()
model.train()

return [["map", "mrr", "p30"],[map, mrr, p30]]
return [["map", "mrr", "p30"], [map, mrr, p30]]
# return [["acc", "precision", "recall", "f1"], [acc, pre, rec, f1]]
# return [["acc", "p@1", "precision", "recall", "f1"], [acc, p1, pre, rec, f1]]

Expand All @@ -166,6 +168,8 @@ def test(args, split="test", model=None, tokenizer=None, test_dataset=None, trai
parser.add_argument('--pytorch_dump_path', default='saved.model', help='')
parser.add_argument('--load_trained', action='store_true', default=False, help='')
parser.add_argument('--chinese', action='store_true', default=False, help='')
parser.add_argument('--local_model', default=None, help='[None, path to local model file]')
parser.add_argument('--local_tokenizer', default=None, help='[None, path to local vocab file]')
parser.add_argument('--eval_steps', default=-1, type=int, help='evaluation per [eval_steps] steps, -1 for evaluation per epoch')
parser.add_argument('--model_type', default='BertForNextSentencePrediction', help='')
parser.add_argument('--output_path', default='prediction.tmp', help='')
Expand Down
26 changes: 17 additions & 9 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from pytorch_pretrained_bert.optimization import BertAdam


def load_pretrained_model_tokenizer(model_type="BertForSequenceClassification", device="cuda", chinese=False):
def load_pretrained_model_tokenizer(model_type="BertForSequenceClassification",
local_model=None, local_tokenizer=None,
device="cuda", chinese=False):
# Load pre-trained model (weights)
if chinese:
base_model = "bert-base-chinese"
else:
base_model = "bert-base-uncased"
if local_model is None:
# Download from huggingface
if chinese:
base_model = "bert-base-chinese"
else:
base_model = "bert-base-uncased"
if model_type == "BertForSequenceClassification":
model = BertForSequenceClassification.from_pretrained(base_model)
# Load pre-trained model tokenizer (vocabulary)
Expand All @@ -26,8 +30,13 @@ def load_pretrained_model_tokenizer(model_type="BertForSequenceClassification",
else:
print("[Error]: unsupported model type")
return None, None

tokenizer = BertTokenizer.from_pretrained(base_model)

if local_tokenizer is None:
# Download from huggingface
tokenizer = BertTokenizer.from_pretrained(base_model)
else:
# Load local vocab file
tokenizer = BertTokenizer.from_pretrained(local_tokenizer)
model.to(device)
return model, tokenizer

Expand Down Expand Up @@ -71,7 +80,6 @@ def get_instance(self):
label, sim, a, b, qid, docid, qidx, didx = \
l.replace("\n", "").split("\t")
return label, sim, a, b, qid, docid, qidx, didx

return None, None, None, None, None, None, None, None

def load_data(data_path, data_name, batch_size, tokenizer, split="train", device="cuda", add_url=False):
Expand Down Expand Up @@ -124,7 +132,7 @@ def load_data(data_path, data_name, batch_size, tokenizer, split="train", device
data_set.append((tokens_tensor, segments_tensor, mask_tensor, label_tensor, qid_tensor, docid_tensor))
test_batch, testqid_batch, mask_batch, label_batch, qid_batch, docqid_batch = [], [], [], [], [], []
yield (tokens_tensor, segments_tensor, mask_tensor, label_tensor, qid_tensor, docid_tensor)

# if split != "train":
# break
yield None
Expand Down

0 comments on commit 1db4b07

Please sign in to comment.