Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2024.1.12 Add new CAT Strategy #10

Merged
merged 12 commits into from
Jan 15, 2024
41 changes: 34 additions & 7 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import time
import copy

import vegas
import logging
import torch
Expand All @@ -11,10 +9,12 @@
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate
import time
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple
from utils import StraightThrough
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class IRT(nn.Module):
def __init__(self, num_students, num_questions, num_dim):
Expand All @@ -39,7 +39,6 @@ def forward(self, student_ids, question_ids):
pred = torch.sigmoid(pred)
return pred


class IRTModel(AbstractModel):

def __init__(self, **config):
Expand All @@ -52,8 +51,11 @@ def name(self):
return 'Item Response Theory'

def init_model(self, data: Dataset):
policy_lr=0.0005
self.model = IRT(data.num_students, data.num_questions, self.config['num_dim'])

self.policy = StraightThrough(data.num_questions, data.num_questions,policy_lr, self.config)
self.n_q = data.num_questions

def train(self, train_data: TrainDataset, log_step=1):
lr = self.config['learning_rate']
batch_size = self.config['batch_size']
Expand Down Expand Up @@ -92,6 +94,8 @@ def adaptest_load(self, path):
"""
Reload the saved model
"""
if self.config['policy'] =='bobcat':
self.policy.policy.load_state_dict(torch.load(self.config['policy_path']),strict=False)
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])

Expand Down Expand Up @@ -424,6 +428,29 @@ def expected_model_change(self, sid: int, qid: int, adaptest_data: AdapTestDatas
pred = pred_all[sid][qid]
return pred * torch.norm(pos_weights - original_weights).item() + \
(1 - pred) * torch.norm(neg_weights - original_weights).item()


def bobcat_policy(self,S_set,untested_questions):
""" get expected model change
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the comments

Args:
S_set:list , the questions have been chosen
untested_questions: dict, untested_questions
Returns:
float, expected model change
"""
device = self.config['device']
action_mask = [0.0] * self.n_q
train_mask=[-0.0]*self.n_q
for index in untested_questions:
action_mask[index] = 1.0
for state in S_set:
keys = list(state.keys())
key = keys[0]
values = list(state.values())
val = values[0]
train_mask[key] = (float(val)-0.5)*2
action_mask = torch.tensor(action_mask).to(device)
train_mask = torch.tensor(train_mask).to(device)
action = self.policy.policy(train_mask, action_mask)
return action.item()


Loading