Skip to content

Commit

Permalink
add MFI strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
nnnyt committed Jan 19, 2021
1 parent ed2212f commit 755527c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 30 deletions.
33 changes: 20 additions & 13 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
import torch.nn as nn
import numpy as np
import torch.utils.data as data
from math import exp as exp
from sklearn.metrics import roc_auc_score

try:
# for python module
from .abstract_model import AbstractModel
from ..dataset import AdapTestDataset, TrainDataset, Dataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from abstract_model import AbstractModel
from dataset import AdapTestDataset, TrainDataset, Dataset
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset


class IRT(nn.Module):
Expand All @@ -41,10 +36,6 @@ def forward(self, student_ids, question_ids):
pred = torch.sigmoid(pred)
return pred

def get_knowledge_status(self, stu_ids):
stu_emb = self.theta(stu_ids)
return stu_emb.data


class IRTModel(AbstractModel):

Expand Down Expand Up @@ -181,4 +172,20 @@ def get_beta(self, question_id):
return self.model.beta.weight.data.numpy()[question_id]

def get_theta(self, student_id):
return self.model.theta.weight.data.numpy()[student_id]
return self.model.theta.weight.data.numpy()[student_id]

def get_iif(self, student_id, question_id):
device = self.config['device']
sid = torch.LongTensor([student_id]).to(device)
qid = torch.LongTensor([question_id]).to(device)
theta = self.model.theta(sid).clone().detach().requires_grad_(True)
alpha = self.model.alpha(qid).clone().detach()
beta = self.model.beta(qid).clone().detach()
pred = (alpha * theta).sum(dim=1, keepdim=True) + beta
pred = torch.sigmoid(pred)
pred.backward()
pred = pred.data.numpy()[0][0]
grad = theta.grad.data.numpy()[0][0]
return grad ** 2 / (pred * (1 - pred))


7 changes: 1 addition & 6 deletions CAT/model/abstract_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from abc import ABC, abstractmethod
try:
# for python module
from ..dataset import AdapTestDataset, TrainDataset, Dataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from dataset import AdapTestDataset, TrainDataset, Dataset
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset


class AbstractModel(ABC):
Expand Down
31 changes: 31 additions & 0 deletions CAT/strategy/MFI_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
from CAT.dataset import AdapTestDataset


class MFIStrategy(AbstractStrategy):

def __init__(self):
super().__init__()

@property
def name(self):
return 'Maximum Fisher Information Strategy'

def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset):
assert hasattr(model, 'get_theta'), \
'the models must implement get_theta method'
assert hasattr(model, 'get_iif'), \
'the models must implement get_iif method'
selection = {}
for sid in range(adaptest_data.num_students):
theta = model.get_theta(sid)
untested_questions = np.array(list(adaptest_data.untested[sid]))
untested_iif = []
for qid in untested_questions:
untested_iif.append(model.get_iif(sid, qid))
j = np.argmax(untested_iif)
selection[sid] = untested_questions[j]
return selection
3 changes: 2 additions & 1 deletion CAT/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .abstract_strategy import AbstractStrategy
from .random_strategy import RandomStrategy
from .random_strategy import RandomStrategy
from .MFI_strategy import MFIStrategy
13 changes: 3 additions & 10 deletions CAT/strategy/random_strategy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import numpy as np

try:
# for python module
from .abstract_strategy import AbstractStrategy
from ..model import AbstractModel
from ..dataset import AdapTestDataset
except (ImportError, SystemError): # pragma: no cover
# for python script
from abstract_strategy import AbstractStrategy
from model import AbstractModel
from dataset import AdapTestDataset
from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
from CAT.dataset import AdapTestDataset


class RandomStrategy(AbstractStrategy):
Expand Down

0 comments on commit 755527c

Please sign in to comment.