Skip to content

Commit

Permalink
Fix NCAT bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Hhhhhhand committed May 25, 2024
1 parent cdd502b commit f551bb1
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 265 deletions.
35 changes: 25 additions & 10 deletions CAT/dataset/adaptest_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,36 @@ def tested(self):
def untested(self):
return self._untested

def get_tested_dataset(self, last=False):
def get_tested_dataset(self, last=False,ssid=None):
"""
Get tested data for training
Args:
last: bool, True - the last question, False - all the tested questions
Returns:
TrainDataset
"""
triplets = []
for sid, qids in self._tested.items():
if last:
qid = qids[-1]
triplets.append((sid, qid, self.data[sid][qid]))
else:
for qid in qids:
if ssid==None:
triplets = []
for sid, qids in self._tested.items():
if last:
qid = qids[-1]

triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)
else:
for qid in qids:
triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)
else:
triplets = []
for sid, qids in self._tested.items():
if ssid == sid:
if last:
qid = qids[-1]

triplets.append((sid, qid, self.data[sid][qid]))
else:
for qid in qids:
triplets.append((sid, qid, self.data[sid][qid]))
return TrainDataset(triplets, self.concept_map,
self.num_students, self.num_questions, self.num_concepts)
15 changes: 10 additions & 5 deletions CAT/model/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple
from utils import StraightThrough
from .utils import StraightThrough
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class IRT(nn.Module):
Expand Down Expand Up @@ -99,7 +99,7 @@ def adaptest_load(self, path):
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])

def adaptest_update(self, adaptest_data: AdapTestDataset):
def adaptest_update(self, adaptest_data: AdapTestDataset,sid=None):
"""
Update CDM with tested data
"""
Expand All @@ -109,9 +109,8 @@ def adaptest_update(self, adaptest_data: AdapTestDataset):
device = self.config['device']
optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)

tested_dataset = adaptest_data.get_tested_dataset(last=True)
tested_dataset = adaptest_data.get_tested_dataset(last=True,ssid=sid)
dataloader = torch.utils.data.DataLoader(tested_dataset, batch_size=batch_size, shuffle=True)

for ep in range(1, epochs + 1):
loss = 0.0
log_steps = 100
Expand All @@ -127,7 +126,13 @@ def adaptest_update(self, adaptest_data: AdapTestDataset):
loss += bz_loss.data.float()
# if cnt % log_steps == 0:
# print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt))

return loss
def one_student_update(self, adaptest_data: AdapTestDataset):
lr = self.config['learning_rate']
batch_size = self.config['batch_size']
epochs = self.config['num_epochs']
device = self.config['device']
optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)
def evaluate(self, adaptest_data: AdapTestDataset):
data = adaptest_data.data
concept_map = adaptest_data.concept_map
Expand Down
37 changes: 25 additions & 12 deletions CAT/strategy/NCAT_nn/NCAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from scipy.optimize import minimize
from CAT.dataset import AdapTestDataset,Dataset
from CAT.model.IRT import IRT,IRTModel
import os

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(self, n_question, d_model=10, n_blocks=1,kq_same=True, dropout=0.0,
n_heads: number of heads in multi-headed attention
d_ff : dimension for fully conntected net inside the basic block
"""
self.device = torch.device('cuda')
self.device = torch.device('cpu')
self.pad = pad
self.n_question = n_question
self.dropout = dropout
Expand Down Expand Up @@ -134,7 +135,7 @@ def optimize_model(self, data, lr):
def create_model(cls, config):
model = cls(config.item_num, config.latent_factor, config.num_blocks,
True, config.dropout_rate, policy_fc_dim=512, n_heads=config.num_heads, d_ff=2048, l2=1e-5, separate_qa=None, pad=0)
return model.to(torch.device('cuda'))
return model.to(torch.device('cpu'))


def mask(src, s_len):
Expand Down Expand Up @@ -309,12 +310,12 @@ def __init__(self,data,concept_map,config,T):
self.users = {}
self.utypes = {}
#self.args = args
self.device = torch.device('cuda')
self.device = torch.device('cpu')
self.rates, self._item_num, self.know_map = self.load_data(data,concept_map)
self.tsdata=data
self.setup_train_test()
self.sup_rates, self.query_rates = self.split_data(ratio=0.5)
pth_path='CDM.pt'
pth_path='../ckpt/irt.pt'
name = 'IRT'
self.model, self.dataset = self.load_CDM(name,data,pth_path,config)
#print(self.model)
Expand Down Expand Up @@ -368,8 +369,8 @@ def load_CDM(self,name,data,pth_path,config):
if name == 'IRT':
model = IRTModel(**config)
model.init_model(data)
model.load_state_dict(torch.load(pth_path), strict=False)
model.to(self.config['device'])
model.adaptest_load(pth_path)
#model.to(self.config['device'])
return model ,data.data

def step(self, action,sid):
Expand All @@ -389,10 +390,13 @@ def step(self, action,sid):

def reward(self, action,sid):
items = [state[0] for state in self.state[1]] + [action]

correct = [self.rates[self.state[0][0]][it] for it in items]
self.tsdata.apply_selection(sid, action)
loss = self.model.adaptest_update(self.tsdata)
acc,auc=self.model.evaluate(self.tsdata)
loss = self.model.adaptest_update(self.tsdata,sid)
result=self.model.evaluate(self.tsdata)
auc = result['auc']
acc = result['acc']
return -loss, acc, auc, correct[-1]

def precision(self, episode):
Expand Down Expand Up @@ -455,13 +459,15 @@ def __init__(self, NCATdata,concept_map,config,test_length):
self.config = config
self.model = None
self.env = env(data=NCATdata,concept_map=concept_map,config=config,T=test_length)

self.memory = []
self.item_num =self.env.item_num
self.user_num = self.env.user_num
self.device = config['device']
self.fa = NCAT(n_question=NCATdata.num_questions+1).to(self.device)

def ncat_policy(self,sid,THRESHOLD,used_actions):
self.memory_size = 50000
self.tau = 0

def ncat_policy(self,sid,THRESHOLD,used_actions,type,epoch):
actions = {}
rwds = 0
done = False
Expand All @@ -475,7 +481,7 @@ def ncat_policy(self,sid,THRESHOLD,used_actions):
data["uid"] = torch.tensor(data["uid"], device=self.device)
policy = self.fa.predict(data)[0]
if type == "training":
if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.args.item_num,))
if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.item_num,))
for item in actions: policy[item] = -np.inf
for item in range(self.item_num):
if item not in self.env.candidate_items:
Expand All @@ -490,6 +496,13 @@ def ncat_policy(self,sid,THRESHOLD,used_actions):
rwds += rwd
state = state_next
used_actions.extend(list(actions.keys()))
if type == "training":
if len(self.memory) >= self.config['batch_size']:
self.memory = self.memory[-self.memory_size:]
batch = [self.memory[item] for item in np.random.choice(range(len(self.memory)),(self.args.batch,))]
data = self.convert_batch2dict(batch,epoch)
loss = self.fa.optimize_model(data, self.args.learning_rate)
self.tau += 5
return

def convert_batch2dict(self, batch, epoch):
Expand Down
23 changes: 15 additions & 8 deletions CAT/strategy/NCAT_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ def __init__(self):
def name(self):
return 'NCAT'

def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length):
used_actions = []
def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length):
selection = {}
NCATdata = adaptest_data
model = NCATModel(NCATdata,concept_map,config,test_length)
threshold = config['THRESHOLD']
for sid in range(adaptest_data.num_students):
NCATdata = adaptest_data
model = NCATModel(NCATdata,concept_map,config,test_length)
threshold = config['THRESHOLD']
model.ncat_policy(sid,threshold,used_actions)

return used_actions
print(str(sid+1)+'/'+str(adaptest_data.num_students))
used_actions = []
model.ncat_policy(sid,threshold,used_actions,type="training",epoch=100)
NCATdata.reset()
for sid in range(adaptest_data.num_students):
used_actions = []
model.ncat_policy(sid,threshold,used_actions,type="testing",epoch=0)
selection[sid] = used_actions
NCATdata.reset()
return selection
Loading

0 comments on commit f551bb1

Please sign in to comment.