Skip to content

Commit

Permalink
code formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
Hhhhhhand committed Jan 14, 2024
1 parent 6044fbd commit 59d06e0
Show file tree
Hide file tree
Showing 5 changed files with 664 additions and 705 deletions.
66 changes: 8 additions & 58 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,21 @@
import os
import time
import copy

import vegas
import logging
import torch
import torch.nn as nn
import numpy as np
import math
import torch.utils.data as data
import torch.optim as optim
import copy as cp
import torch.nn.functional as F
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate
import time
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils import StraightThrough
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

def hard_sample(logits, dim=-1):
y_soft = F.softmax(logits, dim=-1)
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret, index

class Actor(nn.Module):
def __init__(self, state_dim, action_dim, n_latent_var=256):
super().__init__()
# actor
self.obs_layer = nn.Linear(state_dim, n_latent_var)
self.actor_layer = nn.Sequential(
nn.Linear(n_latent_var, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, action_dim)
)

def forward(self, state, action_mask):
hidden_state = self.obs_layer(state)
logits = self.actor_layer(hidden_state)
inf_mask = torch.clamp(torch.log(action_mask.float()),
min=torch.finfo(torch.float32).min)
logits = logits + inf_mask
train_mask, actions = hard_sample(logits)
return actions

class StraightThrough:
def __init__(self, state_dim, action_dim, lr, betas):
self.lr = lr
self.betas = betas
self.policy = Actor(state_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(
self.policy.parameters(), lr=lr, betas=betas)

def update(self, loss):
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()

class IRT(nn.Module):
def __init__(self, num_students, num_questions, num_dim):
# num_dim: IRT if num_dim == 1 else MIRT
Expand Down Expand Up @@ -98,11 +51,11 @@ def name(self):
return 'Item Response Theory'

def init_model(self, data: Dataset):
betas = (0.9, 0.999)
policy_lr=0.0005
self.model = IRT(data.num_students, data.num_questions, self.config['num_dim'])
self.policy = StraightThrough(data.num_questions, data.num_questions,policy_lr, betas)
self.policy = StraightThrough(data.num_questions, data.num_questions,policy_lr, self.config)
self.n_q = data.num_questions

def train(self, train_data: TrainDataset, log_step=1):
lr = self.config['learning_rate']
batch_size = self.config['batch_size']
Expand Down Expand Up @@ -141,15 +94,11 @@ def adaptest_load(self, path):
"""
Reload the saved model
"""
if self.config['policy'] =='bobcat':
self.policy.policy.load_state_dict(torch.load(self.config['policy_path']),strict=False)
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])
def adaptest_load_BOBCAT(self, path, policy):
"""
Reload the saved model
"""
self.model.load_state_dict(torch.load(path), strict=False)
self.policy.policy.load_state_dict(torch.load(policy),strict=False)
self.model.to(self.config['device'])

def adaptest_update(self, adaptest_data: AdapTestDataset):
"""
Update CDM with tested data
Expand Down Expand Up @@ -488,6 +437,7 @@ def bobcat_policy(self,S_set,untested_questions):
Returns:
float, expected model change
"""
device = self.config['device']
action_mask = [0.0] * self.n_q
train_mask=[-0.0]*self.n_q
for index in untested_questions:
Expand Down
Loading

0 comments on commit 59d06e0

Please sign in to comment.