Skip to content

Commit

Permalink
add KLI strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
nnnyt committed Jan 22, 2021
1 parent 755527c commit 914f9a1
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 6 deletions.
29 changes: 27 additions & 2 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.utils.data as data
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate

from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
Expand Down Expand Up @@ -119,8 +120,8 @@ def adaptest_update(self, adaptest_data: AdapTestDataset):
bz_loss.backward()
optimizer.step()
loss += bz_loss.data.float()
if cnt % log_steps == 0:
print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt))
# if cnt % log_steps == 0:
# print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt))

def evaluate(self, adaptest_data: AdapTestDataset):
data = adaptest_data.data
Expand Down Expand Up @@ -187,5 +188,29 @@ def get_iif(self, student_id, question_id):
pred = pred.data.numpy()[0][0]
grad = theta.grad.data.numpy()[0][0]
return grad ** 2 / (pred * (1 - pred))

def kli(self, x, student_id, question_id, alpha, beta, pred_estimate):
pred = alpha * x + beta
pred = 1 / (1 + np.exp(-pred))
q_estimate = 1 - pred_estimate
q = 1 - pred
return pred_estimate * np.log(pred_estimate / pred) + q_estimate * np.log((q_estimate / q))

def get_kli(self, student_id, question_id, n):
device = self.config['device']
sid = torch.LongTensor([student_id]).to(device)
qid = torch.LongTensor([question_id]).to(device)
theta = self.model.theta(sid).clone().detach().numpy()[0][0]
alpha = self.model.alpha(qid).clone().detach().numpy()[0][0]
beta = self.model.beta(qid).clone().detach().numpy()[0][0]
pred_estimate = alpha * theta + beta
pred_estimate = 1 / (1 + np.exp(-pred_estimate))
c = 3
low = theta - c / np.sqrt(n)
high = theta + c / np.sqrt(n)
v, err = integrate.quad(self.kli, low, high, args=(sid, qid, alpha, beta, pred_estimate))
return v




31 changes: 31 additions & 0 deletions CAT/strategy/KLI_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import torch

from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
from CAT.dataset import AdapTestDataset


class KLIStrategy(AbstractStrategy):

def __init__(self):
super().__init__()

@property
def name(self):
return 'KL Information Strategy'

def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset):
assert hasattr(model, 'get_kli'), \
'the models must implement get_kli method'
selection = {}
n = len(adaptest_data.tested[0])
for sid in range(adaptest_data.num_students):
theta = model.get_theta(sid)
untested_questions = np.array(list(adaptest_data.untested[sid]))
untested_kli = []
for qid in untested_questions:
untested_kli.append(model.get_kli(sid, qid, n))
j = np.argmax(untested_kli)
selection[sid] = untested_questions[j]
return selection
4 changes: 1 addition & 3 deletions CAT/strategy/MFI_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch

from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
Expand All @@ -15,13 +16,10 @@ def name(self):
return 'Maximum Fisher Information Strategy'

def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset):
assert hasattr(model, 'get_theta'), \
'the models must implement get_theta method'
assert hasattr(model, 'get_iif'), \
'the models must implement get_iif method'
selection = {}
for sid in range(adaptest_data.num_students):
theta = model.get_theta(sid)
untested_questions = np.array(list(adaptest_data.untested[sid]))
untested_iif = []
for qid in untested_questions:
Expand Down
3 changes: 2 additions & 1 deletion CAT/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .abstract_strategy import AbstractStrategy
from .random_strategy import RandomStrategy
from .MFI_strategy import MFIStrategy
from .MFI_strategy import MFIStrategy
from .KLI_strategy import KLIStrategy

0 comments on commit 914f9a1

Please sign in to comment.