-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2024.1.12 Add new CAT Strategy #10
Conversation
CAT/strategy/NCAT_strategy.py
Outdated
def forward(self, x): | ||
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | ||
|
||
class IRT(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove IRT from this file. Make sure to implement all CDMs in /model and only implement the question selection strategy in /strategy
CAT/strategy/NCAT_strategy.py
Outdated
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) | ||
|
||
THRESHOLD = 300 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use args to set these parameters instead of global variables
CAT/strategy/NCAT_strategy.py
Outdated
from sklearn.metrics import accuracy_score | ||
from collections import namedtuple | ||
import torch.nn as nn | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean your code, remove duplicates
CAT/model/IRT.py
Outdated
@@ -94,7 +143,13 @@ def adaptest_load(self, path): | |||
""" | |||
self.model.load_state_dict(torch.load(path), strict=False) | |||
self.model.to(self.config['device']) | |||
|
|||
def adaptest_load_BOBCAT(self, path, policy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add args in adaptest_load
to control whether to load policy instead of a new function
CAT/model/IRT.py
Outdated
@@ -52,8 +98,11 @@ def name(self): | |||
return 'Item Response Theory' | |||
|
|||
def init_model(self, data: Dataset): | |||
betas = (0.9, 0.999) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add args to set these hyper parameters instead of directly setting in the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if it is not necessary for all IRT models, use args to control
CAT/model/IRT.py
Outdated
from math import exp as exp | ||
from sklearn.metrics import roc_auc_score | ||
from scipy import integrate | ||
import time | ||
from CAT.model.abstract_model import AbstractModel | ||
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset | ||
from sklearn.metrics import accuracy_score | ||
from collections import namedtuple | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not set device here. We will set device through config
CAT/strategy/NCAT_strategy.py
Outdated
loss += bz_loss.data.float() | ||
ls = ls + loss | ||
return ls/(epochs) | ||
def _loss_function(self, pred, real): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure your code is well-formatted, such as keeping a blank line between these functions
CAT/strategy/NCAT_strategy.py
Outdated
def load_data(self, ncatdata,concept): | ||
|
||
return ncatdata.data, ncatdata.num_questions, concept | ||
def load_CDM(self,name,data,pth_path,config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to use some auto tools to check the format of your code. blank line between functions, space after the comma, etc.
CAT/strategy/NCAT_strategy.py
Outdated
return model ,data.data | ||
def step(self, action,sid): | ||
assert action in self.sup_rates[self.state[0][0]] and action not in self.short | ||
reward, ACC, AUC, rate = self.reward(action,sid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend to use snake case for variables
" test_data.reset()\n", | ||
" \n", | ||
" print(strategy.name)\n", | ||
" if strategy.name == 'NCAT':\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is hard to unify all the strategies, I would recommend to write train/test examples for them in seperate scripts, such as train_NCAT.ipynb
, test_NCAT.ipynb
, test_BOBCAT.ipynb
.
CAT/model/NCAT.py
Outdated
def forward(self, x): | ||
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | ||
|
||
class CDM(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between this CDM and IRTModel in IRT.py
? If it's the same, why not import IRTModel as the CDM?
CAT/strategy/NCAT_strategy.py
Outdated
for sid in range(adaptest_data.num_students): | ||
NCATdata = adaptest_data | ||
model = NCATModel(NCATdata,concept_map,config,test_length) | ||
THRESHOLD = config['THRESHOLD'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use uppercase letters here? use snake case for variables
CAT/model/NCAT.py
Outdated
G.append(sum(map(lambda x: math.pow(alpha, x - 1), dict(Counter(items)).values())) / math.log(i + 2, 2)) | ||
return sum(G) | ||
|
||
class NCATModel(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a model for question selection strategy and not a CD model right? Correct me if I am wrong - I have not read the NCAT paper yet. /model
should contain code for CDMs and /strategy
contains code for question selection strategy. If NCATModel is not a CDM, it should not be placed in this folder. Modularize your code
Add new features, add code for BOBCAT and NCAT.