Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Dec 15, 2022
1 parent e31bd24 commit cb9541e
Show file tree
Hide file tree
Showing 27 changed files with 470 additions and 90 deletions.
Binary file modified CAT/distillation/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/distillation/__pycache__/tool.cpython-39.pyc
Binary file not shown.
46 changes: 41 additions & 5 deletions CAT/distillation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
from tqdm import tqdm

class dMFI(nn.Module):
class distill(nn.Module):
def __init__(self,embedding_dim,user_dim):
# self.prednet_input_len =1
super(dMFI, self).__init__()
super(distill, self).__init__()
self.utn = nn.Sequential(
nn.Linear(user_dim, 256), nn.Sigmoid(
),
Expand All @@ -31,12 +31,47 @@ def forward(self,u,i):
return (user * item).sum(dim=-1, keepdim=True)
# return user*item

class dMFIModel(object):
def __init__(self, k, embedding_dim,user_dim,device):
self.model = dMFI(embedding_dim,user_dim)
class distillModel(object):
def __init__(self, k, embedding_dim, user_dim,device):
self.model = distill(embedding_dim,user_dim)
# 20 1 1
self.k = k
self.device=device

def train_rank(self,train_data,test_data,item_pool,lr=0.01,epoch=2):
self.model=self.model.to(self.device)
train_data=list(train_data)
test_data=list(test_data)
self.eval(test_data,item_pool)
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
loss = []
for epoch_i in range(epoch):
for data in tqdm(train_data,f'Epoch {epoch_i+1} '):
utrait,itrait,label,k_items=data
itrait = itrait.squeeze()
indices = torch.tensor(k_items).to(self.device)
u_loss: torch.Tensor = torch.tensor(0.).to(self.device)
utrait:torch.Tensor = utrait.to(self.device)
itrait: torch.Tensor = itrait.to(self.device)
label: torch.Tensor = label.to(self.device)
kutrait = torch.index_select(utrait, 0, indices)
kitrait = torch.index_select(itrait, 0, indices)
# klabel = torch.index_select(label, 0, indices)
score = self.model(kutrait,kitrait).squeeze(-1)
r = torch.arange(1,self.k+1).to(self.device)
a=torch.tensor(20.).to(self.device)
score1 = torch.cat([score[1:],score[49:]],dim=0)
# u_loss = (-torch.exp(-r/a)*torch.log(torch.sigmoid(score-score1))).sum()
u_loss = (-torch.log(torch.sigmoid(score-score1))).sum()
# u_loss=((score-label)**2).sum()
loss.append(u_loss.item())
optimizer.zero_grad()
u_loss.backward()
optimizer.step()
# print(float(np.mean(loss)))
# self.eval(valid_data,item_pool)
print('Loss: ',float(np.mean(loss)))
self.eval(test_data,item_pool)

def train(self,train_data,test_data,item_pool,lr=0.01,epoch=2):
self.model=self.model.to(self.device)
Expand Down Expand Up @@ -74,6 +109,7 @@ def save(self, path):
torch.save(model_dict, path)

def eval(self,valid_data,item_pool):
self.model=self.model.to(self.device)
k_nums=[1,3,5,10,30,50]
recall = [[]for i in k_nums]
for data in tqdm(valid_data,'testing'):
Expand Down
26 changes: 18 additions & 8 deletions CAT/distillation/prepare_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
cdm='irt'
dataset = 'assistment'
stg='MFI'
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
train_triplets = pd.read_csv(
f'/data/yutingh/CAT/data/{dataset}/train_triples.csv', encoding='utf-8').to_records(index=False)
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_with_theta.pt'
Expand Down Expand Up @@ -59,17 +61,25 @@
beta=model.get_beta(qid)
item_dict[item_id]=[np.float(alpha[0]),np.float(beta[0])]
label,k_info,tested_info = get_label_and_k(user_dict,item_dict,50,stg,model)
trait_dict = {
'user':user_dict,
'item':item_dict,
'label':label,
'k_info':k_info,
'tested_info':tested_info
}
if with_tested_info:
trait_dict = {
'user':user_dict,
'item':item_dict,
'label':label,
'k_info':k_info,
'tested_info':tested_info
}
else:
trait_dict = {
'user':user_dict,
'item':item_dict,
'label':label,
'k_info':k_info,
}

path_prefix = f"/data/yutingh/CAT/data/{dataset}/{stg}/"

with open(f"{path_prefix}trait_with_tested_info.json", "w", encoding="utf-8") as f:
with open(f"{path_prefix}trait{postfix}.json", "w", encoding="utf-8") as f:
# indent参数保证json数据的缩进,美观
# ensure_ascii=False才能输出中文,否则就是Unicode字符
f.write(json.dumps(trait_dict, ensure_ascii=False))
74 changes: 51 additions & 23 deletions CAT/distillation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,23 @@ def get_label_and_k(user_trait,item_trait,k,stg="MFI", model=None):
labels=[]
k_infos=[]
tested_infos=[]
# thetas=[0.,0.2,0.4,0.6,0.8,1.0]
# for theta in thetas:
# k_info, label, tested_info = get_k_fisher(k, theta, item_trait)
# tested_infos.append(tested_info)
# # print('\n',theta,'\n',k_info)
# # print("===============")
# k_infos.append(k_info)
# labels.append(label)
# print(k_infos)
# return labels,k_infos,tested_infos
for sid, theta in tqdm(user_trait.items(),f'get top {k} items'):
# print(theta)
if stg=='MFI':
k_info, label, tested_info = get_k_fisher(k, theta, item_trait)
tested_infos.append(tested_info)
# print('\n',theta,'\n',k_info)
# print("===============")
elif stg=='KLI':
k_info,label,tested_info= get_k_kli(k, theta, item_trait)
tested_infos.append(tested_info)
Expand Down Expand Up @@ -75,36 +88,51 @@ def pack_batch(batch):
theta, Tensor(itrait), Tensor(label), k_fisher
)

def get_k_fisher(k, theta, items):
def get_k_fisher(k,theta,items):
fisher_arr = []
items_n=len(items.keys())
ns = [random.randint(0,19) for i in range(items_n)]
tested_qids = [random.sample(list(range(0,20)),n) for n in ns]
avg_embs = np.array(list(items.values())).mean(axis=0)
p=0.001
avg_tested_embs=[]
for tested_qid, (qid,(alpha,beta)) in zip(tested_qids,items.items()):
# tested_qid
if len(tested_qid)==0:
avg_tested_emb=np.array([0,0])
else:
avg_tested_emb = np.array([items[qid] for qid in tested_qid]).mean(axis=0)
item_emb=[alpha,beta]
for qid,(alpha,beta) in items.items():
pred = alpha * theta + beta
pred = torch.sigmoid(torch.tensor(pred))
# pred = 1 / (1 + np.exp(-pred))
q = 1 - pred
diff = ((item_emb-avg_tested_emb)**2).sum()
sim = ((item_emb-avg_embs)**2).sum()
fisher_info = float((q*pred*(alpha ** 2)).numpy()) + p*diff/sim
fisher_info = float((q*pred*(alpha ** 2)).numpy())
fisher_arr.append((fisher_info,qid))
avg_tested_embs.append(avg_tested_emb.tolist())
fisher_arr_sorted = sorted(fisher_arr, reverse=True)
tested_info=[]
for avg_tested_emb,n in zip(avg_tested_embs,ns):
avg_tested_emb.extend([n])
tested_info.append(avg_tested_emb)
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr],tested_info
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr],[]

# def get_k_fisher(k, theta, items):
# fisher_arr = []
# items_n=len(items.keys())
# ns = [random.randint(0,19) for i in range(items_n)]
# tested_qids = [random.sample(list(range(0,20)),n) for n in ns]
# avg_embs = np.array(list(items.values())).mean(axis=0)
# p=0.002
# # p=0.01
# avg_tested_embs=[]
# for tested_qid, (qid,(alpha,beta)) in zip(tested_qids,items.items()):
# # tested_qid
# if len(tested_qid)==0:
# avg_tested_emb=np.array([0,0])
# else:
# avg_tested_emb = np.array([items[qid] for qid in tested_qid]).mean(axis=0)
# item_emb=[alpha,beta]
# pred = alpha * theta + beta
# pred = torch.sigmoid(torch.tensor(pred))
# # pred = 1 / (1 + np.exp(-pred))
# q = 1 - pred
# diff = ((item_emb-avg_tested_emb)**2).sum()
# sim = ((item_emb-avg_embs)**2).sum()
# fisher_info = float((q*pred*(alpha ** 2)).numpy()) + p*diff/sim
# # print(float((q*pred*(alpha ** 2)).numpy()),0.01*diff/sim)
# fisher_arr.append((fisher_info,qid,0.05*diff/sim))
# avg_tested_embs.append(avg_tested_emb.tolist())
# fisher_arr_sorted = sorted(fisher_arr, reverse=True)
# tested_info=[]
# for avg_tested_emb,n in zip(avg_tested_embs,ns):
# avg_tested_emb.extend([n])
# tested_info.append(avg_tested_emb)
# # print([i[0] for i in fisher_arr_sorted[:k]],'\n',[i[2] for i in fisher_arr_sorted[:k]])
# return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr],tested_info

def get_k_emc(k,sid,theta,items,model):
epochs = model.config['num_epochs']
Expand Down
18 changes: 11 additions & 7 deletions CAT/distillation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
compute by dot production
conpute loss
"""
from CAT.distillation.model import dMFIModel
from CAT.distillation.model import distillModel
from CAT.distillation.tool import transform,split_data
import torch
import json
import numpy as np
dataset='assistment'
cdm='irt'
stg='MFI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait_with_tested_info.json', 'r'))
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait{postfix}.json', 'r'))
# stg='KLI'
# trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
utrait = trait['user']
itrait = trait['item']
# print(utrait,itrait)
label = trait['label']
k_info = trait['k_info']
if 'tested_info' in trait:
# if 'tested_info' in trait:
if with_tested_info:
tested_info= trait['tested_info']
train_data, test_data = split_data(utrait,label,k_info,0.8,tested_info)
user_dim=np.array(tested_info).shape[-1]+1
Expand All @@ -33,12 +36,13 @@
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=38
epoch=5
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:4')
dMFI.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
dMFI.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip_with_tested_info.pt')
model = distillModel(k,embedding_dim,user_dim,device='cuda:2')
model.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')
model.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
model.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')



Expand Down
44 changes: 44 additions & 0 deletions CAT/distillation/train_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from CAT.distillation.model import distillModel
from CAT.distillation.tool import transform, split_data
import torch
import json
import numpy as np
from torch import Tensor

dataset='assistment'
cdm='irt'
stg='MFI'
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait{postfix}.json', 'r'))
# stg='KLI'
# trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
utrait = trait['user']
itrait = trait['item']
label = trait['label']
k_info = trait['k_info']
if 'tested_info' in trait:
tested_info= trait['tested_info']
train_data, test_data = split_data(utrait,label,k_info,0.8,tested_info)
user_dim=np.array(tested_info).shape[-1]+1
else:
user_dim=1
train_data, test_data = split_data(utrait,label,k_info,0.8)


torch.manual_seed(0)
train_set = transform(itrait,*train_data)
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=25
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
model = distillModel(k,embedding_dim,user_dim,device='cuda:2')
model.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')
model.train_rank(train_set,test_set,itrait,epoch=epoch,lr=lr)
# model.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')




2 changes: 1 addition & 1 deletion CAT/mips/build_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from CAT.mips.ball_tree import BallTree,search_metric_tree
from CAT.distillation.MFI.model import dMFIModel
from CAT.distillation.model import dMFIModel
import numpy as np
import torch
import datetime
Expand Down
14 changes: 7 additions & 7 deletions CAT/mips/draw_pic.ipynb

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions CAT/mips/prepare_ip_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
import torch
from tqdm import tqdm
import numpy as np
from CAT.distillation.model import dMFIModel
from CAT.distillation.model import distillModel
from CAT.distillation.tool import get_label_and_k, split_data, transform

dataset='assistment'
cdm='irt'
stg='MFI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait_with_tested_info.json', 'r'))
with_tested_info=False
postfix = '_with_tested_info' if with_tested_info else ''
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait{postfix}.json', 'r'))
utrait = trait['user']
itrait = trait['item']
label = trait['label']
k_info = trait['k_info']
if 'tested_info' in trait:
# if 'tested_info' in trait:
if with_tested_info:
tested_info= trait['tested_info']
user_dim=np.array(tested_info).shape[-1]+1
else:
Expand All @@ -31,8 +34,8 @@
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:0')
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip_with_tested_info.pt')
dMFI = distillModel(k,embedding_dim,user_dim,device='cuda:0')
dMFI.load(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip{postfix}.pt')
# dMFI.eval(test_set,itrait)
ball_embs=[]
max_embs_len=torch.tensor(0.)
Expand All @@ -52,7 +55,7 @@

path_prefix = f"/data/yutingh/CAT/data/{dataset}/{stg}/"

with open(f"{path_prefix}ball_trait_with_tested_info.json", "w", encoding="utf-8") as f:
with open(f"{path_prefix}ball_trait{postfix}.json", "w", encoding="utf-8") as f:
# indent参数保证json数据的缩进,美观
# ensure_ascii=False才能输出中文,否则就是Unicode字符
f.write(json.dumps(ball_embs, ensure_ascii=False))
Loading

0 comments on commit cb9541e

Please sign in to comment.