Skip to content

Commit

Permalink
2024.1.12 Add new CAT Strategy (bigdata-ustc#10)
Browse files Browse the repository at this point in the history
* Add BECAT

* Fix bug and Change README

* Change README

* Update README.md

* Modified code specification

* Add BOBCAT and NCAT

* Change readme

* code formatted

* Change NCAT code structure
  • Loading branch information
Hhhhhhand authored Jan 15, 2024
1 parent af7d596 commit 76a7cc3
Show file tree
Hide file tree
Showing 8 changed files with 710 additions and 11 deletions.
41 changes: 34 additions & 7 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import time
import copy

import vegas
import logging
import torch
Expand All @@ -11,10 +9,12 @@
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate
import time
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple
from utils import StraightThrough
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class IRT(nn.Module):
def __init__(self, num_students, num_questions, num_dim):
Expand All @@ -39,7 +39,6 @@ def forward(self, student_ids, question_ids):
pred = torch.sigmoid(pred)
return pred


class IRTModel(AbstractModel):

def __init__(self, **config):
Expand All @@ -52,8 +51,11 @@ def name(self):
return 'Item Response Theory'

def init_model(self, data: Dataset):
policy_lr=0.0005
self.model = IRT(data.num_students, data.num_questions, self.config['num_dim'])

self.policy = StraightThrough(data.num_questions, data.num_questions,policy_lr, self.config)
self.n_q = data.num_questions

def train(self, train_data: TrainDataset, log_step=1):
lr = self.config['learning_rate']
batch_size = self.config['batch_size']
Expand Down Expand Up @@ -92,6 +94,8 @@ def adaptest_load(self, path):
"""
Reload the saved model
"""
if self.config['policy'] =='bobcat':
self.policy.policy.load_state_dict(torch.load(self.config['policy_path']),strict=False)
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])

Expand Down Expand Up @@ -424,6 +428,29 @@ def expected_model_change(self, sid: int, qid: int, adaptest_data: AdapTestDatas
pred = pred_all[sid][qid]
return pred * torch.norm(pos_weights - original_weights).item() + \
(1 - pred) * torch.norm(neg_weights - original_weights).item()


def bobcat_policy(self,S_set,untested_questions):
""" get expected model change
Args:
S_set:list , the questions have been chosen
untested_questions: dict, untested_questions
Returns:
float, expected model change
"""
device = self.config['device']
action_mask = [0.0] * self.n_q
train_mask=[-0.0]*self.n_q
for index in untested_questions:
action_mask[index] = 1.0
for state in S_set:
keys = list(state.keys())
key = keys[0]
values = list(state.values())
val = values[0]
train_mask[key] = (float(val)-0.5)*2
action_mask = torch.tensor(action_mask).to(device)
train_mask = torch.tensor(train_mask).to(device)
action = self.policy.policy(train_mask, action_mask)
return action.item()


44 changes: 44 additions & 0 deletions CAT/model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

def hard_sample(logits, dim=-1):
y_soft = F.softmax(logits, dim=-1)
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret, index

class Actor(nn.Module):
def __init__(self, state_dim, action_dim, n_latent_var=256):
super().__init__()
# actor
self.obs_layer = nn.Linear(state_dim, n_latent_var)
self.actor_layer = nn.Sequential(
nn.Linear(n_latent_var, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, action_dim)
)

def forward(self, state, action_mask):
hidden_state = self.obs_layer(state)
logits = self.actor_layer(hidden_state)
inf_mask = torch.clamp(torch.log(action_mask.float()),
min=torch.finfo(torch.float32).min)
logits = logits + inf_mask
actions = hard_sample(logits)
return actions

class StraightThrough:
def __init__(self, state_dim, action_dim, lr, config):
self.lr = lr
device = config['device']
self.betas = config['betas']
self.policy = Actor(state_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(
self.policy.parameters(), lr=lr, betas=self.betas)

def update(self, loss):
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
25 changes: 25 additions & 0 deletions CAT/strategy/BOBCAT_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
from scipy.optimize import minimize
from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
from CAT.dataset import AdapTestDataset

class BOBCAT(AbstractStrategy):

def __init__(self):
super().__init__()

@property
def name(self):
return 'BOBCAT'
def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset,S_set):
assert hasattr(model, 'get_kli'), \
'the models must implement get_kli method'
assert hasattr(model, 'get_pred'), \
'the models must implement get_pred method for accelerating'
selection = {}
for sid in range(adaptest_data.num_students):
untested_questions = np.array(list(adaptest_data.untested[sid]))
j = model.bobcat_policy(S_set[sid],untested_questions)
selection[sid] = j
return selection
Loading

0 comments on commit 76a7cc3

Please sign in to comment.