forked from bigdata-ustc/EduCAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
adaptest_dataset.py
95 lines (83 loc) · 3.09 KB
/
adaptest_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from collections import defaultdict, deque
import torch
try:
# for python module
from .dataset import Dataset
from .train_dataset import TrainDataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from dataset import Dataset
from train_dataset import TrainDataset
class AdapTestDataset(Dataset):
def __init__(self, data, concept_map,
num_students, num_questions, num_concepts):
"""
Args:
data: list, [(sid, qid, score)]
concept_map: dict, concept map {qid: cid}
num_students: int, total student number
num_questions: int, total question number
num_concepts: int, total concept number
"""
super().__init__(data, concept_map,
num_students, num_questions, num_concepts)
# initialize tested and untested set
self._tested = None
self._untested = None
self.reset()
def apply_selection(self, student_idx, question_idx):
"""
Add one untested question to the tested set
Args:
student_idx: int
question_idx: int
"""
assert question_idx in self._untested[student_idx], \
'Selected question not allowed'
self._untested[student_idx].remove(question_idx)
self._tested[student_idx].append(question_idx)
def reset(self):
"""
Set tested set empty
"""
self._tested = defaultdict(deque)
self._untested = defaultdict(set)
for sid in self.data:
self._untested[sid] = set(self.data[sid].keys())
@property
def tested(self):
return self._tested
@property
def untested(self):
return self._untested
def get_tested_dataset(self, last=False,ssid=None):
"""
Get tested data for training
Args:
last: bool, True - the last question, False - all the tested questions
Returns:
TrainDataset
"""
if ssid==None:
triplets = []
for sid, qids in self._tested.items():
if last:
qid = qids[-1]
triplets.append((sid, qid, self.data[sid][qid]))
else:
for qid in qids:
triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)
else:
triplets = []
for sid, qids in self._tested.items():
if ssid == sid:
if last:
qid = qids[-1]
triplets.append((sid, qid, self.data[sid][qid]))
else:
for qid in qids:
triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)