forked from zeng-zhen/SCD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_stu.py
58 lines (49 loc) · 1.76 KB
/
build_stu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import random
import torch
k = 2
pmin = 0.4
def remake(args, name=None): # u→e
data_file = 'data/ASSIST/graph/e_from_u.txt'
s_log_dict = torch.zeros(args.student_n)
s_log_dict = s_log_dict.tolist()
e_log_dict = torch.zeros(args.exer_n)
e_log_dict = e_log_dict.tolist()
data = []
with open(data_file, encoding='utf8') as i_f:
for line in i_f:
stu, exer = line.replace('\n', '').split('\t')
stu = int(stu)
exer = int(exer)
data.append((stu, exer))
s_log_dict[stu] = s_log_dict[stu] + 1
e_log_dict[exer] = e_log_dict[exer] + 1
u_from_e = '' # e(src) to k(dst)
e_from_u = '' # k(src) to k(dst)
i = 0
j = 0
s_log_dict = torch.Tensor(s_log_dict)
e_log_dict = torch.Tensor(e_log_dict)
e_log_dict = k / torch.log(e_log_dict + 1 + 10e-5)
e = torch.full((args.exer_n,), pmin)
e_log_dict = torch.where(e_log_dict > pmin, e_log_dict, e)
s_log_dict = k / torch.log(s_log_dict + 1 + 10e-5)
s = torch.full((args.student_n,), pmin)
s_log_dict = torch.where(s_log_dict > pmin, s_log_dict, s)
random.seed()
e_rand = torch.rand(len(data)).tolist()
random.seed()
s_rand = torch.rand(len(data)).tolist()
t = 0
for stu, exer in data:
if e_rand[t] < e_log_dict[exer]:
e_from_u += str(stu) + '\t' + str(exer) + '\n'
i += 1
if s_rand[t] < s_log_dict[stu]:
u_from_e += str(exer) + '\t' + str(stu) + '\n'
j += 1
t += 1
if name is not None:
with open('./view/u_from_e' + name + '.txt', 'w') as f:
f.write(u_from_e)
with open('./view/e_from_u' + name + '.txt', 'w') as f:
f.write(e_from_u)