Skip to content

Commit

Permalink
add parameters for NCD
Browse files Browse the repository at this point in the history
  • Loading branch information
nnnyt committed Mar 9, 2021
1 parent 45a975c commit 21d7276
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions CAT/model/NCD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class NCD(nn.Module):
'''
NeuralCDM
'''
def __init__(self, student_n, exer_n, knowledge_n):
def __init__(self, student_n, exer_n, knowledge_n, prednet_len1=128, prednet_len2=64):
self.knowledge_dim = knowledge_n
self.exer_n = exer_n
self.emb_num = student_n
self.stu_dim = self.knowledge_dim
self.prednet_input_len = self.knowledge_dim
self.prednet_len1, self.prednet_len2 = 512, 256 # changeable
self.prednet_len1, self.prednet_len2 = prednet_len1, prednet_len2 # changeable

super(NCD, self).__init__()

Expand Down Expand Up @@ -96,7 +96,7 @@ def name(self):
return 'Neural Cognitive Diagnosis'

def init_model(self, data: Dataset):
self.model = NCD(data.num_students, data.num_questions, data.num_concepts)
self.model = NCD(data.num_students, data.num_questions, data.num_concepts, self.config['prednet_len1'], self.config['prednet_len2'])

def train(self, train_data: TrainDataset):
lr = self.config['learning_rate']
Expand Down
5 changes: 4 additions & 1 deletion scripts/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@
" 'learning_rate': 0.0025,\n",
" 'batch_size': 2048,\n",
" 'num_epochs': 8,\n",
" 'num_dim': 1,\n",
" 'num_dim': 1, # for IRT or MIRT\n",
" 'device': 'cpu',\n",
" # for NeuralCD\n",
" 'prednet_len1': 128,\n",
" 'prednet_len2': 64,\n",
"}\n",
"# fixed test length\n",
"test_length = 5\n",
Expand Down
5 changes: 4 additions & 1 deletion scripts/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@
" 'learning_rate': 0.002,\n",
" 'batch_size': 2048,\n",
" 'num_epochs': 10,\n",
" 'num_dim': 10,\n",
" 'num_dim': 10, # for IRT or MIRT\n",
" 'device': 'cpu',\n",
" # for NeuralCD\n",
" 'prednet_len1': 128,\n",
" 'prednet_len2': 64,\n",
"}"
]
},
Expand Down

0 comments on commit 21d7276

Please sign in to comment.