Skip to content

Commit

Permalink
sae
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Nov 5, 2022
1 parent ffe6042 commit 74b0d44
Show file tree
Hide file tree
Showing 40 changed files with 370 additions and 541 deletions.
1 change: 0 additions & 1 deletion CAT.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ CAT.egg-info/requires.txt
CAT.egg-info/top_level.txt
CAT/cognitive_structure/__init__.py
CAT/cognitive_structure/structure.py
CAT/cognitive_structure/test.py
CAT/dataset/__init__.py
CAT/dataset/adaptest_dataset.py
CAT/dataset/dataset.py
Expand Down
Binary file modified CAT/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/cognitive_structure/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/cognitive_structure/__pycache__/structure.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/dataset/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/dataset/__pycache__/adaptest_dataset.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/dataset/__pycache__/dataset.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/dataset/__pycache__/train_dataset.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/distillation/MFI/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/distillation/MFI/__pycache__/model.cpython-39.pyc
Binary file not shown.
35 changes: 0 additions & 35 deletions CAT/distillation/MFI/train.py

This file was deleted.

Binary file added CAT/distillation/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file modified CAT/distillation/__pycache__/tool.cpython-39.pyc
Binary file not shown.
50 changes: 10 additions & 40 deletions CAT/distillation/MFI/model.py → CAT/distillation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from tqdm import tqdm

class dMFI(nn.Module):
def __init__(self,embedding_dim):
def __init__(self,embedding_dim,user_dim):
# self.prednet_input_len =1
super(dMFI, self).__init__()
self.utn = nn.Sequential(
nn.Linear(1, 256), nn.Sigmoid(
nn.Linear(user_dim, 256), nn.Sigmoid(
),
nn.Linear(256, 128), nn.Sigmoid(
),
Expand All @@ -32,8 +32,8 @@ def forward(self,u,i):
# return user*item

class dMFIModel(object):
def __init__(self, k, embedding_dim,device):
self.model = dMFI(embedding_dim)
def __init__(self, k, embedding_dim,user_dim,device):
self.model = dMFI(embedding_dim,user_dim)
# 20 1 1
self.k = k
self.device=device
Expand All @@ -42,15 +42,14 @@ def train(self,train_data,test_data,item_pool,lr=0.01,epoch=2):
self.model=self.model.to(self.device)
train_data=list(train_data)
test_data=list(test_data)
self.items_n = len(item_pool.keys())
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
loss = []
for epoch_i in range(epoch):
for data in tqdm(train_data):
for data in tqdm(train_data,f'Epoch {epoch_i+1} '):
utrait,itrait,label,_=data
itrait = itrait.squeeze()
u_loss: torch.Tensor = torch.tensor(0.).to(self.device)
utrait:torch.Tensor = torch.tensor([utrait]*self.items_n).unsqueeze(-1).to(self.device)
utrait:torch.Tensor = utrait.to(self.device)
itrait: torch.Tensor = itrait.to(self.device)
label: torch.Tensor = label.to(self.device)
score = self.model(utrait,itrait).squeeze(-1)
Expand All @@ -61,7 +60,7 @@ def train(self,train_data,test_data,item_pool,lr=0.01,epoch=2):
optimizer.step()
# print(float(np.mean(loss)))
# self.eval(valid_data,item_pool)
print(f'Epoch {epoch_i}:',float(np.mean(loss)))
print('Loss: ',float(np.mean(loss)))
self.eval(test_data,item_pool)

def load(self, path):
Expand All @@ -78,47 +77,18 @@ def eval(self,valid_data,item_pool):
k_nums=[1,3,5,10,30,50]
recall = [[]for i in k_nums]
for data in tqdm(valid_data,'testing'):
utrait,_,__,k_fisher=data
utrait,_,__,k_info=data
k_items,k_DCG = self.getkitems(utrait,item_pool)
# k_fisher
k_fisher=k_fisher[0]
for i,k in enumerate(k_nums):
i_kitems = set(k_items[:k]).intersection(set(k_fisher[:k]))
i_kitems = set(k_items[:k]).intersection(set(k_info[:k]))
recall[i].append(len(i_kitems)/k)
for i,k in enumerate(k_nums):
print(f'recall@{k}: ',np.mean(recall[i]))

# def get_k_fisher(self,theta,items):
# fisher_arr = []
# for qid,(alpha,beta) in items.items():
# pred = alpha * theta + beta
# pred = torch.sigmoid(torch.tensor(pred))
# q = 1 - pred
# fisher_info = (q*pred*(alpha ** 2)).numpy()
# fisher_arr.append((fisher_info,qid))
# fisher_arr_sorted = sorted(fisher_arr, reverse=True)
# return [i[1] for i in fisher_arr_sorted[:self.k]]

def estimate_rank(self,score,theta,items):
with torch.no_grad():
self.model.eval()
items_arr = list(range(self.items_n))
np.random.shuffle(items_arr)
samples = items_arr[0:self.sample_num]
theta = torch.tensor([theta]*len(samples))
utrait:torch.Tensor = theta.unsqueeze(-1).to(self.device)
itrait:torch.Tensor = torch.tensor([items[str(sample)] for sample in samples]).to(device)
scores = self.model(utrait,itrait).squeeze(-1)
s_rank = len([i for i in scores if i>score])
self.model.train()
return s_rank*(self.items_n-1)/self.sample_num+1


def getkitems(self, utrait,item_pool):
with torch.no_grad():
self.model.eval()
item_n =len(item_pool.keys())
utrait:torch.Tensor = torch.tensor([utrait]*item_n).unsqueeze(-1).to(self.device)
utrait:torch.Tensor = utrait.to(self.device)
itrait:torch.Tensor = torch.tensor(list(item_pool.values())).to(self.device)
scores = self.model(utrait,itrait).squeeze(-1)
tmp = list(zip(scores.tolist(),item_pool.keys()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from tqdm import tqdm
import numpy as np
from CAT.distillation.tool import get_label_and_k
import random


seed = 0
random.seed(seed)
cdm='irt'
dataset = 'assistment'
dataset = 'junyi'
stg='KLI'
train_triplets = pd.read_csv(
f'/data/yutingh/CAT/data/{dataset}/train_triples.csv', encoding='utf-8').to_records(index=False)
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_with_theta.pt'
Expand Down Expand Up @@ -43,10 +46,7 @@
model = CAT.model.NCDModel(**config)
model.init_model(train_data)
model.adaptest_load(ckpt_path)
# trait_log=[]
# for row in train_data:
# print(row)
# break

user_dict={}
for user_id in tqdm(range(train_data.num_students),'gettting theta'):
sid = torch.LongTensor([user_id]).to(config['device'])
Expand All @@ -58,23 +58,16 @@
alpha=model.get_alpha(qid)
beta=model.get_beta(qid)
item_dict[item_id]=[np.float(alpha[0]),np.float(beta[0])]
label,k_fisher = get_label_and_k(user_dict,item_dict,50)
label,k_info,tested_info = get_label_and_k(user_dict,item_dict,50,stg)
trait_dict = {
'user':user_dict,
'item':item_dict,
'label':label,
'k_fisher':k_fisher
'k_info':k_info,
'tested_info':tested_info
}


# for user_id,item_id,score in tqdm(train_triplets,'getting traits'):
# sid = torch.LongTensor([user_id]).to(config['device'])
# qid = torch.LongTensor([item_id]).to(config['device'])
# difficulty=model.get_beta(qid)
# discrimination=model.get_alpha(qid)
# theta=model.get_theta(sid)
# trait_log.append([theta[0],[difficulty[0],discrimination[0]]])
path_prefix = f"/data/yutingh/CAT/data/{dataset}/"
path_prefix = f"/data/yutingh/CAT/data/{dataset}/{stg}/"

with open(f"{path_prefix}trait.json", "w", encoding="utf-8") as f:
# indent参数保证json数据的缩进,美观
Expand Down
110 changes: 83 additions & 27 deletions CAT/distillation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from torch import Tensor
from tqdm import tqdm
import numpy as np
def split_data(utrait,label,k_fisher,rate=1):
import vegas
from scipy import integrate
import random
import functools

def split_data(utrait,label,k_fisher,rate=1,tested_info=None):
max_len=int(rate*len(utrait.keys()))
u_train={}
u_test={}
Expand All @@ -11,41 +16,54 @@ def split_data(utrait,label,k_fisher,rate=1):
u_test[u]=trait
else:
u_train[u]=trait
return ((u_train,label[:max_len],k_fisher[:max_len]),(u_test,label[max_len:],k_fisher[max_len:]))
if tested_info:
return ((u_train,label[:max_len],k_fisher[:max_len],tested_info[:max_len]),(u_test,label[max_len:],k_fisher[max_len:],tested_info[max_len:]))
else:
return ((u_train,label[:max_len],k_fisher[:max_len]),(u_test,label[max_len:],k_fisher[max_len:]))

def get_label_and_k(user_trait,item_trait,k):
def get_label_and_k(user_trait,item_trait,k,stg="MFI"):
labels=[]
k_fishers=[]
k_infos=[]
tested_ns=[]
for theta in tqdm(user_trait.values(),f'get top {k} items'):
k_fisher, label = get_k_fisher(k, theta, item_trait)
if stg=='MFI':
k_info,label = get_k_fisher(k, theta, item_trait)
elif stg=='KLI':
k_info,label,tested_n= get_k_kli(k, theta, item_trait)
tested_ns.append(tested_n)
k_infos.append(k_info)
labels.append(label)
k_fishers.append(k_fisher)
return labels,k_fishers
return labels,k_infos,tested_ns

def transform(item_trait, user_trait,labels,k_fishers):
for theta, label, k_fisher in zip(user_trait.values(),labels,k_fishers):
itrait = list(item_trait.values())
yield pack_batch([[
theta,
itrait,
label,
k_fisher
# topitrait,
# topkitems,
# tailitrait,
# tailkitems
]])
# batch=[]
# if batch:
# yield pack_batch(batch)
def transform(item_trait, user_trait,labels,k_fishers,tested_infos=None):
if tested_infos:
for theta, label, k_fisher,tested_info in zip(user_trait.values(),labels,k_fishers,tested_infos):
itrait = list(item_trait.values())
item_n = len(itrait)
yield pack_batch([
torch.tensor(list(zip([theta]*item_n,tested_info))),
itrait,
label,
k_fisher
])
else:
for theta, label, k_fisher in zip(user_trait.values(),labels,k_fishers):
itrait = list(item_trait.values())
item_n = len(itrait)
yield pack_batch([
torch.tensor([theta]*item_n).unsqueeze(-1),
itrait,
label,
k_fisher
])

def pack_batch(batch):
theta, itrait, label, k_fisher= zip(*batch)
theta, itrait, label, k_fisher= batch
return (
Tensor(theta), Tensor(itrait), Tensor(label), k_fisher
theta, Tensor(itrait), Tensor(label), k_fisher
)

def get_k_fisher(k,theta,items):
def get_k_fisher(k, theta, items):
fisher_arr = []
for qid,(alpha,beta) in items.items():
pred = alpha * theta + beta
Expand All @@ -55,4 +73,42 @@ def get_k_fisher(k,theta,items):
fisher_info = float((q*pred*(alpha ** 2)).numpy())
fisher_arr.append((fisher_info,qid))
fisher_arr_sorted = sorted(fisher_arr, reverse=True)
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr]
return [i[1] for i in fisher_arr_sorted[:k]],[i[0]for i in fisher_arr]



def get_k_kli(k, theta, items):
items_n=len(items.keys())
ns = [random.randint(1,20) for i in range(items_n)]
dim = 1
res_arr = []
for (qid,(alpha, beta)),n in zip(items.items(),ns):
if type(alpha) == float:
alpha = np.array([alpha])
if type(theta) == float:
theta = np.array([theta])
pred_estimate = np.matmul(alpha.T, theta) + beta
pred_estimate = 1 / (1 + np.exp(-pred_estimate))
def kli(x):
if type(x) == float:
x = np.array([x])
pred = np.matmul(alpha.T, x) + beta
pred = 1 / (1 + np.exp(-pred))
q_estimate = 1 - pred
q = 1 - pred
return pred_estimate * np.log(pred_estimate / pred) + \
q_estimate * np.log((q_estimate / q))
c = 3
boundaries = [
[theta[i] - c / np.sqrt(n), theta[i] + c / np.sqrt(n)] for i in range(dim)]
if len(boundaries) == 1:
# KLI
v, err = integrate.quad(kli, boundaries[0][0], boundaries[0][1])
res_arr.append((v,qid))
else:
# MKLI
integ = vegas.Integrator(boundaries)
result = integ(kli, nitn=10, neval=1000)
res_arr.append((result.mean,qid))
res_arr_sorted = sorted(res_arr, reverse=True)
return [i[1] for i in res_arr_sorted[:k]],[i[0]for i in res_arr],ns
40 changes: 40 additions & 0 deletions CAT/distillation/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
get theta ,get a b, [theta][alpha,beta]
compute MFI by getMFI top-k
compute by dot production
conpute loss
"""
from CAT.distillation.model import dMFIModel
from CAT.distillation.tool import transform,split_data
import torch
import json
dataset='assistment'
cdm='irt'
stg='KLI'
trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg}/trait.json', 'r'))
utrait = trait['user']
itrait = trait['item']
# print(utrait,itrait)
label = trait['label']
k_info = trait['k_info']
if stg=='KLI':
tested_info= trait['tested_info']
train_data, test_data = split_data(utrait,label,k_info,0.8,tested_info)
else:
train_data, test_data = split_data(utrait,label,k_info,0.8)

torch.manual_seed(0)
train_set = transform(itrait,*train_data)
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
user_dim=2 if stg =='KLI'else 1
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:4')
dMFI.train(train_set,test_set,itrait,epoch=30,lr=lr)
dMFI.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip.pt')




Binary file modified CAT/mips/__pycache__/ball_tree.cpython-39.pyc
Binary file not shown.
Loading

0 comments on commit 74b0d44

Please sign in to comment.