-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathdataprep.py
379 lines (331 loc) · 12.3 KB
/
dataprep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# -*- coding: utf-8 -*-
from __future__ import division, print_function, unicode_literals
from datetime import datetime
from collections import namedtuple, defaultdict
from operator import itemgetter
from io import open
import os
import time
import json
import numpy as np
from yata.fields import Words, Categorical
from yata.loaders import TableLoader
from random import sample, randint, shuffle, seed
# 做题记录
Record = namedtuple('Record', ['user', 'school', 'topic', 'exam',
'score', 'time'])
# 学生做题序列,其中bias用于执行考试内打乱
Item = namedtuple('Item', ['topic', 'score', 'time', 'bias'])
class Dataset:
def __init__(self, random_level=1):
"""
构造一个空的Dataset
:param random_level: 获取序列时随机打乱的程度,0为不打乱,1为考试内打乱,2为全
部打乱。默认1
"""
self.topics = set()
self.exams = set()
self.users = set()
self.schools = defaultdict(set)
self.records = list()
self.user_school_map = dict()
self.topic_exam_map = dict()
self.random_level = random_level
@staticmethod
def from_matrix(filename):
f = open(filename, encoding='utf-8')
d = Dataset()
for line in f:
desc, seq = line.strip().split('\t')
scl, exam, grd = desc.strip().split('@')
# 去除异常考试
if exam == 'e061cfc1-86e9-4486-abc7-0630d0c1ad2d':
continue
if exam == '8c7a9923-0856-4fed-b326-c0f124129b86':
continue
data_fields = seq.strip().split(' ')
exam_time = int(data_fields[0])
for record in data_fields[5:]:
user, topic, score, std_score = record.split('@')
score, std_score = float(score), float(std_score)
r = Record(user, scl, topic, exam,
score / std_score, exam_time)
d._insert(r)
return d
@staticmethod
def from_records(dirname):
d = Dataset()
record_f = open(os.path.join(dirname, 'records.txt'), encoding='utf-8')
school_f = open(os.path.join(dirname, 'schools.txt'), encoding='utf-8')
exam_f = open(os.path.join(dirname, 'exams.txt'), encoding='utf-8')
topic_exam_map = dict()
exam_info = dict()
student_scl_map = dict()
for line in exam_f:
fields = line.strip().split(' ')
exam_id, exam_type, exam_time, _ = fields[0].split(',')
exam_id = exam_id + '_' + exam_type
exam_time = \
int(time.mktime(datetime.strptime(exam_time,
'%Y-%m-%d').timetuple()))
exam_info[exam_id] = exam_time
for topic in fields[1:]:
topic_exam_map[topic] = exam_id
for line in school_f:
fields = line.strip().split(' ')
school_id = fields[0]
for student in fields[1:]:
student_scl_map[student] = school_id
exam_by_school = defaultdict(set)
for line in record_f:
fields = line.strip().split(' ')
student_id, _, _ = fields[0].split(',')
if student_id not in student_scl_map:
continue
for item in fields[1:]:
topic_id, score = item.split(',')
score = float(score)
scl_id = student_scl_map[student_id]
exam_id = topic_exam_map[topic_id]
r = Record(student_id, scl_id, topic_id, exam_id,
score, exam_info[exam_id])
exam_by_school[scl_id].add(exam_id)
d._insert(r)
return d
def select(self, filter):
"""
选择满足条件的记录集合,分别返回满足和不满足的两个dataset
:param filter: 判断条件(函数)
:return: selected, others
"""
selected = Dataset()
others = Dataset()
for r in self.records:
if filter(r):
selected._insert(r)
else:
others._insert(r)
return selected, others
def split(self):
"""
划分数据为训练集、测试集(新学生、新考试、学生考试都未出现)
:return: train, user, exam, new
"""
train = Dataset()
user = Dataset()
exam = Dataset()
new = Dataset()
train_exams = []
schools = dict()
for s in self.schools:
schools[s] = sorted(list(self.schools[s]),
key=itemgetter(1))
for s in schools:
train_exams.extend([x[0] for x in schools[s][:-1]])
train_exams = set(train_exams)
train_users = sample(sorted(self.users), 0.9)
train_users = set(train_users)
for r in self.records:
if r.exam in train_exams and r.user in train_users:
# train set
train._insert(r)
elif r.exam in train_exams:
# new user
user._insert(r)
elif r.user in train_users:
# new exam
exam._insert(r)
else:
# completely new record
new._insert(r)
train.random_level = self.random_level
user.random_level = self.random_level
exam.random_level = self.random_level
new.random_level = self.random_level
return train, user, exam, new
def split_future(self, frac, rand_seed=324):
seq = self.get_seq()
train_data = Dataset()
test_data = Dataset()
seed(rand_seed)
for user in seq:
school = self.user_school_map[user]
u_seq = seq[user]
train_len = int(frac * len(u_seq))
for topic, score, time, _ in u_seq[:train_len]:
exam = self.topic_exam_map[topic]
train_data._insert(Record(user, school, topic, exam,
score, time))
for topic, score, time, _ in u_seq[train_len:]:
exam = self.topic_exam_map[topic]
test_data._insert(Record(user, school, topic, exam,
score, time))
return train_data, test_data
def split_user(self, frac, rand_seed=101):
seed(rand_seed)
train_users = sample(sorted(self.users),
int(len(self.users) * frac))
train_users = set(train_users)
train_data = Dataset()
test_data = Dataset()
for r in self.records:
if r.user in train_users:
train_data._insert(r)
else:
test_data._insert(r)
return train_data, test_data
def get_seq(self):
"""
返回每个学生的做题序列,根据设定的打乱程度(random_level,0为不打乱,1为考试内打
乱,2为全部打乱)对序列进行随机打乱
:return: 一个学生到该学生做题记录(Item)序列的字典
"""
seq = defaultdict(list)
for r in self.records:
seq[r.user].append(Item(r.topic, r.score,
r.time, randint(-5000, 5000)))
for user in seq:
if self.random_level == 1:
seq[user].sort(key=lambda x: x.time + x.bias)
elif self.random_level == 2:
shuffle(seq[user])
return seq
def get_dict(self):
"""
返回学生、题目的序号以及反查表
:return: 学生序号、序号反查、题目序号、序号反查
"""
user_dic = {}
topic_dic = {}
user_inv_dic = {}
topic_inv_dic = {}
for i, user in enumerate(sorted(self.users)):
user_dic[user] = i + 1
user_inv_dic[i + 1] = user
for i, topic in enumerate(sorted(self.topics)):
topic_dic[topic] = i + 1
topic_inv_dic[i + 1] = topic
return user_dic, user_inv_dic, topic_dic, topic_inv_dic
def save(self, filename):
f = open(filename, 'w')
json.dump(self.records, f)
f.close()
def load(self, filename):
f = open(filename)
records = json.load(f)
for r in records:
self._insert(Record(*r))
def _insert(self, r):
self.topics.add(r.topic)
self.exams.add(r.exam)
self.users.add(r.user)
self.schools[r.school].add((r.exam, r.time))
self.user_school_map[r.user] = r.school
self.topic_exam_map[r.topic] = r.exam
self.records.append(r)
def get_dataset(type, random_level=0):
"""
返回数据集
:param type: {full,some}[_test]
:param random_level: 0为不打乱,1为考试内打乱,2为全部打乱,默认1
:return: 对应数据集
"""
some_schools = ['2300000001000000032',
'2300000001000674122',
'4444000020000000449',
'2300000001000649665',
'2300000001000053674',
'2300000001000649702']
some_test_schools = ['4444000020000000470']
if type.startswith('full'):
if type.endswith('test'):
rv = Dataset.from_records('data/test')
else:
rv = Dataset.from_records('data/full')
elif type.startswith('some'):
d = Dataset.from_matrix('data/02.10.matrix')
if type.endswith('test'):
rv, _ = d.select(lambda r: r.school in some_test_schools)
else:
rv, _ = d.select(lambda r: r.school in some_schools)
else:
rv = Dataset()
rv.load(type)
rv.random_level = random_level
return rv
def load_embedding(filename):
f = open(filename, encoding='utf-8')
wcnt, emb_size = next(f).strip().split(' ')
wcnt = int(wcnt)
emb_size = int(emb_size)
words = []
embs = []
for line in f:
fields = line.strip().split(' ')
word = fields[0]
emb = np.array([float(x) for x in fields[1:]])
words.append(word)
embs.append(emb)
embs = np.asarray(embs)
return wcnt, emb_size, words, embs
def get_topics(type, words):
if type.startswith('some'):
feature_file = 'data/features.dump.some'
else:
feature_file = 'data/features.dump.full'
to_categorical = Categorical('</s>')
to_categorical.load_dict(words)
topic_fields = {
'2->content': to_categorical(Words(':', null='</s>')),
}
topics = TableLoader(feature_file, with_header=False,
key=0, fields=topic_fields, index=['content'])
return topics
if __name__ == '__main__':
# data = Dataset()
# data.load('data/raw50/full_sampled.json')
data = get_dataset('full', random_level=2)
# print('#topic', len(data.topics))
# print('#exam', len(data.exams))
# print('#users', len(data.users))
# print('#school', len(data.schools))
# print('#records', len(data.records))
data = data.get_seq()
users = list(data)
shuffle(users)
for user in users[:20]:
seq = data[user]
for item in seq:
print(item.topic, item.score, sep=',', end=' ')
print()
# f = open('data/id_dict', 'w')
# json.dump(data.get_dict(), f)
# f.close()
'''
print(data.topics)
trainset, test_user, test_exam, test_new = data.split()
print('trainset:')
print('#topic', len(trainset.topics))
print('#exam', len(trainset.exams))
print('#users', len(trainset.users))
print('#school', len(trainset.schools))
print('#records', len(trainset.records))
print('user:')
print('#topic', len(test_user.topics))
print('#exam', len(test_user.exams))
print('#users', len(test_user.users))
print('#school', len(test_user.schools))
print('#records', len(test_user.records))
print('exam:')
print('#topic', len(test_exam.topics))
print('#exam', len(test_exam.exams))
print('#users', len(test_exam.users))
print('#school', len(test_exam.schools))
print('#records', len(test_exam.records))
print('new:')
print('#topic', len(test_new.topics))
print('#exam', len(test_new.exams))
print('#users', len(test_new.users))
print('#school', len(test_new.schools))
print('#records', len(test_new.records))
'''