From f551bb1f98e3b8ff62a6213bfa44f94350999f8e Mon Sep 17 00:00:00 2001
From: yujunhao <772986150@qq.com>
Date: Sat, 25 May 2024 17:29:08 +0800
Subject: [PATCH] Fix NCAT bugs
---
CAT/dataset/adaptest_dataset.py | 35 +++++--
CAT/model/IRT.py | 15 ++-
CAT/strategy/NCAT_nn/NCAT.py | 37 +++++---
CAT/strategy/NCAT_strategy.py | 23 +++--
scripts/dataset/assistment.ipynb | 153 ++++++++++++++++---------------
scripts/test.ipynb | 144 +++++++----------------------
scripts/train.ipynb | 70 +++++---------
7 files changed, 212 insertions(+), 265 deletions(-)
diff --git a/CAT/dataset/adaptest_dataset.py b/CAT/dataset/adaptest_dataset.py
index 4768d36..19ddef9 100644
--- a/CAT/dataset/adaptest_dataset.py
+++ b/CAT/dataset/adaptest_dataset.py
@@ -60,7 +60,7 @@ def tested(self):
def untested(self):
return self._untested
- def get_tested_dataset(self, last=False):
+ def get_tested_dataset(self, last=False,ssid=None):
"""
Get tested data for training
Args:
@@ -68,13 +68,28 @@ def get_tested_dataset(self, last=False):
Returns:
TrainDataset
"""
- triplets = []
- for sid, qids in self._tested.items():
- if last:
- qid = qids[-1]
- triplets.append((sid, qid, self.data[sid][qid]))
- else:
- for qid in qids:
+ if ssid==None:
+ triplets = []
+ for sid, qids in self._tested.items():
+ if last:
+ qid = qids[-1]
+
triplets.append((sid, qid, self.data[sid][qid]))
- return TrainDataset(triplets, self.concept_map,
- self.num_students, self.num_questions, self.num_concepts)
\ No newline at end of file
+ else:
+ for qid in qids:
+ triplets.append((sid, qid, self.data[sid][qid]))
+ return TrainDataset(triplets, self.concept_map,
+ self.num_students, self.num_questions, self.num_concepts)
+ else:
+ triplets = []
+ for sid, qids in self._tested.items():
+ if ssid == sid:
+ if last:
+ qid = qids[-1]
+
+ triplets.append((sid, qid, self.data[sid][qid]))
+ else:
+ for qid in qids:
+ triplets.append((sid, qid, self.data[sid][qid]))
+ return TrainDataset(triplets, self.concept_map,
+ self.num_students, self.num_questions, self.num_concepts)
\ No newline at end of file
diff --git a/CAT/model/IRT.py b/CAT/model/IRT.py
index a480a10..5b828d0 100644
--- a/CAT/model/IRT.py
+++ b/CAT/model/IRT.py
@@ -13,7 +13,7 @@
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple
-from utils import StraightThrough
+from .utils import StraightThrough
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
class IRT(nn.Module):
@@ -99,7 +99,7 @@ def adaptest_load(self, path):
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])
- def adaptest_update(self, adaptest_data: AdapTestDataset):
+ def adaptest_update(self, adaptest_data: AdapTestDataset,sid=None):
"""
Update CDM with tested data
"""
@@ -109,9 +109,8 @@ def adaptest_update(self, adaptest_data: AdapTestDataset):
device = self.config['device']
optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)
- tested_dataset = adaptest_data.get_tested_dataset(last=True)
+ tested_dataset = adaptest_data.get_tested_dataset(last=True,ssid=sid)
dataloader = torch.utils.data.DataLoader(tested_dataset, batch_size=batch_size, shuffle=True)
-
for ep in range(1, epochs + 1):
loss = 0.0
log_steps = 100
@@ -127,7 +126,13 @@ def adaptest_update(self, adaptest_data: AdapTestDataset):
loss += bz_loss.data.float()
# if cnt % log_steps == 0:
# print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt))
-
+ return loss
+ def one_student_update(self, adaptest_data: AdapTestDataset):
+ lr = self.config['learning_rate']
+ batch_size = self.config['batch_size']
+ epochs = self.config['num_epochs']
+ device = self.config['device']
+ optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr)
def evaluate(self, adaptest_data: AdapTestDataset):
data = adaptest_data.data
concept_map = adaptest_data.concept_map
diff --git a/CAT/strategy/NCAT_nn/NCAT.py b/CAT/strategy/NCAT_nn/NCAT.py
index 33e58f2..7fadf38 100644
--- a/CAT/strategy/NCAT_nn/NCAT.py
+++ b/CAT/strategy/NCAT_nn/NCAT.py
@@ -12,6 +12,7 @@
from scipy.optimize import minimize
from CAT.dataset import AdapTestDataset,Dataset
from CAT.model.IRT import IRT,IRTModel
+import os
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
@@ -44,7 +45,7 @@ def __init__(self, n_question, d_model=10, n_blocks=1,kq_same=True, dropout=0.0,
n_heads: number of heads in multi-headed attention
d_ff : dimension for fully conntected net inside the basic block
"""
- self.device = torch.device('cuda')
+ self.device = torch.device('cpu')
self.pad = pad
self.n_question = n_question
self.dropout = dropout
@@ -134,7 +135,7 @@ def optimize_model(self, data, lr):
def create_model(cls, config):
model = cls(config.item_num, config.latent_factor, config.num_blocks,
True, config.dropout_rate, policy_fc_dim=512, n_heads=config.num_heads, d_ff=2048, l2=1e-5, separate_qa=None, pad=0)
- return model.to(torch.device('cuda'))
+ return model.to(torch.device('cpu'))
def mask(src, s_len):
@@ -309,12 +310,12 @@ def __init__(self,data,concept_map,config,T):
self.users = {}
self.utypes = {}
#self.args = args
- self.device = torch.device('cuda')
+ self.device = torch.device('cpu')
self.rates, self._item_num, self.know_map = self.load_data(data,concept_map)
self.tsdata=data
self.setup_train_test()
self.sup_rates, self.query_rates = self.split_data(ratio=0.5)
- pth_path='CDM.pt'
+ pth_path='../ckpt/irt.pt'
name = 'IRT'
self.model, self.dataset = self.load_CDM(name,data,pth_path,config)
#print(self.model)
@@ -368,8 +369,8 @@ def load_CDM(self,name,data,pth_path,config):
if name == 'IRT':
model = IRTModel(**config)
model.init_model(data)
- model.load_state_dict(torch.load(pth_path), strict=False)
- model.to(self.config['device'])
+ model.adaptest_load(pth_path)
+ #model.to(self.config['device'])
return model ,data.data
def step(self, action,sid):
@@ -389,10 +390,13 @@ def step(self, action,sid):
def reward(self, action,sid):
items = [state[0] for state in self.state[1]] + [action]
+
correct = [self.rates[self.state[0][0]][it] for it in items]
self.tsdata.apply_selection(sid, action)
- loss = self.model.adaptest_update(self.tsdata)
- acc,auc=self.model.evaluate(self.tsdata)
+ loss = self.model.adaptest_update(self.tsdata,sid)
+ result=self.model.evaluate(self.tsdata)
+ auc = result['auc']
+ acc = result['acc']
return -loss, acc, auc, correct[-1]
def precision(self, episode):
@@ -455,13 +459,15 @@ def __init__(self, NCATdata,concept_map,config,test_length):
self.config = config
self.model = None
self.env = env(data=NCATdata,concept_map=concept_map,config=config,T=test_length)
-
+ self.memory = []
self.item_num =self.env.item_num
self.user_num = self.env.user_num
self.device = config['device']
self.fa = NCAT(n_question=NCATdata.num_questions+1).to(self.device)
-
- def ncat_policy(self,sid,THRESHOLD,used_actions):
+ self.memory_size = 50000
+ self.tau = 0
+
+ def ncat_policy(self,sid,THRESHOLD,used_actions,type,epoch):
actions = {}
rwds = 0
done = False
@@ -475,7 +481,7 @@ def ncat_policy(self,sid,THRESHOLD,used_actions):
data["uid"] = torch.tensor(data["uid"], device=self.device)
policy = self.fa.predict(data)[0]
if type == "training":
- if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.args.item_num,))
+ if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.item_num,))
for item in actions: policy[item] = -np.inf
for item in range(self.item_num):
if item not in self.env.candidate_items:
@@ -490,6 +496,13 @@ def ncat_policy(self,sid,THRESHOLD,used_actions):
rwds += rwd
state = state_next
used_actions.extend(list(actions.keys()))
+ if type == "training":
+ if len(self.memory) >= self.config['batch_size']:
+ self.memory = self.memory[-self.memory_size:]
+ batch = [self.memory[item] for item in np.random.choice(range(len(self.memory)),(self.args.batch,))]
+ data = self.convert_batch2dict(batch,epoch)
+ loss = self.fa.optimize_model(data, self.args.learning_rate)
+ self.tau += 5
return
def convert_batch2dict(self, batch, epoch):
diff --git a/CAT/strategy/NCAT_strategy.py b/CAT/strategy/NCAT_strategy.py
index cbbbbae..e09b9d0 100644
--- a/CAT/strategy/NCAT_strategy.py
+++ b/CAT/strategy/NCAT_strategy.py
@@ -16,12 +16,19 @@ def __init__(self):
def name(self):
return 'NCAT'
- def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length):
- used_actions = []
+ def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length):
+ selection = {}
+ NCATdata = adaptest_data
+ model = NCATModel(NCATdata,concept_map,config,test_length)
+ threshold = config['THRESHOLD']
for sid in range(adaptest_data.num_students):
- NCATdata = adaptest_data
- model = NCATModel(NCATdata,concept_map,config,test_length)
- threshold = config['THRESHOLD']
- model.ncat_policy(sid,threshold,used_actions)
-
- return used_actions
+ print(str(sid+1)+'/'+str(adaptest_data.num_students))
+ used_actions = []
+ model.ncat_policy(sid,threshold,used_actions,type="training",epoch=100)
+ NCATdata.reset()
+ for sid in range(adaptest_data.num_students):
+ used_actions = []
+ model.ncat_policy(sid,threshold,used_actions,type="testing",epoch=0)
+ selection[sid] = used_actions
+ NCATdata.reset()
+ return selection
diff --git a/scripts/dataset/assistment.ipynb b/scripts/dataset/assistment.ipynb
index acfec2d..030683d 100644
--- a/scripts/dataset/assistment.ipynb
+++ b/scripts/dataset/assistment.ipynb
@@ -39,8 +39,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/yutingning/opt/anaconda3/envs/cat/lib/python3.7/site-packages/IPython/core/interactiveshell.py:3147: DtypeWarning: Columns (18) have mixed types.Specify dtype option on import or set low_memory=False.\n",
- " interactivity=interactivity, compiler=compiler, result=result)\n"
+ "c:\\Users\\DELL\\.conda\\envs\\Basenv\\lib\\site-packages\\IPython\\core\\interactiveshell.py:3553: DtypeWarning: Columns (17) have mixed types.Specify dtype option on import or set low_memory=False.\n",
+ " exec(code_obj, self.user_global_ns, self.user_ns)\n"
]
},
{
@@ -64,7 +64,6 @@
" \n",
" \n",
" \n",
- " Unnamed: 0 \n",
" order_id \n",
" assignment_id \n",
" user_id \n",
@@ -74,6 +73,7 @@
" correct \n",
" attempt_count \n",
" ms_first_response \n",
+ " tutor_mode \n",
" ... \n",
" hint_count \n",
" hint_total \n",
@@ -90,7 +90,6 @@
"
5 rows × 31 columns
\n", + "5 rows × 30 columns
\n", "" ], "text/plain": [ - " Unnamed: 0 order_id assignment_id user_id assistment_id problem_id \\\n", - "0 1 33022537 277618 64525 33139 51424 \n", - "1 2 33022709 277618 64525 33150 51435 \n", - "2 3 35450204 220674 70363 33159 51444 \n", - "3 4 35450295 220674 70363 33110 51395 \n", - "4 5 35450311 220674 70363 33196 51481 \n", + " order_id assignment_id user_id assistment_id problem_id original \\\n", + "0 33022537 277618 64525 33139 51424 1 \n", + "1 33022709 277618 64525 33150 51435 1 \n", + "2 35450204 220674 70363 33159 51444 1 \n", + "3 35450295 220674 70363 33110 51395 1 \n", + "4 35450311 220674 70363 33196 51481 1 \n", "\n", - " original correct attempt_count ms_first_response ... hint_count \\\n", - "0 1 1 1 32454 ... 0 \n", - "1 1 1 1 4922 ... 0 \n", - "2 1 0 2 25390 ... 0 \n", - "3 1 1 1 4859 ... 0 \n", - "4 1 0 14 19813 ... 3 \n", + " correct attempt_count ms_first_response tutor_mode ... hint_count \\\n", + "0 1 1 32454 tutor ... 0 \n", + "1 1 1 4922 tutor ... 0 \n", + "2 0 2 25390 tutor ... 0 \n", + "3 1 1 4859 tutor ... 0 \n", + "4 0 14 19813 tutor ... 3 \n", "\n", - " hint_total overlap_time template_id answer_id answer_text first_action \\\n", - "0 3 32454 30799 NaN 26 0 \n", - "1 3 4922 30799 NaN 55 0 \n", - "2 3 42000 30799 NaN 88 0 \n", - "3 3 4859 30059 NaN 41 0 \n", - "4 4 124564 30060 NaN 65 0 \n", + " hint_total overlap_time template_id answer_id answer_text first_action \\\n", + "0 3 32454 30799 NaN 26 0 \n", + "1 3 4922 30799 NaN 55 0 \n", + "2 3 42000 30799 NaN 88 0 \n", + "3 3 4859 30059 NaN 41 0 \n", + "4 4 124564 30060 NaN 65 0 \n", "\n", - " bottom_hint opportunity opportunity_original \n", - "0 NaN 1 1.0 \n", - "1 NaN 2 2.0 \n", - "2 NaN 1 1.0 \n", - "3 NaN 2 2.0 \n", - "4 0.0 3 3.0 \n", + " bottom_hint opportunity opportunity_original \n", + "0 NaN 1 1.0 \n", + "1 NaN 2 2.0 \n", + "2 NaN 1 1.0 \n", + "3 NaN 2 2.0 \n", + "4 0.0 3 3.0 \n", "\n", - "[5 rows x 31 columns]" + "[5 rows x 30 columns]" ] }, "execution_count": 3, @@ -251,7 +251,8 @@ } ], "source": [ - "raw_data = pd.read_csv('../../data/assistment.csv', encoding = 'utf-8', dtype={'skill_id': str})\n", + "data_path ='../../data/assistment/'\n", + "raw_data = pd.read_csv('../../data/assistment/assistment.csv', encoding = 'utf-8', dtype={'skill_id': str})\n", "raw_data.head()" ] }, @@ -278,11 +279,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total length: 274590\n", - "Number of unique [student_id,question_id]: 270478\n", + "Total length: 325636\n", + "Number of unique [student_id,question_id]: 270477\n", "Number of unique student_id: 4151\n", "Number of unique question_id: 16891\n", - "Number of unique knowledge_id: 138\n" + "Number of unique knowledge_id: 111\n" ] } ], @@ -291,7 +292,9 @@ "stat_unique(all_data, ['student_id', 'question_id'])\n", "stat_unique(all_data, 'student_id')\n", "stat_unique(all_data, 'question_id')\n", - "stat_unique(all_data, 'knowledge_id')" + "stat_unique(all_data, 'knowledge_id')\n", + "ques_num = len(all_data['question_id'].unique())\n", + "know_num = len(all_data['knowledge_id'].unique())" ] }, { @@ -319,7 +322,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 15924 questions\n" + "filter 15370 questions\n" ] } ], @@ -340,7 +343,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 1749 students\n" + "filter 1471 students\n" ] } ], @@ -382,7 +385,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 10 knowledges\n" + "filter 8 knowledges\n" ] } ], @@ -475,12 +478,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total length: 61860\n", - "Number of unique [student_id,question_id]: 59500\n", - "Number of unique student_id: 1505\n", - "Number of unique question_id: 932\n", - "Number of unique knowledge_id: 22\n", - "Average #questions per knowledge: 44.38095238095238\n" + "Total length: 110398\n", + "Number of unique [student_id,question_id]: 78747\n", + "Number of unique student_id: 1940\n", + "Number of unique question_id: 1485\n", + "Number of unique knowledge_id: 35\n", + "Average #questions per knowledge: 59.4\n" ] } ], @@ -500,7 +503,7 @@ "outputs": [], "source": [ "# save selected data\n", - "selected_data.to_csv('selected_data.csv', index=False)" + "selected_data.to_csv(data_path+'selected_data.csv', index=False)" ] }, { @@ -528,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -555,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -566,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -575,7 +578,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -585,7 +588,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -611,7 +614,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -632,7 +635,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -643,16 +646,16 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train records length: 51010\n", - "test records length: 10850\n", - "all records length: 61860\n" + "train records length: 60393\n", + "test records length: 50005\n", + "all records length: 110398\n" ] } ], @@ -671,7 +674,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -687,36 +690,36 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ - "save_to_csv(train_data, 'train_triples.csv')\n", - "save_to_csv(test_data, 'test_triples.csv')\n", - "save_to_csv(all_data, 'triples.csv')" + "save_to_csv(train_data, data_path+'train_triples.csv')\n", + "save_to_csv(test_data, data_path+'test_triples.csv')\n", + "save_to_csv(all_data, data_path+'triples.csv')" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ - "metadata = {\"num_students\": 1505, \n", - " \"num_questions\": 932,\n", - " \"num_concepts\": 22, \n", - " \"num_records\": 61860, \n", - " \"num_train_students\": 1444, \n", - " \"num_test_students\": 61}" + "metadata = {\"num_students\": n_students, \n", + " \"num_questions\": ques_num,\n", + " \"num_concepts\": know_num, \n", + " \"num_records\": len(all_data), \n", + " \"num_train_students\": n_students - len(test_students), \n", + " \"num_test_students\": len(test_students)}" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ - "with open('metadata.json', 'w') as f:\n", + "with open(data_path+'metadata.json', 'w') as f:\n", " json.dump(metadata, f)" ] } @@ -737,7 +740,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.16" } }, "nbformat": 4, diff --git a/scripts/test.ipynb b/scripts/test.ipynb index 3583bb7..b8d419c 100644 --- a/scripts/test.ipynb +++ b/scripts/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -55,20 +55,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "