-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2024.1.12 Add new CAT Strategy #10
Changes from 9 commits
52a1e6f
5dbb244
4acd44c
cd09dab
cbdd910
047fcc3
18861e4
6044fbd
6df3983
59d06e0
217a3e7
42da066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,60 @@ | |
import numpy as np | ||
import math | ||
import torch.utils.data as data | ||
import torch.optim as optim | ||
import copy as cp | ||
import torch.nn.functional as F | ||
from math import exp as exp | ||
from sklearn.metrics import roc_auc_score | ||
from scipy import integrate | ||
import time | ||
from CAT.model.abstract_model import AbstractModel | ||
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset | ||
from sklearn.metrics import accuracy_score | ||
from collections import namedtuple | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) | ||
|
||
def hard_sample(logits, dim=-1): | ||
y_soft = F.softmax(logits, dim=-1) | ||
index = y_soft.max(dim, keepdim=True)[1] | ||
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0) | ||
ret = y_hard - y_soft.detach() + y_soft | ||
return ret, index | ||
|
||
class Actor(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these new strategies be applied to other CDMs such as NCD? If they can, I would recommend to add another file such as |
||
def __init__(self, state_dim, action_dim, n_latent_var=256): | ||
super().__init__() | ||
# actor | ||
self.obs_layer = nn.Linear(state_dim, n_latent_var) | ||
self.actor_layer = nn.Sequential( | ||
nn.Linear(n_latent_var, n_latent_var), | ||
nn.Tanh(), | ||
nn.Linear(n_latent_var, action_dim) | ||
) | ||
|
||
def forward(self, state, action_mask): | ||
hidden_state = self.obs_layer(state) | ||
logits = self.actor_layer(hidden_state) | ||
inf_mask = torch.clamp(torch.log(action_mask.float()), | ||
min=torch.finfo(torch.float32).min) | ||
logits = logits + inf_mask | ||
train_mask, actions = hard_sample(logits) | ||
return actions | ||
|
||
class StraightThrough: | ||
def __init__(self, state_dim, action_dim, lr, betas): | ||
self.lr = lr | ||
self.betas = betas | ||
self.policy = Actor(state_dim, action_dim).to(device) | ||
self.optimizer = torch.optim.Adam( | ||
self.policy.parameters(), lr=lr, betas=betas) | ||
|
||
def update(self, loss): | ||
self.optimizer.zero_grad() | ||
loss.mean().backward() | ||
self.optimizer.step() | ||
|
||
class IRT(nn.Module): | ||
def __init__(self, num_students, num_questions, num_dim): | ||
|
@@ -39,7 +86,6 @@ def forward(self, student_ids, question_ids): | |
pred = torch.sigmoid(pred) | ||
return pred | ||
|
||
|
||
class IRTModel(AbstractModel): | ||
|
||
def __init__(self, **config): | ||
|
@@ -52,8 +98,11 @@ def name(self): | |
return 'Item Response Theory' | ||
|
||
def init_model(self, data: Dataset): | ||
betas = (0.9, 0.999) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add args to set these hyper parameters instead of directly setting in the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if it is not necessary for all IRT models, use args to control |
||
policy_lr=0.0005 | ||
self.model = IRT(data.num_students, data.num_questions, self.config['num_dim']) | ||
|
||
self.policy = StraightThrough(data.num_questions, data.num_questions,policy_lr, betas) | ||
self.n_q = data.num_questions | ||
def train(self, train_data: TrainDataset, log_step=1): | ||
lr = self.config['learning_rate'] | ||
batch_size = self.config['batch_size'] | ||
|
@@ -94,7 +143,13 @@ def adaptest_load(self, path): | |
""" | ||
self.model.load_state_dict(torch.load(path), strict=False) | ||
self.model.to(self.config['device']) | ||
|
||
def adaptest_load_BOBCAT(self, path, policy): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add args in |
||
""" | ||
Reload the saved model | ||
""" | ||
self.model.load_state_dict(torch.load(path), strict=False) | ||
self.policy.policy.load_state_dict(torch.load(policy),strict=False) | ||
self.model.to(self.config['device']) | ||
def adaptest_update(self, adaptest_data: AdapTestDataset): | ||
""" | ||
Update CDM with tested data | ||
|
@@ -424,6 +479,28 @@ def expected_model_change(self, sid: int, qid: int, adaptest_data: AdapTestDatas | |
pred = pred_all[sid][qid] | ||
return pred * torch.norm(pos_weights - original_weights).item() + \ | ||
(1 - pred) * torch.norm(neg_weights - original_weights).item() | ||
|
||
|
||
def bobcat_policy(self,S_set,untested_questions): | ||
""" get expected model change | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update the comments |
||
Args: | ||
S_set:list , the questions have been chosen | ||
untested_questions: dict, untested_questions | ||
Returns: | ||
float, expected model change | ||
""" | ||
action_mask = [0.0] * self.n_q | ||
train_mask=[-0.0]*self.n_q | ||
for index in untested_questions: | ||
action_mask[index] = 1.0 | ||
for state in S_set: | ||
keys = list(state.keys()) | ||
key = keys[0] | ||
values = list(state.values()) | ||
val = values[0] | ||
train_mask[key] = (float(val)-0.5)*2 | ||
action_mask = torch.tensor(action_mask).to(device) | ||
train_mask = torch.tensor(train_mask).to(device) | ||
action = self.policy.policy(train_mask, action_mask) | ||
return action.item() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import numpy as np | ||
from scipy.optimize import minimize | ||
from CAT.strategy.abstract_strategy import AbstractStrategy | ||
from CAT.model import AbstractModel | ||
from CAT.dataset import AdapTestDataset | ||
|
||
class BOBCAT(AbstractStrategy): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@property | ||
def name(self): | ||
return 'BOBCAT' | ||
def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset,S_set): | ||
assert hasattr(model, 'get_kli'), \ | ||
'the models must implement get_kli method' | ||
assert hasattr(model, 'get_pred'), \ | ||
'the models must implement get_pred method for accelerating' | ||
selection = {} | ||
for sid in range(adaptest_data.num_students): | ||
untested_questions = np.array(list(adaptest_data.untested[sid])) | ||
j = model.bobcat_policy(S_set[sid],untested_questions) | ||
selection[sid] = j | ||
return selection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not set device here. We will set device through config