Skip to content

Commit

Permalink
Add twitter url dataset (castorini#145)
Browse files Browse the repository at this point in the history
* add prediction/qrel files dump option

* fix comment

* add twitter-url dataset, minor refactor

* fix minor error
  • Loading branch information
Victor0118 authored Sep 8, 2018
1 parent 5afb845 commit 10e4e56
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
9 changes: 6 additions & 3 deletions common/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,21 @@ def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, d
return TRECQA, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'wikiqa':
if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
raise FileNotFoundError('WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'WikiQA/')
train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(WikiQA.TEXT_FIELD.vocab.vectors)
return WikiQA, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'pit2015':
if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'SemEval-PIT2015/')
train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'twitterurl':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Twitter-URL/')
train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader
else:
raise ValueError('{} is not a valid dataset.'.format(dataset_name))

3 changes: 2 additions & 1 deletion common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class EvaluatorFactory(object):
'SST-2': SSTEvaluator,
'trecqa': TRECQAEvaluator,
'wikiqa': WikiQAEvaluator,
'pit2015': PIT2015Evaluator
'pit2015': PIT2015Evaluator,
'twitterurl': PIT2015Evaluator
}

evaluator_map_nce = {
Expand Down
6 changes: 4 additions & 2 deletions common/evaluators/pit2015_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def get_scores(self):
n_dev_correct += (prediction == gold_label).sum().item()
acc_total += ((prediction == batch.label.data) * (prediction == 1)).sum().item()
total_loss += F.nll_loss(scores, batch.label, size_average=False).item()
rel_total += batch.label.data.sum().item()
pre_total += torch.max(scores, 1)[1].view(batch.label.size()).data.sum().item()
rel_total += gold_label.sum().item()
pre_total += prediction.sum().item()

del scores

precision = acc_total / pre_total
recall = acc_total / rel_total
Expand Down
3 changes: 2 additions & 1 deletion common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class TrainerFactory(object):
'SST-2': SSTTrainer,
'trecqa': TRECQATrainer,
'wikiqa': WikiQATrainer,
'pit2015': PIT2015Trainer
'pit2015': PIT2015Trainer,
'twitterurl': PIT2015Trainer
}

trainer_map_nce = {
Expand Down

0 comments on commit 10e4e56

Please sign in to comment.