Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Dec 26, 2022
1 parent cfd45e8 commit 029976c
Show file tree
Hide file tree
Showing 22 changed files with 861 additions and 75 deletions.
Binary file not shown.
Binary file modified CAT/distillation/__pycache__/tool.cpython-39.pyc
Binary file not shown.
151 changes: 151 additions & 0 deletions CAT/distillation/model_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from turtle import forward
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

class distill(nn.Module):
def __init__(self,embedding_dim,user_dim):
# self.prednet_input_len =1
super(distill, self).__init__()
self.utn = nn.Sequential(
nn.Linear(user_dim, 256), nn.Sigmoid(
),
nn.Linear(256, 128), nn.Sigmoid(
),
nn.Linear(128, embedding_dim), nn.Sigmoid(
) )
# nn.Dropout(p=0.5)

self.itn = nn.Sequential(
nn.Linear(2, 256), nn.Sigmoid(
),
nn.Linear(256, 128), nn.Sigmoid(
),
nn.Linear(128, embedding_dim), nn.Sigmoid(
) )

def forward(self,u,i):
user =self.utn(u)
item =self.itn(i)
return (user * item).sum(dim=-1, keepdim=True)
# return user*item

class distillModel(object):
def __init__(self, k, embedding_dim, user_dim,device):
self.model = distill(embedding_dim,user_dim)
# 20 1 1
self.k = k
self.device=device
self.batch_size=32
self.warmp_up_ratio = 0.55
self.l=torch.tensor(1.0).to(self.device)
self.b=torch.tensor(10000.0).to(self.device)

def get_distance_data(self,train_data,item_pool):
# dissmilarity
selected=set()
for data in train_data:
top_k = data[3]
selected.update(set(top_k))
all_qs = set([int(i) for i in item_pool.keys()])
unselected=all_qs-selected
selected_itrait = [item_pool[str(i)] for i in selected]
unselected_itrait = [item_pool[str(i)] for i in unselected]
d_i=[]
d_j=[]
for i in selected_itrait:
for j in unselected_itrait:
d_i.append(i)
d_j.append(j)
# similarity
s_i=[]
s_j=[]
for i in range(len(selected_itrait)):
for j in range(i+1,len(selected_itrait)):
s_i.append(selected_itrait[i])
s_j.append(selected_itrait[j])

return s_i,s_j,d_i,d_j

def train(self,train_data,test_data,item_pool,lr=0.01,epoch=2):
self.model=self.model.to(self.device)
train_data=list(train_data)
test_data=list(test_data)
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
s_i,s_j,d_i,d_j = self.get_distance_data(train_data,item_pool)


for epoch_i in range(epoch):
# s_data=[]
loss = []
for data in tqdm(train_data,f'Epoch {epoch_i+1} '):
utrait,itrait,label,_=data
itrait = itrait.squeeze()
# u_loss: torch.Tensor = torch.tensor(0.).to(self.device)
utrait:torch.Tensor = utrait.to(self.device)
itrait: torch.Tensor = itrait.to(self.device)
label: torch.Tensor = label.to(self.device)
score = self.model(utrait,itrait).squeeze(-1)
u_loss=((score-label)**2).sum()
loss.append(u_loss.item())
optimizer.zero_grad()
u_loss.backward()
optimizer.step()

si:torch.Tensor = torch.tensor(s_i).to(self.device)
sj:torch.Tensor = torch.tensor(s_j).to(self.device)
di:torch.Tensor = torch.tensor(d_i).to(self.device)
dj:torch.Tensor = torch.tensor(d_j).to(self.device)
si = self.model.itn(si)
sj = self.model.itn(sj)
di = self.model.itn(di)
dj = self.model.itn(dj)
e_loss = self.b*((si-sj)**2).sum()/((di-dj)**2).sum()
optimizer.zero_grad()
e_loss.backward()
optimizer.step()

print('Loss: ',float(np.mean(loss)))
self.eval(test_data,item_pool)

def load(self, path):
self.model.to(self.device)
self.model.load_state_dict(torch.load(path), strict=False)

def save(self, path):
model_dict = self.model.state_dict()
model_dict = {k: v for k, v in model_dict.items()
if 'utn' in k or 'itn' in k}
torch.save(model_dict, path)

def eval(self,valid_data,item_pool):
self.model=self.model.to(self.device)
k_nums=[1,5,10,15,20]
recall = [[]for i in k_nums]
for data in tqdm(valid_data,'testing'):
utrait,_,__,k_info=data
k_items,k_DCG = self.getkitems(utrait,item_pool)
for i,k in enumerate(k_nums):
i_kitems = set(k_items[:k]).intersection(set(k_info[:k]))
recall[i].append(len(i_kitems)/k)
for i,k in enumerate(k_nums):
print(f'recall@{k}: ',np.mean(recall[i]))

def getkitems(self, utrait,item_pool):
with torch.no_grad():
self.model.eval()
utrait:torch.Tensor = utrait.to(self.device)
itrait:torch.Tensor = torch.tensor(list(item_pool.values())).to(self.device)
scores = self.model(utrait,itrait).squeeze(-1)
tmp = list(zip(scores.tolist(),item_pool.keys()))
tmp_sorted = sorted(tmp, reverse=True)
self.model.train()
return [int(i[1]) for i in tmp_sorted[:self.k]],[e[0]/np.log(i+2) for i,e in enumerate(tmp_sorted[:self.k])]







6 changes: 4 additions & 2 deletions CAT/distillation/prepare_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
seed = 0
random.seed(seed)
cdm='irt'
dataset = 'assistment'
dataset = 'ifytek'
stg='MFI'
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
train_triplets = pd.read_csv(
f'/data/yutingh/CAT/data/{dataset}/train_triples.csv', encoding='utf-8').to_records(index=False)
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_with_theta.pt'
concept_map = json.load(open(f'/data/yutingh/CAT/data/{dataset}/item_topic.json', 'r'))
concept_map = json.load(open(f'/data/yutingh/CAT/data/{dataset}/concept_map.json', 'r'))
concept_map = {int(k):v for k,v in concept_map.items()}
metadata = json.load(open(f'/data/yutingh/CAT/data/{dataset}/metadata.json', 'r'))
train_data = CAT.dataset.TrainDataset(train_triplets, concept_map,
Expand Down Expand Up @@ -50,10 +50,12 @@
model.adaptest_load(ckpt_path)

user_dict={}
user_min=user_max=0.5
for user_id in tqdm(range(train_data.num_students),'gettting theta'):
sid = torch.LongTensor([user_id]).to(config['device'])
theta=model.get_theta(sid)
user_dict[user_id]=np.float(theta[0])
# if
item_dict={}
for item_id in tqdm(range(train_data.num_questions),'gettting alpha beta'):
qid = torch.LongTensor([item_id]).to(config['device'])
Expand Down
2 changes: 1 addition & 1 deletion CAT/distillation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_label_and_k(user_trait,item_trait,k,stg="MFI", model=None):
tested_infos=[]
# thetas=[0.,0.2,0.4,0.6,0.8,1.0]
# for theta in thetas:
# k_info, label, tested_info = get_k_fisher(k, theta, item_trait)
# k_info, label, tested_info = get_k_fisher(20, theta, item_trait)
# tested_infos.append(tested_info)
# # print('\n',theta,'\n',k_info)
# # print("===============")
Expand Down
6 changes: 3 additions & 3 deletions CAT/distillation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=5
epoch=20
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
model = distillModel(k,embedding_dim,user_dim,device='cuda:2')
model.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')
model = distillModel(k,embedding_dim,user_dim,device='cuda:1')
# model.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')
model.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
model.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')

Expand Down
53 changes: 53 additions & 0 deletions CAT/distillation/train_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
get theta ,get a b, [theta][alpha,beta]
compute MFI by getMFI top-k
compute by dot production
conpute loss
"""
from CAT.distillation.model_similarity import distillModel
from CAT.distillation.tool import transform,split_data
import torch
import json
import numpy as np
dataset='ifytek'
cdm='irt'
stg='MFI'
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait{postfix}.json', 'r'))
# stg='KLI'
# trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
utrait = trait['user']
itrait = trait['item']
# print(utrait,itrait)
label = trait['label']
k_info = trait['k_info']
# if 'tested_info' in trait:
if with_tested_info:
tested_info= trait['tested_info']
train_data, test_data = split_data(utrait,label,k_info,0.8,tested_info)
user_dim=np.array(tested_info).shape[-1]+1
else:
user_dim=1
train_data, test_data = split_data(utrait,label,k_info,0.8)

torch.manual_seed(0)
train_set = transform(itrait,*train_data)
# for i in train_set:
# print(i)
# break

test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=11
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
model = distillModel(k,embedding_dim,user_dim,device='cuda:4')
# model.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip.pt')
model.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
model.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip_s.pt')




Loading

0 comments on commit 029976c

Please sign in to comment.