Skip to content

Commit

Permalink
clean and update PIT-2015 eval (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
likicode authored and Victor0118 committed Mar 6, 2019
1 parent 34bd440 commit 430fd79
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 32 deletions.
120 changes: 94 additions & 26 deletions common/evaluators/pit2015_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,100 @@

from .evaluator import Evaluator

def URL_maxF1_eval(predict_result, test_data_label):
test_data_label = [item >= 1 for item in test_data_label]
counter = 0
tp = 0.0
fp = 0.0
fn = 0.0
tn = 0.0

for i, t in enumerate(predict_result):

if t > 0.5:
guess = True
else:
guess = False
label = test_data_label[i]
# print guess, label
if guess == True and label == False:
fp += 1.0
elif guess == False and label == True:
fn += 1.0
elif guess == True and label == True:
tp += 1.0
elif guess == False and label == False:
tn += 1.0
if label == guess:
counter += 1.0

try:
P = tp / (tp + fp)
R = tp / (tp + fn)
F = 2 * P * R / (P + R)
except:
P = 0
R = 0
F = 0

accuracy = counter / len(predict_result)

maxF1 = 0
P_maxF1 = 0
R_maxF1 = 0
probs = predict_result
sortedindex = sorted(range(len(probs)), key=probs.__getitem__)
sortedindex.reverse()

truepos = 0
falsepos = 0
for sortedi in sortedindex:
if test_data_label[sortedi] == True:
truepos += 1
elif test_data_label[sortedi] == False:
falsepos += 1
precision = 0
if truepos + falsepos > 0:
precision = truepos / (truepos + falsepos)

if (tp + fn) > 0:
recall = truepos / (tp + fn)
else:
recall = 0
f1 = 0
if precision + recall > 0:
f1 = 2 * precision * recall / (precision + recall)
if f1 > maxF1:
# print probs[sortedi]
maxF1 = f1
P_maxF1 = precision
R_maxF1 = recall
# print("PRECISION: {}, RECALL: {}, max_F1: {}".format(P_maxF1, R_maxF1, maxF1))
return (accuracy, maxF1)

class PIT2015Evaluator(Evaluator):

def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
n_dev_correct = 0
total_loss = 0
acc_total = 0
rel_total = 0
pre_total = 0
for batch_idx, batch in enumerate(self.data_loader):
sent1, sent2 = self.get_sentence_embeddings(batch)
scores = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)
prediction = torch.max(scores, 1)[1].view(batch.label.size()).data
gold_label = batch.label.data
n_dev_correct += (prediction == gold_label).sum().item()
acc_total += ((prediction == batch.label.data) * (prediction == 1)).sum().item()
total_loss += F.nll_loss(scores, batch.label, size_average=False).item()
rel_total += gold_label.sum().item()
pre_total += prediction.sum().item()

del scores

precision = acc_total / pre_total
recall = acc_total / rel_total
f1 = 2 * precision * recall / (precision + recall)
accuracy = 100. * n_dev_correct / len(self.data_loader.dataset.examples)
avg_loss = total_loss / len(self.data_loader.dataset.examples)
return [accuracy, avg_loss, precision, recall, f1], ['accuracy', 'cross_entropy_loss', 'precision', 'recall', 'f1']
#self.model.eval()
test_loss = 0
true_labels = []
predictions = []

with torch.no_grad():
for batch in self.data_loader:
# Select embedding
sent1, sent2 = self.get_sentence_embeddings(batch)

output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw)

test_loss += F.nll_loss(output, batch.label, size_average=False).item()

true_labels.extend(batch.label.detach().cpu().numpy())
predictions.extend(output.detach().exp()[:, 1].cpu().numpy())

del output

test_loss /= len(batch.dataset.examples)
accuracy, maxF1 = URL_maxF1_eval(predictions, true_labels)

return [accuracy, test_loss, maxF1], ['accuracy', 'NLL loss', 'f1']
8 changes: 2 additions & 6 deletions common/trainers/pit2015_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ def train_epoch(self, epoch):
100. * batch_idx / (len(self.train_loader)), loss.item() / len(batch))
)

accuracy, avg_loss, precision, recall, f1 = self.evaluate(self.train_evaluator, 'train')
accuracy, avg_loss, f1 = self.evaluate(self.train_evaluator, 'train')

if self.use_tensorboard:
self.writer.add_scalar('{}/train/cross_entropy_loss'.format(self.train_loader.dataset.NAME), avg_loss, epoch)
self.writer.add_scalar('{}/train/accuracy'.format(self.train_loader.dataset.NAME), accuracy, epoch)
self.writer.add_scalar('{}/train/precision'.format(self.train_loader.dataset.NAME), precision, epoch)
self.writer.add_scalar('{}/train/recall'.format(self.train_loader.dataset.NAME), recall, epoch)
self.writer.add_scalar('{}/train/f1'.format(self.train_loader.dataset.NAME), f1, epoch)

return total_loss
Expand All @@ -54,15 +52,13 @@ def train(self, epochs):
self.train_epoch(epoch)

dev_scores = self.evaluate(self.dev_evaluator, 'dev')
accuracy, avg_loss, precision, recall, f1 = dev_scores
accuracy, avg_loss, f1 = dev_scores

test_scores = self.evaluate(self.test_evaluator, 'test')
if self.use_tensorboard:
self.writer.add_scalar('{}/lr'.format(self.train_loader.dataset.NAME), self.optimizer.param_groups[0]['lr'], epoch)
self.writer.add_scalar('{}/dev/cross_entropy_loss'.format(self.train_loader.dataset.NAME), avg_loss, epoch)
self.writer.add_scalar('{}/dev/accuracy'.format(self.train_loader.dataset.NAME), accuracy, epoch)
self.writer.add_scalar('{}/dev/precision'.format(self.train_loader.dataset.NAME), precision, epoch)
self.writer.add_scalar('{}/dev/recall'.format(self.train_loader.dataset.NAME), recall, epoch)
self.writer.add_scalar('{}/dev/f1'.format(self.train_loader.dataset.NAME), f1, epoch)

end = time.time()
Expand Down

0 comments on commit 430fd79

Please sign in to comment.