forked from bigdata-ustc/EduCAT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dataset import Dataset | ||
from .adaptest_dataset import AdapTestDataset | ||
from .train_dataset import TrainDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from collections import defaultdict, deque | ||
import torch | ||
|
||
try: | ||
# for python module | ||
from .dataset import Dataset | ||
from .train_dataset import TrainDataset | ||
except (ImportError, SystemError): # pragma: no cover | ||
# for python script | ||
from dataset import Dataset | ||
from train_dataset import TrainDataset | ||
|
||
|
||
class AdapTestDataset(Dataset): | ||
|
||
def __init__(self, data, concept_map, | ||
num_students, num_questions, num_concepts): | ||
""" | ||
Args: | ||
data: list, [(sid, qid, score)] | ||
concept_map: dict, concept map {qid: cid} | ||
num_students: int, total student number | ||
num_questions: int, total question number | ||
num_concepts: int, total concept number | ||
""" | ||
super().__init__(data, concept_map, | ||
num_students, num_questions, num_concepts) | ||
|
||
# initialize tested and untested set | ||
self.tested = None | ||
self.untested = None | ||
self.reset() | ||
|
||
def apply_selection(self, student_idx, question_idx): | ||
""" | ||
Add one untested question to the tested set | ||
Args: | ||
student_idx: int | ||
question_idx: int | ||
""" | ||
assert question_idx in self.untested[student_idx], \ | ||
'Selected question not allowed' | ||
self.untested[student_idx].remove(question_idx) | ||
self.tested[student_idx].append(question_idx) | ||
|
||
def reset(self): | ||
""" | ||
Set tested set empty | ||
""" | ||
self.tested = defaultdict(deque) | ||
self.untested = defaultdict(set) | ||
for sid in self.data: | ||
self.untested[sid] = set(self.data[sid].keys()) | ||
|
||
@property | ||
def tested(self): | ||
return self.tested | ||
|
||
@property | ||
def untested(self): | ||
return self.untested | ||
|
||
def get_tested_dataset(self, last=False): | ||
""" | ||
Get tested data for training | ||
Args: | ||
last: bool, True - the last question, False - all the tested questions | ||
Returns: | ||
TrainDataset | ||
""" | ||
triplets = [] | ||
for sid, qids in self.tested.items(): | ||
if last: | ||
qid = qids[-1] | ||
triplets.append((sid, qid, self.data[sid][qid])) | ||
else: | ||
for qid in qids: | ||
triplets.append((sid, qid, self.data[sid][qid])) | ||
return TrainDataset(triplets, self.concept_map, | ||
self.num_students, self.num_questions, self.num_concepts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from collections import defaultdict, deque | ||
|
||
|
||
class Dataset(object): | ||
|
||
def __init__(self, data, concept_map, | ||
num_students, num_questions, num_concepts): | ||
""" | ||
Args: | ||
data: list, [(sid, qid, score)] | ||
concept_map: dict, concept map {qid: cid} | ||
num_students: int, total student number | ||
num_questions: int, total question number | ||
num_concepts: int, total concept number | ||
""" | ||
self.raw_data = data | ||
self.concept_map = concept_map | ||
self.n_students = num_students | ||
self.n_questions = num_questions | ||
self.n_concepts = num_concepts | ||
|
||
# reorganize datasets | ||
self.data = {} | ||
for sid, qid, correct in data: | ||
self.data.setdefault(sid, {}) | ||
self.data[sid].setdefault(qid, {}) | ||
self.data[sid][qid] = correct | ||
|
||
student_ids = set(x[0] for x in data) | ||
question_ids = set(x[1] for x in data) | ||
concept_ids = set(sum(concept_map.values(), [])) | ||
|
||
assert max(student_ids) < num_students, \ | ||
'Require student ids renumbered' | ||
assert max(question_ids) < num_questions, \ | ||
'Require student ids renumbered' | ||
assert max(concept_ids) < num_concepts, \ | ||
'Require student ids renumbered' | ||
|
||
@property | ||
def num_students(self): | ||
return self.n_students | ||
|
||
@property | ||
def num_questions(self): | ||
return self.n_questions | ||
|
||
@property | ||
def num_concepts(self): | ||
return self.n_concepts | ||
|
||
@property | ||
def raw_data(self): | ||
return self.raw_data | ||
|
||
@property | ||
def data(self): | ||
return self.data | ||
|
||
@property | ||
def concept_map(self): | ||
return self.concept_map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from torch.utils import data | ||
|
||
try: | ||
# for python module | ||
from .dataset import Dataset | ||
except (ImportError, SystemError): # pragma: no cover | ||
# for python script | ||
from dataset import Dataset | ||
|
||
|
||
class TrainDataset(Dataset, data.dataset.Dataset): | ||
|
||
def __init__(self, data, concept_map, | ||
num_students, num_questions, num_concepts): | ||
""" | ||
Args: | ||
data: list, [(sid, qid, score)] | ||
concept_map: dict, concept map {qid: cid} | ||
num_students: int, total student number | ||
num_questions: int, total question number | ||
num_concepts: int, total concept number | ||
""" | ||
super().__init__(data, concept_map, | ||
num_students, num_questions, num_concepts) | ||
|
||
def __getitem__(self, item): | ||
sid, qid, score = self.raw_data[item] | ||
concepts = self.concept_map[qid] | ||
return sid, qid, score, concepts | ||
|
||
def __len__(self): | ||
return len(self.raw_data) |