Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Nov 7, 2022
1 parent f9e619b commit e31bd24
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 72 deletions.
Empty file removed CAT/distillation/MFI/__init__.py
Empty file.
Binary file not shown.
Binary file removed CAT/distillation/MFI/__pycache__/dMFI.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified CAT/distillation/__pycache__/tool.cpython-39.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions CAT/distillation/prepare_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
seed = 0
random.seed(seed)
cdm='irt'
dataset = 'junyi'
stg='KLI'
dataset = 'assistment'
stg='MFI'
train_triplets = pd.read_csv(
f'/data/yutingh/CAT/data/{dataset}/train_triples.csv', encoding='utf-8').to_records(index=False)
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_with_theta.pt'
Expand Down Expand Up @@ -58,7 +58,7 @@
alpha=model.get_alpha(qid)
beta=model.get_beta(qid)
item_dict[item_id]=[np.float(alpha[0]),np.float(beta[0])]
label,k_info,tested_info = get_label_and_k(user_dict,item_dict,50,stg)
label,k_info,tested_info = get_label_and_k(user_dict,item_dict,50,stg,model)
trait_dict = {
'user':user_dict,
'item':item_dict,
Expand All @@ -69,7 +69,7 @@

path_prefix = f"/data/yutingh/CAT/data/{dataset}/{stg}/"

with open(f"{path_prefix}trait.json", "w", encoding="utf-8") as f:
with open(f"{path_prefix}trait_with_tested_info.json", "w", encoding="utf-8") as f:
# indent参数保证json数据的缩进,美观
# ensure_ascii=False才能输出中文,否则就是Unicode字符
f.write(json.dumps(trait_dict, ensure_ascii=False))
105 changes: 94 additions & 11 deletions CAT/distillation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,39 @@ def split_data(utrait,label,k_fisher,rate=1,tested_info=None):
else:
return ((u_train,label[:max_len],k_fisher[:max_len]),(u_test,label[max_len:],k_fisher[max_len:]))

def get_label_and_k(user_trait,item_trait,k,stg="MFI"):
def get_label_and_k(user_trait,item_trait,k,stg="MFI", model=None):
labels=[]
k_infos=[]
tested_ns=[]
for theta in tqdm(user_trait.values(),f'get top {k} items'):
tested_infos=[]
for sid, theta in tqdm(user_trait.items(),f'get top {k} items'):
if stg=='MFI':
k_info,label = get_k_fisher(k, theta, item_trait)
k_info, label, tested_info = get_k_fisher(k, theta, item_trait)
tested_infos.append(tested_info)
elif stg=='KLI':
k_info,label,tested_n= get_k_kli(k, theta, item_trait)
tested_ns.append(tested_n)
k_info,label,tested_info= get_k_kli(k, theta, item_trait)
tested_infos.append(tested_info)
elif stg=='MAAT':
get_k_emc(k, sid,theta, item_trait, model)
k_infos.append(k_info)
labels.append(label)
return labels,k_infos,tested_ns
return labels,k_infos,tested_infos

def transform(item_trait, user_trait,labels,k_fishers,tested_infos=None):
if tested_infos:
for theta, label, k_fisher,tested_info in zip(user_trait.values(),labels,k_fishers,tested_infos):
itrait = list(item_trait.values())
item_n = len(itrait)
user_embs=[]
for tmp in tested_info:
user_emb = [theta]
if type(tmp) == list:
user_emb.extend(tmp)
else:
user_emb.append(tmp)
user_embs.append(user_emb)
yield pack_batch([
torch.tensor(list(zip([theta]*item_n,tested_info))),
torch.tensor(user_embs),
# torch.tensor(list(zip([theta]*item_n,tested_info))),
itrait,
label,
k_fisher
Expand All @@ -65,16 +77,87 @@ def pack_batch(batch):

def get_k_fisher(k, theta, items):
fisher_arr = []
for qid,(alpha,beta) in items.items():
items_n=len(items.keys())
ns = [random.randint(0,19) for i in range(items_n)]
tested_qids = [random.sample(list(range(0,20)),n) for n in ns]
avg_embs = np.array(list(items.values())).mean(axis=0)
p=0.001
avg_tested_embs=[]
for tested_qid, (qid,(alpha,beta)) in zip(tested_qids,items.items()):
# tested_qid
if len(tested_qid)==0:
avg_tested_emb=np.array([0,0])
else:
avg_tested_emb = np.array([items[qid] for qid in tested_qid]).mean(axis=0)
item_emb=[alpha,beta]
pred = alpha * theta + beta
pred = torch.sigmoid(torch.tensor(pred))
# pred = 1 / (1 + np.exp(-pred))
q = 1 - pred
fisher_info = float((q*pred*(alpha ** 2)).numpy())
diff = ((item_emb-avg_tested_emb)**2).sum()
sim = ((item_emb-avg_embs)**2).sum()
fisher_info = float((q*pred*(alpha ** 2)).numpy()) + p*diff/sim
fisher_arr.append((fisher_info,qid))
avg_tested_embs.append(avg_tested_emb.tolist())
fisher_arr_sorted = sorted(fisher_arr, reverse=True)
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr]
tested_info=[]
for avg_tested_emb,n in zip(avg_tested_embs,ns):
avg_tested_emb.extend([n])
tested_info.append(avg_tested_emb)
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr],tested_info

def get_k_emc(k,sid,theta,items,model):
epochs = model.config['num_epochs']
lr = model.config['learning_rate']
device = model.config['device']
optimizer = torch.optim.Adam(model.model.parameters(), lr=lr)
res_arr = []
for qid,(alpha,beta) in items.items():
for name, param in model.model.named_parameters():
if 'theta' not in name:
param.requires_grad = False

original_weights = model.model.theta.weight.data.clone()

student_id = torch.LongTensor([sid]).to(device)
question_id = torch.LongTensor([qid]).to(device)
correct = torch.LongTensor([1]).to(device).float()
wrong = torch.LongTensor([0]).to(device).float()

for ep in range(epochs):
optimizer.zero_grad()
pred = model.model(student_id, question_id)
loss = model._loss_function(pred, correct)
loss.backward()
optimizer.step()

pos_weights = model.model.theta.weight.data.clone()
model.model.theta.weight.data.copy_(original_weights)

for ep in range(epochs):
optimizer.zero_grad()
pred = model.model(student_id, question_id)
loss = model._loss_function(pred, wrong)
loss.backward()
optimizer.step()

neg_weights = model.model.theta.weight.data.clone()
# model.model.theta.weight.data.copy_(original_weights)

for param in model.model.parameters():
param.requires_grad = True

if type(alpha) == float:
alpha = np.array([alpha])
if type(theta) == float:
theta = np.array([theta])
pred = np.matmul(alpha.T, theta) + beta
pred = 1 / (1 + np.exp(-pred))
result = pred * torch.norm(pos_weights - original_weights).item() + \
(1 - pred) * torch.norm(neg_weights - original_weights).item()
res_arr.append((result,qid))
res_arr_sorted = sorted(res_arr, reverse=True)
return [i[1] for i in res_arr_sorted[:k]],[i[0]for i in res_arr]


def get_k_kli(k, theta, items):
Expand Down
16 changes: 10 additions & 6 deletions CAT/distillation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,37 @@
from CAT.distillation.tool import transform,split_data
import torch
import json
import numpy as np
dataset='assistment'
cdm='irt'
stg='KLI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
stg='MFI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait_with_tested_info.json', 'r'))
# stg='KLI'
# trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
utrait = trait['user']
itrait = trait['item']
# print(utrait,itrait)
label = trait['label']
k_info = trait['k_info']
if stg=='KLI':
if 'tested_info' in trait:
tested_info= trait['tested_info']
train_data, test_data = split_data(utrait,label,k_info,0.8,tested_info)
user_dim=np.array(tested_info).shape[-1]+1
else:
user_dim=1
train_data, test_data = split_data(utrait,label,k_info,0.8)

torch.manual_seed(0)
train_set = transform(itrait,*train_data)
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=30
epoch=38
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
user_dim=2 if stg =='KLI'else 1
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:4')
dMFI.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
dMFI.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip.pt')
dMFI.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip_with_tested_info.pt')



Expand Down
179 changes: 164 additions & 15 deletions CAT/mips/draw_pic.ipynb

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions CAT/mips/prepare_ip_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@

dataset='assistment'
cdm='irt'
stg='KLI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
stg='MFI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait_with_tested_info.json', 'r'))
utrait = trait['user']
itrait = trait['item']
label = trait['label']
k_info = trait['k_info']
if 'tested_info' in trait:
tested_info= trait['tested_info']
user_dim=np.array(tested_info).shape[-1]+1
else:
user_dim=1

train_data, test_data = split_data(utrait,label,k_info,0.8)

torch.manual_seed(0)
train_set = transform(itrait,*train_data)
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
user_dim=2 if stg =='KLI'else 1
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:0')
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip.pt')
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip_with_tested_info.pt')
# dMFI.eval(test_set,itrait)
ball_embs=[]
max_embs_len=torch.tensor(0.)
Expand All @@ -47,7 +52,7 @@

path_prefix = f"/data/yutingh/CAT/data/{dataset}/{stg}/"

with open(f"{path_prefix}ball_trait.json", "w", encoding="utf-8") as f:
with open(f"{path_prefix}ball_trait_with_tested_info.json", "w", encoding="utf-8") as f:
# indent参数保证json数据的缩进,美观
# ensure_ascii=False才能输出中文,否则就是Unicode字符
f.write(json.dumps(ball_embs, ensure_ascii=False))
90 changes: 59 additions & 31 deletions CAT/mips/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,27 @@ def setuplogger():

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
def main(dataset="assistment", cdm="irt", stg = ['Random'], test_length = 20, ctx="cuda:4", lr=0.2, num_epoch=1, efficient=False):
lr=0.15 if dataset=='assistment' else 0.2
def main(dataset="assistment", cdm="irt", stg = ['MFI'], test_length = 20, ctx="cuda:4", lr=0.2, num_epoch=1, efficient=True):
# lr=0.05 if dataset=='assistment' else 0.2
setuplogger()
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

config = {
'learning_rate': lr,
'batch_size': 2048,
'num_epochs': num_epoch,
'num_dim': 1, # for IRT or MIRT
'device': ctx,
# for NeuralCD
'prednet_len1': 128,
'prednet_len2': 64,
# 'prednet_len1': 64,
# 'prednet_len2': 32,
lr_config={
"assistment":{
"MFI":0.15,
"KLI":0.15,
"Random":0.05,
'MAAT':0.15
},
"junyi":{
"MFI":0.2,
"KLI":0.2,
"Random":0.2,
'MAAT':0.15
}
}

metadata = json.load(open(f'/data/yutingh/CAT/data/{dataset}/metadata.json', 'r'))
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}.pt'
# read datasets
Expand All @@ -64,6 +66,18 @@ def main(dataset="assistment", cdm="irt", stg = ['Random'], test_length = 20, ct
df = pd.DataFrame()
df1 = pd.DataFrame()
for i, strategy in enumerate(strategies):
config = {
'learning_rate': lr_config[dataset][stg[i]],
'batch_size': 2048,
'num_epochs': num_epoch,
'num_dim': 1, # for IRT or MIRT
'device': ctx,
# for NeuralCD
'prednet_len1': 128,
'prednet_len2': 64,
# 'prednet_len1': 64,
# 'prednet_len2': 32,
}
if cdm == 'irt':
model = CAT.model.IRTModel(**config)
elif cdm =='ncd':
Expand All @@ -72,15 +86,20 @@ def main(dataset="assistment", cdm="irt", stg = ['Random'], test_length = 20, ct
model.adaptest_load(ckpt_path)
test_data.reset()
if efficient:
ball_trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg[i]}/ball_trait.json', 'r'))
ball_trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg[i]}/ball_trait_with_tested_info.json', 'r'))
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg[i]}/trait_with_tested_info.json', 'r'))
distill_k=50
embedding_dim=15
user_dim=2 if stg[0]=='KLI' else 1
if 'tested_info' in trait:
tested_info= trait['tested_info']
user_dim=np.array(tested_info).shape[-1]+1
else:
user_dim=1
dMFI = dMFIModel(distill_k,embedding_dim,user_dim,device=ctx)
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg[i]}_ip.pt')
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg[i]}_ip_with_tested_info.pt')
logging.info('-----------')
logging.info(f'start adaptive testing with {strategy.name} strategy')

logging.info('lr: ' + str(config['learning_rate']))
logging.info(f'Iteration 0')
res=[]
time=0
Expand All @@ -99,27 +118,36 @@ def main(dataset="assistment", cdm="irt", stg = ['Random'], test_length = 20, ct
results = tmp_model.evaluate(sid, test_data)
tmp =[list(results.values())]
time = datetime.timedelta(microseconds=0)
tested_info=[]
for it in range(1, test_length + 1):
starttime = datetime.datetime.now()
if efficient:
if stg[i]=='KLI':
theta = tmp_model.model.theta(torch.tensor(sid).to(ctx))
u_emb = dMFI.model.utn(torch.cat((theta,torch.Tensor([it]).to(ctx)),0)).tolist()
theta = tmp_model.model.theta(torch.tensor(sid).to(ctx))
if user_dim==1:
u_emb = dMFI.model.utn(theta).tolist()
else:
u_emb = dMFI.model.utn(tmp_model.model.theta(torch.tensor(sid).to(ctx))).tolist()
if stg[i]=='KLI':
u_emb = dMFI.model.utn(torch.cat((theta,torch.Tensor([it]).to(ctx)),0)).tolist()
elif stg[i]=='MFI':
if len(test_data.tested[sid])==0:
avg_tested_emb=np.array([0,0]).tolist()
else:
avg_tested_emb = np.array([trait['item'][str(qid)] for qid in test_data.tested[sid]]).mean(axis=0).tolist()
avg_tested_emb.extend([it])
u_emb = dMFI.model.utn(torch.cat((theta,torch.Tensor(avg_tested_emb).to(ctx)),0)).tolist()
candidates=dict(zip(list(range(metadata['num_questions'],metadata['num_questions']+it)),[0]*it))
search_metric_tree(candidates,np.array(u_emb),T)
untested_qids = set(candidates.keys())-set(test_data.tested[sid])
# print(it, untested_qids)
if len(untested_qids) == 1:
max_score = 0
for k,v in candidates.items():
if k in untested_qids:
if v>max_score:
qid=k
max_score=v
else:
qid = strategy.adaptest_select(tmp_model, sid, test_data,item_candidates=untested_qids)
# if len(untested_qids) == 1:
max_score = 0
for k,v in candidates.items():
if k in untested_qids:
if v>max_score:
qid=k
max_score=v
# else:
# qid = strategy.adaptest_select(tmp_model, sid, test_data,item_candidates=untested_qids)
else:
qid = strategy.adaptest_select(tmp_model, sid, test_data)
test_data.apply_selection(sid, qid)
Expand Down
Binary file modified ckpt/assistment/irt_MFI_ip.pt
Binary file not shown.
Binary file added ckpt/assistment/irt_MFI_ip_with_tested_info.pt
Binary file not shown.
Binary file added ckpt/junyi/irt_KLI_ip.pt
Binary file not shown.

0 comments on commit e31bd24

Please sign in to comment.