From f551bb1f98e3b8ff62a6213bfa44f94350999f8e Mon Sep 17 00:00:00 2001 From: yujunhao <772986150@qq.com> Date: Sat, 25 May 2024 17:29:08 +0800 Subject: [PATCH] Fix NCAT bugs --- CAT/dataset/adaptest_dataset.py | 35 +++++-- CAT/model/IRT.py | 15 ++- CAT/strategy/NCAT_nn/NCAT.py | 37 +++++--- CAT/strategy/NCAT_strategy.py | 23 +++-- scripts/dataset/assistment.ipynb | 153 ++++++++++++++++--------------- scripts/test.ipynb | 144 +++++++---------------------- scripts/train.ipynb | 70 +++++--------- 7 files changed, 212 insertions(+), 265 deletions(-) diff --git a/CAT/dataset/adaptest_dataset.py b/CAT/dataset/adaptest_dataset.py index 4768d36..19ddef9 100644 --- a/CAT/dataset/adaptest_dataset.py +++ b/CAT/dataset/adaptest_dataset.py @@ -60,7 +60,7 @@ def tested(self): def untested(self): return self._untested - def get_tested_dataset(self, last=False): + def get_tested_dataset(self, last=False,ssid=None): """ Get tested data for training Args: @@ -68,13 +68,28 @@ def get_tested_dataset(self, last=False): Returns: TrainDataset """ - triplets = [] - for sid, qids in self._tested.items(): - if last: - qid = qids[-1] - triplets.append((sid, qid, self.data[sid][qid])) - else: - for qid in qids: + if ssid==None: + triplets = [] + for sid, qids in self._tested.items(): + if last: + qid = qids[-1] + triplets.append((sid, qid, self.data[sid][qid])) - return TrainDataset(triplets, self.concept_map, - self.num_students, self.num_questions, self.num_concepts) \ No newline at end of file + else: + for qid in qids: + triplets.append((sid, qid, self.data[sid][qid])) + return TrainDataset(triplets, self.concept_map, + self.num_students, self.num_questions, self.num_concepts) + else: + triplets = [] + for sid, qids in self._tested.items(): + if ssid == sid: + if last: + qid = qids[-1] + + triplets.append((sid, qid, self.data[sid][qid])) + else: + for qid in qids: + triplets.append((sid, qid, self.data[sid][qid])) + return TrainDataset(triplets, self.concept_map, + self.num_students, self.num_questions, self.num_concepts) \ No newline at end of file diff --git a/CAT/model/IRT.py b/CAT/model/IRT.py index a480a10..5b828d0 100644 --- a/CAT/model/IRT.py +++ b/CAT/model/IRT.py @@ -13,7 +13,7 @@ from CAT.dataset import AdapTestDataset, TrainDataset, Dataset from sklearn.metrics import accuracy_score from collections import namedtuple -from utils import StraightThrough +from .utils import StraightThrough SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) class IRT(nn.Module): @@ -99,7 +99,7 @@ def adaptest_load(self, path): self.model.load_state_dict(torch.load(path), strict=False) self.model.to(self.config['device']) - def adaptest_update(self, adaptest_data: AdapTestDataset): + def adaptest_update(self, adaptest_data: AdapTestDataset,sid=None): """ Update CDM with tested data """ @@ -109,9 +109,8 @@ def adaptest_update(self, adaptest_data: AdapTestDataset): device = self.config['device'] optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr) - tested_dataset = adaptest_data.get_tested_dataset(last=True) + tested_dataset = adaptest_data.get_tested_dataset(last=True,ssid=sid) dataloader = torch.utils.data.DataLoader(tested_dataset, batch_size=batch_size, shuffle=True) - for ep in range(1, epochs + 1): loss = 0.0 log_steps = 100 @@ -127,7 +126,13 @@ def adaptest_update(self, adaptest_data: AdapTestDataset): loss += bz_loss.data.float() # if cnt % log_steps == 0: # print('Epoch [{}] Batch [{}]: loss={:.3f}'.format(ep, cnt, loss / cnt)) - + return loss + def one_student_update(self, adaptest_data: AdapTestDataset): + lr = self.config['learning_rate'] + batch_size = self.config['batch_size'] + epochs = self.config['num_epochs'] + device = self.config['device'] + optimizer = torch.optim.Adam(self.model.theta.parameters(), lr=lr) def evaluate(self, adaptest_data: AdapTestDataset): data = adaptest_data.data concept_map = adaptest_data.concept_map diff --git a/CAT/strategy/NCAT_nn/NCAT.py b/CAT/strategy/NCAT_nn/NCAT.py index 33e58f2..7fadf38 100644 --- a/CAT/strategy/NCAT_nn/NCAT.py +++ b/CAT/strategy/NCAT_nn/NCAT.py @@ -12,6 +12,7 @@ from scipy.optimize import minimize from CAT.dataset import AdapTestDataset,Dataset from CAT.model.IRT import IRT,IRTModel +import os SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) @@ -44,7 +45,7 @@ def __init__(self, n_question, d_model=10, n_blocks=1,kq_same=True, dropout=0.0, n_heads: number of heads in multi-headed attention d_ff : dimension for fully conntected net inside the basic block """ - self.device = torch.device('cuda') + self.device = torch.device('cpu') self.pad = pad self.n_question = n_question self.dropout = dropout @@ -134,7 +135,7 @@ def optimize_model(self, data, lr): def create_model(cls, config): model = cls(config.item_num, config.latent_factor, config.num_blocks, True, config.dropout_rate, policy_fc_dim=512, n_heads=config.num_heads, d_ff=2048, l2=1e-5, separate_qa=None, pad=0) - return model.to(torch.device('cuda')) + return model.to(torch.device('cpu')) def mask(src, s_len): @@ -309,12 +310,12 @@ def __init__(self,data,concept_map,config,T): self.users = {} self.utypes = {} #self.args = args - self.device = torch.device('cuda') + self.device = torch.device('cpu') self.rates, self._item_num, self.know_map = self.load_data(data,concept_map) self.tsdata=data self.setup_train_test() self.sup_rates, self.query_rates = self.split_data(ratio=0.5) - pth_path='CDM.pt' + pth_path='../ckpt/irt.pt' name = 'IRT' self.model, self.dataset = self.load_CDM(name,data,pth_path,config) #print(self.model) @@ -368,8 +369,8 @@ def load_CDM(self,name,data,pth_path,config): if name == 'IRT': model = IRTModel(**config) model.init_model(data) - model.load_state_dict(torch.load(pth_path), strict=False) - model.to(self.config['device']) + model.adaptest_load(pth_path) + #model.to(self.config['device']) return model ,data.data def step(self, action,sid): @@ -389,10 +390,13 @@ def step(self, action,sid): def reward(self, action,sid): items = [state[0] for state in self.state[1]] + [action] + correct = [self.rates[self.state[0][0]][it] for it in items] self.tsdata.apply_selection(sid, action) - loss = self.model.adaptest_update(self.tsdata) - acc,auc=self.model.evaluate(self.tsdata) + loss = self.model.adaptest_update(self.tsdata,sid) + result=self.model.evaluate(self.tsdata) + auc = result['auc'] + acc = result['acc'] return -loss, acc, auc, correct[-1] def precision(self, episode): @@ -455,13 +459,15 @@ def __init__(self, NCATdata,concept_map,config,test_length): self.config = config self.model = None self.env = env(data=NCATdata,concept_map=concept_map,config=config,T=test_length) - + self.memory = [] self.item_num =self.env.item_num self.user_num = self.env.user_num self.device = config['device'] self.fa = NCAT(n_question=NCATdata.num_questions+1).to(self.device) - - def ncat_policy(self,sid,THRESHOLD,used_actions): + self.memory_size = 50000 + self.tau = 0 + + def ncat_policy(self,sid,THRESHOLD,used_actions,type,epoch): actions = {} rwds = 0 done = False @@ -475,7 +481,7 @@ def ncat_policy(self,sid,THRESHOLD,used_actions): data["uid"] = torch.tensor(data["uid"], device=self.device) policy = self.fa.predict(data)[0] if type == "training": - if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.args.item_num,)) + if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.item_num,)) for item in actions: policy[item] = -np.inf for item in range(self.item_num): if item not in self.env.candidate_items: @@ -490,6 +496,13 @@ def ncat_policy(self,sid,THRESHOLD,used_actions): rwds += rwd state = state_next used_actions.extend(list(actions.keys())) + if type == "training": + if len(self.memory) >= self.config['batch_size']: + self.memory = self.memory[-self.memory_size:] + batch = [self.memory[item] for item in np.random.choice(range(len(self.memory)),(self.args.batch,))] + data = self.convert_batch2dict(batch,epoch) + loss = self.fa.optimize_model(data, self.args.learning_rate) + self.tau += 5 return def convert_batch2dict(self, batch, epoch): diff --git a/CAT/strategy/NCAT_strategy.py b/CAT/strategy/NCAT_strategy.py index cbbbbae..e09b9d0 100644 --- a/CAT/strategy/NCAT_strategy.py +++ b/CAT/strategy/NCAT_strategy.py @@ -16,12 +16,19 @@ def __init__(self): def name(self): return 'NCAT' - def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length): - used_actions = [] + def adaptest_select(self, adaptest_data: AdapTestDataset,concept_map,config,test_length): + selection = {} + NCATdata = adaptest_data + model = NCATModel(NCATdata,concept_map,config,test_length) + threshold = config['THRESHOLD'] for sid in range(adaptest_data.num_students): - NCATdata = adaptest_data - model = NCATModel(NCATdata,concept_map,config,test_length) - threshold = config['THRESHOLD'] - model.ncat_policy(sid,threshold,used_actions) - - return used_actions + print(str(sid+1)+'/'+str(adaptest_data.num_students)) + used_actions = [] + model.ncat_policy(sid,threshold,used_actions,type="training",epoch=100) + NCATdata.reset() + for sid in range(adaptest_data.num_students): + used_actions = [] + model.ncat_policy(sid,threshold,used_actions,type="testing",epoch=0) + selection[sid] = used_actions + NCATdata.reset() + return selection diff --git a/scripts/dataset/assistment.ipynb b/scripts/dataset/assistment.ipynb index acfec2d..030683d 100644 --- a/scripts/dataset/assistment.ipynb +++ b/scripts/dataset/assistment.ipynb @@ -39,8 +39,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/yutingning/opt/anaconda3/envs/cat/lib/python3.7/site-packages/IPython/core/interactiveshell.py:3147: DtypeWarning: Columns (18) have mixed types.Specify dtype option on import or set low_memory=False.\n", - " interactivity=interactivity, compiler=compiler, result=result)\n" + "c:\\Users\\DELL\\.conda\\envs\\Basenv\\lib\\site-packages\\IPython\\core\\interactiveshell.py:3553: DtypeWarning: Columns (17) have mixed types.Specify dtype option on import or set low_memory=False.\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n" ] }, { @@ -64,7 +64,6 @@ " \n", " \n", " \n", - " Unnamed: 0\n", " order_id\n", " assignment_id\n", " user_id\n", @@ -74,6 +73,7 @@ " correct\n", " attempt_count\n", " ms_first_response\n", + " tutor_mode\n", " ...\n", " hint_count\n", " hint_total\n", @@ -90,7 +90,6 @@ " \n", " \n", " 0\n", - " 1\n", " 33022537\n", " 277618\n", " 64525\n", @@ -100,6 +99,7 @@ " 1\n", " 1\n", " 32454\n", + " tutor\n", " ...\n", " 0\n", " 3\n", @@ -114,7 +114,6 @@ " \n", " \n", " 1\n", - " 2\n", " 33022709\n", " 277618\n", " 64525\n", @@ -124,6 +123,7 @@ " 1\n", " 1\n", " 4922\n", + " tutor\n", " ...\n", " 0\n", " 3\n", @@ -138,7 +138,6 @@ " \n", " \n", " 2\n", - " 3\n", " 35450204\n", " 220674\n", " 70363\n", @@ -148,6 +147,7 @@ " 0\n", " 2\n", " 25390\n", + " tutor\n", " ...\n", " 0\n", " 3\n", @@ -162,7 +162,6 @@ " \n", " \n", " 3\n", - " 4\n", " 35450295\n", " 220674\n", " 70363\n", @@ -172,6 +171,7 @@ " 1\n", " 1\n", " 4859\n", + " tutor\n", " ...\n", " 0\n", " 3\n", @@ -186,7 +186,6 @@ " \n", " \n", " 4\n", - " 5\n", " 35450311\n", " 220674\n", " 70363\n", @@ -196,6 +195,7 @@ " 0\n", " 14\n", " 19813\n", + " tutor\n", " ...\n", " 3\n", " 4\n", @@ -210,39 +210,39 @@ " \n", " \n", "\n", - "

5 rows × 31 columns

\n", + "

5 rows × 30 columns

\n", "" ], "text/plain": [ - " Unnamed: 0 order_id assignment_id user_id assistment_id problem_id \\\n", - "0 1 33022537 277618 64525 33139 51424 \n", - "1 2 33022709 277618 64525 33150 51435 \n", - "2 3 35450204 220674 70363 33159 51444 \n", - "3 4 35450295 220674 70363 33110 51395 \n", - "4 5 35450311 220674 70363 33196 51481 \n", + " order_id assignment_id user_id assistment_id problem_id original \\\n", + "0 33022537 277618 64525 33139 51424 1 \n", + "1 33022709 277618 64525 33150 51435 1 \n", + "2 35450204 220674 70363 33159 51444 1 \n", + "3 35450295 220674 70363 33110 51395 1 \n", + "4 35450311 220674 70363 33196 51481 1 \n", "\n", - " original correct attempt_count ms_first_response ... hint_count \\\n", - "0 1 1 1 32454 ... 0 \n", - "1 1 1 1 4922 ... 0 \n", - "2 1 0 2 25390 ... 0 \n", - "3 1 1 1 4859 ... 0 \n", - "4 1 0 14 19813 ... 3 \n", + " correct attempt_count ms_first_response tutor_mode ... hint_count \\\n", + "0 1 1 32454 tutor ... 0 \n", + "1 1 1 4922 tutor ... 0 \n", + "2 0 2 25390 tutor ... 0 \n", + "3 1 1 4859 tutor ... 0 \n", + "4 0 14 19813 tutor ... 3 \n", "\n", - " hint_total overlap_time template_id answer_id answer_text first_action \\\n", - "0 3 32454 30799 NaN 26 0 \n", - "1 3 4922 30799 NaN 55 0 \n", - "2 3 42000 30799 NaN 88 0 \n", - "3 3 4859 30059 NaN 41 0 \n", - "4 4 124564 30060 NaN 65 0 \n", + " hint_total overlap_time template_id answer_id answer_text first_action \\\n", + "0 3 32454 30799 NaN 26 0 \n", + "1 3 4922 30799 NaN 55 0 \n", + "2 3 42000 30799 NaN 88 0 \n", + "3 3 4859 30059 NaN 41 0 \n", + "4 4 124564 30060 NaN 65 0 \n", "\n", - " bottom_hint opportunity opportunity_original \n", - "0 NaN 1 1.0 \n", - "1 NaN 2 2.0 \n", - "2 NaN 1 1.0 \n", - "3 NaN 2 2.0 \n", - "4 0.0 3 3.0 \n", + " bottom_hint opportunity opportunity_original \n", + "0 NaN 1 1.0 \n", + "1 NaN 2 2.0 \n", + "2 NaN 1 1.0 \n", + "3 NaN 2 2.0 \n", + "4 0.0 3 3.0 \n", "\n", - "[5 rows x 31 columns]" + "[5 rows x 30 columns]" ] }, "execution_count": 3, @@ -251,7 +251,8 @@ } ], "source": [ - "raw_data = pd.read_csv('../../data/assistment.csv', encoding = 'utf-8', dtype={'skill_id': str})\n", + "data_path ='../../data/assistment/'\n", + "raw_data = pd.read_csv('../../data/assistment/assistment.csv', encoding = 'utf-8', dtype={'skill_id': str})\n", "raw_data.head()" ] }, @@ -278,11 +279,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total length: 274590\n", - "Number of unique [student_id,question_id]: 270478\n", + "Total length: 325636\n", + "Number of unique [student_id,question_id]: 270477\n", "Number of unique student_id: 4151\n", "Number of unique question_id: 16891\n", - "Number of unique knowledge_id: 138\n" + "Number of unique knowledge_id: 111\n" ] } ], @@ -291,7 +292,9 @@ "stat_unique(all_data, ['student_id', 'question_id'])\n", "stat_unique(all_data, 'student_id')\n", "stat_unique(all_data, 'question_id')\n", - "stat_unique(all_data, 'knowledge_id')" + "stat_unique(all_data, 'knowledge_id')\n", + "ques_num = len(all_data['question_id'].unique())\n", + "know_num = len(all_data['knowledge_id'].unique())" ] }, { @@ -319,7 +322,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 15924 questions\n" + "filter 15370 questions\n" ] } ], @@ -340,7 +343,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 1749 students\n" + "filter 1471 students\n" ] } ], @@ -382,7 +385,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "filter 10 knowledges\n" + "filter 8 knowledges\n" ] } ], @@ -475,12 +478,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total length: 61860\n", - "Number of unique [student_id,question_id]: 59500\n", - "Number of unique student_id: 1505\n", - "Number of unique question_id: 932\n", - "Number of unique knowledge_id: 22\n", - "Average #questions per knowledge: 44.38095238095238\n" + "Total length: 110398\n", + "Number of unique [student_id,question_id]: 78747\n", + "Number of unique student_id: 1940\n", + "Number of unique question_id: 1485\n", + "Number of unique knowledge_id: 35\n", + "Average #questions per knowledge: 59.4\n" ] } ], @@ -500,7 +503,7 @@ "outputs": [], "source": [ "# save selected data\n", - "selected_data.to_csv('selected_data.csv', index=False)" + "selected_data.to_csv(data_path+'selected_data.csv', index=False)" ] }, { @@ -528,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -555,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -566,7 +569,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -575,7 +578,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -585,7 +588,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -611,7 +614,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -632,7 +635,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -643,16 +646,16 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train records length: 51010\n", - "test records length: 10850\n", - "all records length: 61860\n" + "train records length: 60393\n", + "test records length: 50005\n", + "all records length: 110398\n" ] } ], @@ -671,7 +674,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -687,36 +690,36 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ - "save_to_csv(train_data, 'train_triples.csv')\n", - "save_to_csv(test_data, 'test_triples.csv')\n", - "save_to_csv(all_data, 'triples.csv')" + "save_to_csv(train_data, data_path+'train_triples.csv')\n", + "save_to_csv(test_data, data_path+'test_triples.csv')\n", + "save_to_csv(all_data, data_path+'triples.csv')" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ - "metadata = {\"num_students\": 1505, \n", - " \"num_questions\": 932,\n", - " \"num_concepts\": 22, \n", - " \"num_records\": 61860, \n", - " \"num_train_students\": 1444, \n", - " \"num_test_students\": 61}" + "metadata = {\"num_students\": n_students, \n", + " \"num_questions\": ques_num,\n", + " \"num_concepts\": know_num, \n", + " \"num_records\": len(all_data), \n", + " \"num_train_students\": n_students - len(test_students), \n", + " \"num_test_students\": len(test_students)}" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ - "with open('metadata.json', 'w') as f:\n", + "with open(data_path+'metadata.json', 'w') as f:\n", " json.dump(metadata, f)" ] } @@ -737,7 +740,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.16" } }, "nbformat": 4, diff --git a/scripts/test.ipynb b/scripts/test.ipynb index 3583bb7..b8d419c 100644 --- a/scripts/test.ipynb +++ b/scripts/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -55,20 +55,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "seed = 0\n", "np.random.seed(seed)\n", @@ -77,31 +66,27 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "../logs/2021-03-09-14:51/\n" - ] - } - ], + "outputs": [], "source": [ "# tensorboard\n", "log_dir = f\"../logs/{datetime.datetime.now().strftime('%Y-%m-%d-%H:%M')}/\"\n", + "log_dir = f\"../logs/\"\n", "print(log_dir)\n", "writer = SummaryWriter(log_dir)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# choose dataset here\n", + "import CAT.strategy\n", + "\n", + "\n", "dataset = 'assistment'\n", "# modify config here\n", "config = {\n", @@ -114,17 +99,20 @@ " 'prednet_len1': 128,\n", " 'prednet_len2': 64,\n", " # for BOBCAT\n", + " 'policy':'notbobcat',\n", " 'betas': (0.9, 0.999),\n", " 'policy_path': 'policy.pt',\n", " # for NCAT\n", " 'THRESHOLD' :300,\n", " 'start':0,\n", " 'end':3000\n", + " \n", "}\n", "# fixed test length\n", "test_length = 5\n", "# choose strategies here\n", - "strategies = [CAT.strategy.RandomStrategy(), CAT.strategy.MFIStrategy(), CAT.strategy.KLIStrategy()]\n", + "#strategies = [CAT.strategy.RandomStrategy(), CAT.strategy.MFIStrategy(), CAT.strategy.KLIStrategy()]\n", + "strategies = [CAT.strategy.NCATs()]\n", "# modify checkpoint path here\n", "ckpt_path = '../ckpt/irt.pt'\n", "bobcat_policy_path =config['policy_path']" @@ -132,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -145,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -157,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -167,87 +155,29 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO 2021-03-09 14:51:04,289] -----------\n", - "[INFO 2021-03-09 14:51:04,290] start adaptive testing with Random Select Strategy strategy\n", - "[INFO 2021-03-09 14:51:04,291] Iteration 0\n", - "[INFO 2021-03-09 14:51:04,308] auc:0.6484533447389293\n", - "[INFO 2021-03-09 14:51:04,309] cov:0.0\n", - "[INFO 2021-03-09 14:51:04,309] Iteration 1\n", - "[INFO 2021-03-09 14:51:04,344] auc:0.6489562662794149\n", - "[INFO 2021-03-09 14:51:04,347] cov:0.05801618621590955\n", - "[INFO 2021-03-09 14:51:04,349] Iteration 2\n", - "[INFO 2021-03-09 14:51:04,382] auc:0.6487346765890865\n", - "[INFO 2021-03-09 14:51:04,383] cov:0.11609196598657111\n", - "[INFO 2021-03-09 14:51:04,384] Iteration 3\n", - "[INFO 2021-03-09 14:51:04,413] auc:0.6500624347642152\n", - "[INFO 2021-03-09 14:51:04,413] cov:0.1612808712341023\n", - "[INFO 2021-03-09 14:51:04,414] Iteration 4\n", - "[INFO 2021-03-09 14:51:04,443] auc:0.6512111930010926\n", - "[INFO 2021-03-09 14:51:04,443] cov:0.20574638420300764\n", - "[INFO 2021-03-09 14:51:04,444] Iteration 5\n", - "[INFO 2021-03-09 14:51:04,473] auc:0.6514404203673256\n", - "[INFO 2021-03-09 14:51:04,474] cov:0.2428003818854224\n", - "[INFO 2021-03-09 14:51:04,478] -----------\n", - "[INFO 2021-03-09 14:51:04,478] start adaptive testing with Maximum Fisher Information Strategy strategy\n", - "[INFO 2021-03-09 14:51:04,479] Iteration 0\n", - "[INFO 2021-03-09 14:51:04,493] auc:0.6459189955860706\n", - "[INFO 2021-03-09 14:51:04,494] cov:0.0\n", - "[INFO 2021-03-09 14:51:04,495] Iteration 1\n", - "[INFO 2021-03-09 14:51:06,005] auc:0.647302288726674\n", - "[INFO 2021-03-09 14:51:06,007] cov:0.0503951833541452\n", - "[INFO 2021-03-09 14:51:06,009] Iteration 2\n", - "[INFO 2021-03-09 14:51:07,396] auc:0.6485068408332938\n", - "[INFO 2021-03-09 14:51:07,397] cov:0.1017306607056953\n", - "[INFO 2021-03-09 14:51:07,398] Iteration 3\n", - "[INFO 2021-03-09 14:51:08,729] auc:0.6499061213124426\n", - "[INFO 2021-03-09 14:51:08,730] cov:0.14101164847492498\n", - "[INFO 2021-03-09 14:51:08,731] Iteration 4\n", - "[INFO 2021-03-09 14:51:10,169] auc:0.6515281889141593\n", - "[INFO 2021-03-09 14:51:10,169] cov:0.17938349590744032\n", - "[INFO 2021-03-09 14:51:10,170] Iteration 5\n", - "[INFO 2021-03-09 14:51:11,581] auc:0.6532324909839825\n", - "[INFO 2021-03-09 14:51:11,581] cov:0.2149702203321859\n", - "[INFO 2021-03-09 14:51:11,586] -----------\n", - "[INFO 2021-03-09 14:51:11,587] start adaptive testing with Kullback-Leibler Information Strategy strategy\n", - "[INFO 2021-03-09 14:51:11,587] Iteration 0\n", - "[INFO 2021-03-09 14:51:11,600] auc:0.6468982686165439\n", - "[INFO 2021-03-09 14:51:11,601] cov:0.0\n", - "[INFO 2021-03-09 14:51:11,602] Iteration 1\n", - "[INFO 2021-03-09 14:51:11,637] auc:0.6485143490570642\n", - "[INFO 2021-03-09 14:51:11,639] cov:0.0503951833541452\n", - "[INFO 2021-03-09 14:51:11,642] Iteration 2\n", - "[INFO 2021-03-09 14:51:18,253] auc:0.6500137283988079\n", - "[INFO 2021-03-09 14:51:18,254] cov:0.09764420342774864\n", - "[INFO 2021-03-09 14:51:18,255] Iteration 3\n", - "[INFO 2021-03-09 14:51:23,240] auc:0.6511878916169775\n", - "[INFO 2021-03-09 14:51:23,241] cov:0.1447750201916595\n", - "[INFO 2021-03-09 14:51:23,242] Iteration 4\n", - "[INFO 2021-03-09 14:51:27,671] auc:0.6521859675699103\n", - "[INFO 2021-03-09 14:51:27,672] cov:0.1899559742455468\n", - "[INFO 2021-03-09 14:51:27,673] Iteration 5\n", - "[INFO 2021-03-09 14:51:32,046] auc:0.6539360633321265\n", - "[INFO 2021-03-09 14:51:32,046] cov:0.21959693579152112\n" - ] - } - ], + "outputs": [], "source": [ "for strategy in strategies:\n", " avg =[]\n", " model = CAT.model.IRTModel(**config)\n", - " model = CAT.model.NCDModel(**config)\n", + " #model = CAT.model.NCDModel(**config)\n", " model.init_model(test_data)\n", - " model.adaptest_load(ckpt_path,bobcat_policy_path)\n", + " model.adaptest_load(ckpt_path)\n", " test_data.reset()\n", " print(strategy.name)\n", " if strategy.name == 'NCAT':\n", " selected_questions = strategy.adaptest_select(test_data,concept_map,config,test_length)\n", + " for it in range(test_length):\n", + " for student, questions in selected_questions.items():\n", + " test_data.apply_selection(student, questions[it]) \n", + " model.adaptest_update(test_data)\n", + " results = model.evaluate(test_data)\n", + " # log results\n", + " logging.info(f'Iteration {it}')\n", + " for name, value in results.items():\n", + " logging.info(f'{name}:{value}')\n", " continue\n", " if strategy.name == 'BOBCAT':\n", " real = {}\n", @@ -281,14 +211,12 @@ " tmp[selected_questions[sid]] = real[sid][selected_questions[sid]]\n", " S_sel[sid].append(tmp)\n", " elif it == 1 and strategy.name == 'BECAT Strategy':\n", - " if it == 1 and strategy.name == 'BECAT Strategy':\n", " for sid in range(test_data.num_students):\n", " untested_questions = np.array(list(test_data.untested[sid]))\n", " random_index = random.randint(0, len(untested_questions)-1)\n", " selected_questions[sid] = untested_questions[random_index]\n", " S_sel[sid].append(untested_questions[random_index])\n", - " elif strategy.name == 'BECAT Strategy': \n", - " elif strategy.name == 'BECAT Strategy': \n", + " elif strategy.name == 'BECAT Strategy': \n", " selected_questions = strategy.adaptest_select(model, test_data,S_sel)\n", " for sid in range(test_data.num_students):\n", " S_sel[sid].append(selected_questions[sid])\n", @@ -296,8 +224,6 @@ " selected_questions = strategy.adaptest_select(model, test_data)\n", " for student, question in selected_questions.items():\n", " test_data.apply_selection(student, question) \n", - " for student, question in selected_questions.items():\n", - " test_data.apply_selection(student, question)\n", " \n", " # update models\n", " model.adaptest_update(test_data)\n", @@ -333,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.16" } }, "nbformat": 4, diff --git a/scripts/train.ipynb b/scripts/train.ipynb index c3a27cb..3c70f73 100644 --- a/scripts/train.ipynb +++ b/scripts/train.ipynb @@ -2,12 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "import CAT\n", "import sys\n", + "sys.path.append('..')\n", + "import CAT\n", "import json\n", "import logging\n", "import numpy as np\n", @@ -16,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -41,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -52,17 +53,18 @@ " 'learning_rate': 0.002,\n", " 'batch_size': 2048,\n", " 'num_epochs': 10,\n", - " 'num_dim': 10, # for IRT or MIRT\n", + " 'num_dim': 1, # for IRT or MIRT\n", " 'device': 'cpu',\n", " # for NeuralCD\n", " 'prednet_len1': 128,\n", " 'prednet_len2': 64,\n", + " 'betas': (0.9, 0.999),\n", "}" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -75,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -87,46 +89,22 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO 2021-03-09 14:46:27,480] train on cpu\n", - "[INFO 2021-03-09 14:46:27,579] Epoch [1] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:27,920] Epoch [1] Batch [10]: loss=0.76105\n", - "[INFO 2021-03-09 14:46:28,267] Epoch [1] Batch [20]: loss=0.72569\n", - "[INFO 2021-03-09 14:46:28,482] Epoch [2] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:28,826] Epoch [2] Batch [10]: loss=0.75531\n", - "[INFO 2021-03-09 14:46:29,177] Epoch [2] Batch [20]: loss=0.72002\n", - "[INFO 2021-03-09 14:46:29,391] Epoch [3] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:29,732] Epoch [3] Batch [10]: loss=0.74957\n", - "[INFO 2021-03-09 14:46:30,080] Epoch [3] Batch [20]: loss=0.71441\n", - "[INFO 2021-03-09 14:46:30,242] Epoch [4] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:30,638] Epoch [4] Batch [10]: loss=0.74343\n", - "[INFO 2021-03-09 14:46:30,969] Epoch [4] Batch [20]: loss=0.70825\n", - "[INFO 2021-03-09 14:46:31,129] Epoch [5] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:31,505] Epoch [5] Batch [10]: loss=0.73560\n", - "[INFO 2021-03-09 14:46:31,855] Epoch [5] Batch [20]: loss=0.70078\n", - "[INFO 2021-03-09 14:46:32,031] Epoch [6] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:32,412] Epoch [6] Batch [10]: loss=0.72559\n", - "[INFO 2021-03-09 14:46:32,767] Epoch [6] Batch [20]: loss=0.69064\n", - "[INFO 2021-03-09 14:46:32,938] Epoch [7] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:33,297] Epoch [7] Batch [10]: loss=0.71256\n", - "[INFO 2021-03-09 14:46:33,618] Epoch [7] Batch [20]: loss=0.67773\n", - "[INFO 2021-03-09 14:46:33,777] Epoch [8] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:34,124] Epoch [8] Batch [10]: loss=0.69541\n", - "[INFO 2021-03-09 14:46:34,517] Epoch [8] Batch [20]: loss=0.66073\n", - "[INFO 2021-03-09 14:46:34,684] Epoch [9] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:35,024] Epoch [9] Batch [10]: loss=0.67549\n", - "[INFO 2021-03-09 14:46:35,412] Epoch [9] Batch [20]: loss=0.64151\n", - "[INFO 2021-03-09 14:46:35,569] Epoch [10] Batch [0]: loss=inf\n", - "[INFO 2021-03-09 14:46:35,911] Epoch [10] Batch [10]: loss=0.65485\n", - "[INFO 2021-03-09 14:46:36,296] Epoch [10] Batch [20]: loss=0.62124\n" + "ename": "KeyError", + "evalue": "'betas'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_27436\\1368016865.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCAT\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIRTModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# train model\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_step\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\workplace\\tmp\\EduCAT\\CAT\\model\\IRT.py\u001b[0m in \u001b[0;36minit_model\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 54\u001b[0m \u001b[0mpolicy_lr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.0005\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mIRT\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_students\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'num_dim'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 56\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mStraightThrough\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mpolicy_lr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 57\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_q\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 58\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32md:\\workplace\\tmp\\EduCAT\\CAT\\model\\utils.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, state_dim, action_dim, lr, config)\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 35\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'device'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 36\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbetas\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'betas'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 37\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mActor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maction_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 38\u001b[0m self.optimizer = torch.optim.Adam(\n", + "\u001b[1;31mKeyError\u001b[0m: 'betas'" ] } ], @@ -140,12 +118,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save model\n", - "model.adaptest_save('../ckpt/mirt.pt')" + "model.adaptest_save('../ckpt/irt.pt')" ] } ], @@ -165,7 +143,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.16" } }, "nbformat": 4,