diff --git a/common/dataset.py b/common/dataset.py index 5cffeeb3..9acf486e 100644 --- a/common/dataset.py +++ b/common/dataset.py @@ -51,18 +51,21 @@ def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, d return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): - raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') + raise FileNotFoundError('WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'pit2015': - if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): - raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'SemEval-PIT2015/') train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors) return PIT2015, embedding, train_loader, test_loader, dev_loader + elif dataset_name == 'twitterurl': + dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Twitter-URL/') + train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) + embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors) + return PIT2015, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name)) diff --git a/common/evaluation.py b/common/evaluation.py index 2fa467d7..8c7f6506 100644 --- a/common/evaluation.py +++ b/common/evaluation.py @@ -19,7 +19,8 @@ class EvaluatorFactory(object): 'SST-2': SSTEvaluator, 'trecqa': TRECQAEvaluator, 'wikiqa': WikiQAEvaluator, - 'pit2015': PIT2015Evaluator + 'pit2015': PIT2015Evaluator, + 'twitterurl': PIT2015Evaluator } evaluator_map_nce = { diff --git a/common/evaluators/pit2015_evaluator.py b/common/evaluators/pit2015_evaluator.py index 5ae2186d..b6d6a78a 100644 --- a/common/evaluators/pit2015_evaluator.py +++ b/common/evaluators/pit2015_evaluator.py @@ -21,8 +21,10 @@ def get_scores(self): n_dev_correct += (prediction == gold_label).sum().item() acc_total += ((prediction == batch.label.data) * (prediction == 1)).sum().item() total_loss += F.nll_loss(scores, batch.label, size_average=False).item() - rel_total += batch.label.data.sum().item() - pre_total += torch.max(scores, 1)[1].view(batch.label.size()).data.sum().item() + rel_total += gold_label.sum().item() + pre_total += prediction.sum().item() + + del scores precision = acc_total / pre_total recall = acc_total / rel_total diff --git a/common/train.py b/common/train.py index bf8a54f8..3099dfdb 100644 --- a/common/train.py +++ b/common/train.py @@ -19,7 +19,8 @@ class TrainerFactory(object): 'SST-2': SSTTrainer, 'trecqa': TRECQATrainer, 'wikiqa': WikiQATrainer, - 'pit2015': PIT2015Trainer + 'pit2015': PIT2015Trainer, + 'twitterurl': PIT2015Trainer } trainer_map_nce = {