Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2024.1.12 Add new CAT Strategy #10

Merged
merged 12 commits into from
Jan 15, 2024
Merged

2024.1.12 Add new CAT Strategy #10

merged 12 commits into from
Jan 15, 2024

Conversation

Hhhhhhand
Copy link
Collaborator

Add new features, add code for BOBCAT and NCAT.

def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))

class IRT(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove IRT from this file. Make sure to implement all CDMs in /model and only implement the question selection strategy in /strategy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

THRESHOLD = 300
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use args to set these parameters instead of global variables

from sklearn.metrics import accuracy_score
from collections import namedtuple
import torch.nn as nn
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean your code, remove duplicates

CAT/model/IRT.py Outdated
@@ -94,7 +143,13 @@ def adaptest_load(self, path):
"""
self.model.load_state_dict(torch.load(path), strict=False)
self.model.to(self.config['device'])

def adaptest_load_BOBCAT(self, path, policy):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add args in adaptest_load to control whether to load policy instead of a new function

CAT/model/IRT.py Outdated
@@ -52,8 +98,11 @@ def name(self):
return 'Item Response Theory'

def init_model(self, data: Dataset):
betas = (0.9, 0.999)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add args to set these hyper parameters instead of directly setting in the function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if it is not necessary for all IRT models, use args to control

CAT/model/IRT.py Outdated
from math import exp as exp
from sklearn.metrics import roc_auc_score
from scipy import integrate
import time
from CAT.model.abstract_model import AbstractModel
from CAT.dataset import AdapTestDataset, TrainDataset, Dataset
from sklearn.metrics import accuracy_score
from collections import namedtuple

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not set device here. We will set device through config

loss += bz_loss.data.float()
ls = ls + loss
return ls/(epochs)
def _loss_function(self, pred, real):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure your code is well-formatted, such as keeping a blank line between these functions

def load_data(self, ncatdata,concept):

return ncatdata.data, ncatdata.num_questions, concept
def load_CDM(self,name,data,pth_path,config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use some auto tools to check the format of your code. blank line between functions, space after the comma, etc.

return model ,data.data
def step(self, action,sid):
assert action in self.sup_rates[self.state[0][0]] and action not in self.short
reward, ACC, AUC, rate = self.reward(action,sid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend to use snake case for variables

" test_data.reset()\n",
" \n",
" print(strategy.name)\n",
" if strategy.name == 'NCAT':\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is hard to unify all the strategies, I would recommend to write train/test examples for them in seperate scripts, such as train_NCAT.ipynb, test_NCAT.ipynb, test_BOBCAT.ipynb.

def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))

class CDM(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between this CDM and IRTModel in IRT.py? If it's the same, why not import IRTModel as the CDM?

for sid in range(adaptest_data.num_students):
NCATdata = adaptest_data
model = NCATModel(NCATdata,concept_map,config,test_length)
THRESHOLD = config['THRESHOLD']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use uppercase letters here? use snake case for variables

G.append(sum(map(lambda x: math.pow(alpha, x - 1), dict(Counter(items)).values())) / math.log(i + 2, 2))
return sum(G)

class NCATModel():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a model for question selection strategy and not a CD model right? Correct me if I am wrong - I have not read the NCAT paper yet. /model should contain code for CDMs and /strategy contains code for question selection strategy. If NCATModel is not a CDM, it should not be placed in this folder. Modularize your code

@nnnyt nnnyt merged commit 76a7cc3 into bigdata-ustc:master Jan 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants