From 135b43143529f1d636f27298b220b0d9df8c7b00 Mon Sep 17 00:00:00 2001 From: nnnyt <793313994@qq.com> Date: Wed, 13 Jan 2021 22:27:54 +0800 Subject: [PATCH] add dataset --- CAT/dataset/__init__.py | 3 ++ CAT/dataset/adaptest_dataset.py | 80 +++++++++++++++++++++++++++++++++ CAT/dataset/dataset.py | 62 +++++++++++++++++++++++++ CAT/dataset/train_dataset.py | 32 +++++++++++++ 4 files changed, 177 insertions(+) create mode 100644 CAT/dataset/__init__.py create mode 100644 CAT/dataset/adaptest_dataset.py create mode 100644 CAT/dataset/dataset.py create mode 100644 CAT/dataset/train_dataset.py diff --git a/CAT/dataset/__init__.py b/CAT/dataset/__init__.py new file mode 100644 index 0000000..9df5a16 --- /dev/null +++ b/CAT/dataset/__init__.py @@ -0,0 +1,3 @@ +from .dataset import Dataset +from .adaptest_dataset import AdapTestDataset +from .train_dataset import TrainDataset \ No newline at end of file diff --git a/CAT/dataset/adaptest_dataset.py b/CAT/dataset/adaptest_dataset.py new file mode 100644 index 0000000..23fb179 --- /dev/null +++ b/CAT/dataset/adaptest_dataset.py @@ -0,0 +1,80 @@ +from collections import defaultdict, deque +import torch + +try: + # for python module + from .dataset import Dataset + from .train_dataset import TrainDataset +except (ImportError, SystemError): # pragma: no cover + # for python script + from dataset import Dataset + from train_dataset import TrainDataset + + +class AdapTestDataset(Dataset): + + def __init__(self, data, concept_map, + num_students, num_questions, num_concepts): + """ + Args: + data: list, [(sid, qid, score)] + concept_map: dict, concept map {qid: cid} + num_students: int, total student number + num_questions: int, total question number + num_concepts: int, total concept number + """ + super().__init__(data, concept_map, + num_students, num_questions, num_concepts) + + # initialize tested and untested set + self.tested = None + self.untested = None + self.reset() + + def apply_selection(self, student_idx, question_idx): + """ + Add one untested question to the tested set + Args: + student_idx: int + question_idx: int + """ + assert question_idx in self.untested[student_idx], \ + 'Selected question not allowed' + self.untested[student_idx].remove(question_idx) + self.tested[student_idx].append(question_idx) + + def reset(self): + """ + Set tested set empty + """ + self.tested = defaultdict(deque) + self.untested = defaultdict(set) + for sid in self.data: + self.untested[sid] = set(self.data[sid].keys()) + + @property + def tested(self): + return self.tested + + @property + def untested(self): + return self.untested + + def get_tested_dataset(self, last=False): + """ + Get tested data for training + Args: + last: bool, True - the last question, False - all the tested questions + Returns: + TrainDataset + """ + triplets = [] + for sid, qids in self.tested.items(): + if last: + qid = qids[-1] + triplets.append((sid, qid, self.data[sid][qid])) + else: + for qid in qids: + triplets.append((sid, qid, self.data[sid][qid])) + return TrainDataset(triplets, self.concept_map, + self.num_students, self.num_questions, self.num_concepts) \ No newline at end of file diff --git a/CAT/dataset/dataset.py b/CAT/dataset/dataset.py new file mode 100644 index 0000000..48195a3 --- /dev/null +++ b/CAT/dataset/dataset.py @@ -0,0 +1,62 @@ +from collections import defaultdict, deque + + +class Dataset(object): + + def __init__(self, data, concept_map, + num_students, num_questions, num_concepts): + """ + Args: + data: list, [(sid, qid, score)] + concept_map: dict, concept map {qid: cid} + num_students: int, total student number + num_questions: int, total question number + num_concepts: int, total concept number + """ + self.raw_data = data + self.concept_map = concept_map + self.n_students = num_students + self.n_questions = num_questions + self.n_concepts = num_concepts + + # reorganize datasets + self.data = {} + for sid, qid, correct in data: + self.data.setdefault(sid, {}) + self.data[sid].setdefault(qid, {}) + self.data[sid][qid] = correct + + student_ids = set(x[0] for x in data) + question_ids = set(x[1] for x in data) + concept_ids = set(sum(concept_map.values(), [])) + + assert max(student_ids) < num_students, \ + 'Require student ids renumbered' + assert max(question_ids) < num_questions, \ + 'Require student ids renumbered' + assert max(concept_ids) < num_concepts, \ + 'Require student ids renumbered' + + @property + def num_students(self): + return self.n_students + + @property + def num_questions(self): + return self.n_questions + + @property + def num_concepts(self): + return self.n_concepts + + @property + def raw_data(self): + return self.raw_data + + @property + def data(self): + return self.data + + @property + def concept_map(self): + return self.concept_map \ No newline at end of file diff --git a/CAT/dataset/train_dataset.py b/CAT/dataset/train_dataset.py new file mode 100644 index 0000000..718052d --- /dev/null +++ b/CAT/dataset/train_dataset.py @@ -0,0 +1,32 @@ +from torch.utils import data + +try: + # for python module + from .dataset import Dataset +except (ImportError, SystemError): # pragma: no cover + # for python script + from dataset import Dataset + + +class TrainDataset(Dataset, data.dataset.Dataset): + + def __init__(self, data, concept_map, + num_students, num_questions, num_concepts): + """ + Args: + data: list, [(sid, qid, score)] + concept_map: dict, concept map {qid: cid} + num_students: int, total student number + num_questions: int, total question number + num_concepts: int, total concept number + """ + super().__init__(data, concept_map, + num_students, num_questions, num_concepts) + + def __getitem__(self, item): + sid, qid, score = self.raw_data[item] + concepts = self.concept_map[qid] + return sid, qid, score, concepts + + def __len__(self): + return len(self.raw_data) \ No newline at end of file