Skip to content

Commit

Permalink
add dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
nnnyt committed Jan 19, 2021
1 parent a3e1b2a commit 135b431
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CAT/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataset import Dataset
from .adaptest_dataset import AdapTestDataset
from .train_dataset import TrainDataset
80 changes: 80 additions & 0 deletions CAT/dataset/adaptest_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections import defaultdict, deque
import torch

try:
# for python module
from .dataset import Dataset
from .train_dataset import TrainDataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from dataset import Dataset
from train_dataset import TrainDataset


class AdapTestDataset(Dataset):

def __init__(self, data, concept_map,
num_students, num_questions, num_concepts):
"""
Args:
data: list, [(sid, qid, score)]
concept_map: dict, concept map {qid: cid}
num_students: int, total student number
num_questions: int, total question number
num_concepts: int, total concept number
"""
super().__init__(data, concept_map,
num_students, num_questions, num_concepts)

# initialize tested and untested set
self.tested = None
self.untested = None
self.reset()

def apply_selection(self, student_idx, question_idx):
"""
Add one untested question to the tested set
Args:
student_idx: int
question_idx: int
"""
assert question_idx in self.untested[student_idx], \
'Selected question not allowed'
self.untested[student_idx].remove(question_idx)
self.tested[student_idx].append(question_idx)

def reset(self):
"""
Set tested set empty
"""
self.tested = defaultdict(deque)
self.untested = defaultdict(set)
for sid in self.data:
self.untested[sid] = set(self.data[sid].keys())

@property
def tested(self):
return self.tested

@property
def untested(self):
return self.untested

def get_tested_dataset(self, last=False):
"""
Get tested data for training
Args:
last: bool, True - the last question, False - all the tested questions
Returns:
TrainDataset
"""
triplets = []
for sid, qids in self.tested.items():
if last:
qid = qids[-1]
triplets.append((sid, qid, self.data[sid][qid]))
else:
for qid in qids:
triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)
62 changes: 62 additions & 0 deletions CAT/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from collections import defaultdict, deque


class Dataset(object):

def __init__(self, data, concept_map,
num_students, num_questions, num_concepts):
"""
Args:
data: list, [(sid, qid, score)]
concept_map: dict, concept map {qid: cid}
num_students: int, total student number
num_questions: int, total question number
num_concepts: int, total concept number
"""
self.raw_data = data
self.concept_map = concept_map
self.n_students = num_students
self.n_questions = num_questions
self.n_concepts = num_concepts

# reorganize datasets
self.data = {}
for sid, qid, correct in data:
self.data.setdefault(sid, {})
self.data[sid].setdefault(qid, {})
self.data[sid][qid] = correct

student_ids = set(x[0] for x in data)
question_ids = set(x[1] for x in data)
concept_ids = set(sum(concept_map.values(), []))

assert max(student_ids) < num_students, \
'Require student ids renumbered'
assert max(question_ids) < num_questions, \
'Require student ids renumbered'
assert max(concept_ids) < num_concepts, \
'Require student ids renumbered'

@property
def num_students(self):
return self.n_students

@property
def num_questions(self):
return self.n_questions

@property
def num_concepts(self):
return self.n_concepts

@property
def raw_data(self):
return self.raw_data

@property
def data(self):
return self.data

@property
def concept_map(self):
return self.concept_map
32 changes: 32 additions & 0 deletions CAT/dataset/train_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.utils import data

try:
# for python module
from .dataset import Dataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from dataset import Dataset


class TrainDataset(Dataset, data.dataset.Dataset):

def __init__(self, data, concept_map,
num_students, num_questions, num_concepts):
"""
Args:
data: list, [(sid, qid, score)]
concept_map: dict, concept map {qid: cid}
num_students: int, total student number
num_questions: int, total question number
num_concepts: int, total concept number
"""
super().__init__(data, concept_map,
num_students, num_questions, num_concepts)

def __getitem__(self, item):
sid, qid, score = self.raw_data[item]
concepts = self.concept_map[qid]
return sid, qid, score, concepts

def __len__(self):
return len(self.raw_data)

0 comments on commit 135b431

Please sign in to comment.