{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "import CAT\n", "import json\n", "import logging\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def setuplogger():\n", " root = logging.getLogger()\n", " root.setLevel(logging.INFO)\n", " handler = logging.StreamHandler(sys.stdout)\n", " handler.setLevel(logging.INFO)\n", " formatter = logging.Formatter(\"[%(levelname)s %(asctime)s] %(message)s\")\n", " handler.setFormatter(formatter)\n", " root.addHandler(handler)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "setuplogger()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# choose dataset here\n", "dataset = 'assistment'\n", "# modify config here\n", "config = {\n", " 'learning_rate': 0.002,\n", " 'batch_size': 2048,\n", " 'num_epochs': 10,\n", " 'num_dim': 1, # for IRT or MIRT\n", " 'device': 'cpu',\n", " # for NeuralCD\n", " 'prednet_len1': 128,\n", " 'prednet_len2': 64,\n", " 'betas': (0.9, 0.999),\n", "}" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# read datasets\n", "train_triplets = pd.read_csv(f'../data/{dataset}/train_triples.csv', encoding='utf-8').to_records(index=False)\n", "concept_map = json.load(open(f'../data/{dataset}/concept_map.json', 'r'))\n", "concept_map = {int(k):v for k,v in concept_map.items()}\n", "metadata = json.load(open(f'../data/{dataset}/metadata.json', 'r'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_data = CAT.dataset.TrainDataset(train_triplets, concept_map,\n", " metadata['num_train_students'], \n", " metadata['num_questions'], \n", " metadata['num_concepts'])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "ename": "KeyError", "evalue": "'betas'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_27436\\1368016865.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mCAT\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIRTModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# train model\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlog_step\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32md:\\workplace\\tmp\\EduCAT\\CAT\\model\\IRT.py\u001b[0m in \u001b[0;36minit_model\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 54\u001b[0m \u001b[0mpolicy_lr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.0005\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mIRT\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_students\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'num_dim'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 56\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mStraightThrough\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mpolicy_lr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 57\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_q\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnum_questions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 58\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32md:\\workplace\\tmp\\EduCAT\\CAT\\model\\utils.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, state_dim, action_dim, lr, config)\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 35\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'device'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 36\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbetas\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'betas'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 37\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mActor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maction_dim\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 38\u001b[0m self.optimizer = torch.optim.Adam(\n", "\u001b[1;31mKeyError\u001b[0m: 'betas'" ] } ], "source": [ "# define model here\n", "model = CAT.model.IRTModel(**config)\n", "# train model\n", "model.init_model(train_data)\n", "model.train(train_data, log_step=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save model\n", "model.adaptest_save('../ckpt/irt.pt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.16" } }, "nbformat": 4, "nbformat_minor": 4 }