Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Nov 5, 2022
1 parent 74b0d44 commit f9e619b
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 168 deletions.
3 changes: 2 additions & 1 deletion CAT/distillation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
test_set = transform(itrait,*test_data)
k=50
embedding_dim=15
epoch=30
lr=0.005 if dataset=='assistment' else 0.01
print(f'lr: {lr}')
user_dim=2 if stg =='KLI'else 1
dMFI = dMFIModel(k,embedding_dim,user_dim,device='cuda:4')
dMFI.train(train_set,test_set,itrait,epoch=30,lr=lr)
dMFI.train(train_set,test_set,itrait,epoch=epoch,lr=lr)
dMFI.save(f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}_{stg}_ip.pt')


Expand Down
164 changes: 0 additions & 164 deletions CAT/mips/a.py

This file was deleted.

5 changes: 2 additions & 3 deletions CAT/mips/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def setuplogger():

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
def main(dataset="assistment", cdm="irt", stg = ['KLI'], test_length = 20, ctx="cuda:4", lr=0.2, num_epoch=1, efficient=False):
def main(dataset="assistment", cdm="irt", stg = ['Random'], test_length = 20, ctx="cuda:4", lr=0.2, num_epoch=1, efficient=False):
lr=0.15 if dataset=='assistment' else 0.2
setuplogger()
seed = 0
Expand All @@ -46,7 +46,6 @@ def main(dataset="assistment", cdm="irt", stg = ['KLI'], test_length = 20, ctx="
ckpt_path = f'/data/yutingh/CAT/ckpt/{dataset}/{cdm}.pt'
# read datasets
test_triplets = pd.read_csv(f'/data/yutingh/CAT/data/{dataset}/test_triples.csv', encoding='utf-8').to_records(index=False)
# test_triplets = pd.read_csv(f'/data/yutingh/CAT/data/{dataset}/test_filled_triplets.csv', encoding='utf-8').to_records(index=False)
concept_map = json.load(open(f'/data/yutingh/CAT/data/{dataset}/item_topic.json', 'r'))
concept_map = {int(k):v for k,v in concept_map.items()}

Expand All @@ -73,7 +72,7 @@ def main(dataset="assistment", cdm="irt", stg = ['KLI'], test_length = 20, ctx="
model.adaptest_load(ckpt_path)
test_data.reset()
if efficient:
ball_trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/ball_trait.json', 'r'))
ball_trait = json.load(open(f'/data/yutingh/CAT/data/{dataset}/{stg[i]}/ball_trait.json', 'r'))
distill_k=50
embedding_dim=15
user_dim=2 if stg[0]=='KLI' else 1
Expand Down
Binary file removed ckpt/assistment/irt_ip.pt
Binary file not shown.

0 comments on commit f9e619b

Please sign in to comment.