forked from jessefreeman/MarathonTerminalGenerator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
24 lines (20 loc) · 781 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from textgenrnn import textgenrnn
from config import *
textgen = textgenrnn(name="./weights/" + model_name)
train_function = textgen.train_from_file if train_cfg['line_delimited'] else textgen.train_from_largetext_file
train_function(
file_path="./datasets/" + file_name,
new_model=True,
num_epochs=train_cfg['num_epochs'],
gen_epochs=train_cfg['gen_epochs'],
batch_size=1024,
train_size=train_cfg['train_size'],
dropout=train_cfg['dropout'],
validation=train_cfg['validation'],
is_csv=train_cfg['is_csv'],
rnn_layers=model_cfg['rnn_layers'],
rnn_size=model_cfg['rnn_size'],
rnn_bidirectional=model_cfg['rnn_bidirectional'],
max_length=model_cfg['max_length'],
dim_embeddings=100,
word_level=model_cfg['word_level'])