-
Notifications
You must be signed in to change notification settings - Fork 56
/
dataset.py
89 lines (81 loc) · 5.83 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch
import torch.nn as nn
from datasets.sick import SICK
from datasets.msrvid import MSRVID
from datasets.trecqa import TRECQA
from datasets.wikiqa import WikiQA
from datasets.pit2015 import PIT2015
from datasets.snli import SNLI
from datasets.sts2014 import STS2014
from datasets.quora import Quora
class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
"""
cache = {}
@classmethod
def unk(cls, tensor):
size_tup = tuple(tensor.size())
if size_tup not in cls.cache:
cls.cache[size_tup] = torch.Tensor(tensor.size())
cls.cache[size_tup].normal_(0, 0.01)
return cls.cache[size_tup]
class DatasetFactory(object):
"""
Get the corresponding Dataset class for a particular dataset.
"""
@staticmethod
def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device, castor_dir="./", utils_trecqa="utils/trec_eval-9.0.5/trec_eval"):
if dataset_name == 'sick':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'sick/')
train_loader, dev_loader, test_loader = SICK.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(SICK.TEXT_FIELD.vocab.vectors)
return SICK, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'msrvid':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'msrvid/')
dev_loader = None
train_loader, test_loader = MSRVID.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(MSRVID.TEXT_FIELD.vocab.vectors)
return MSRVID, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'trecqa':
if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'TrecQA/')
train_loader, dev_loader, test_loader = TRECQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(TRECQA.TEXT_FIELD.vocab.vectors)
return TRECQA, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'wikiqa':
if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
raise FileNotFoundError('WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'WikiQA/')
train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(WikiQA.TEXT_FIELD.vocab.vectors)
return WikiQA, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'pit2015':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'SemEval-PIT2015/')
train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'twitterurl':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Twitter-URL/')
train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'snli':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'snli_1.0/')
train_loader, dev_loader, test_loader = SNLI.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(SNLI.TEXT_FIELD.vocab.vectors)
return SNLI, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'sts2014':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'STS-2014')
train_loader, dev_loader, test_loader = STS2014.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(STS2014.TEXT_FIELD.vocab.vectors)
return STS2014, embedding, train_loader, test_loader, dev_loader
elif dataset_name == "quora":
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'quora/')
train_loader, dev_loader, test_loader = Quora.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(Quora.TEXT_FIELD.vocab.vectors)
return Quora, embedding, train_loader, test_loader, dev_loader
else:
raise ValueError('{} is not a valid dataset.'.format(dataset_name))