Skip to content

Commit

Permalink
Making character model more legible
Browse files Browse the repository at this point in the history
and a few other minor changes
  • Loading branch information
mmistele committed Jun 8, 2018
1 parent 990cb47 commit 14eb039
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
14 changes: 10 additions & 4 deletions keras/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from keras.models import Model
from keras.layers import Dense, Input, Dropout, LSTM, Activation, Masking, Bidirectional
from keras.layers import Dense, Input, Dropout, LSTM, Activation, Masking, Bidirectional, Lambda

from custom_layers import MeanPool

Expand All @@ -23,11 +23,17 @@ def Character_Model_1(input_shape):

X = Masking(mask_value = 0., input_shape=input_shape)(sentences)

X = Bidirectional(LSTM(128, return_sequences = False, dropout=0.2, recurrent_dropout=0.2), merge_mode='ave')(X)
# X = Bidirectional(LSTM(128, return_sequences = False, dropout=0.2, recurrent_dropout=0.2), merge_mode='ave')(X)
X = LSTM(128, return_sequences = True, dropout=0.2, recurrent_dropout=0.2, name='LSTM1')(X)
def get_last(X):
return X[:,-1,:]
X = Lambda(get_last, name = 'LSTM-last')(X)

# X = Bidirectional(LSTM(128, return_sequences = True, dropout=0.2, recurrent_dropout=0.2))(X)
# X = LSTM(128, dropout=0.2, recurrent_dropout=0.2)(X)
X = Dense(1)(X)
X = Activation('sigmoid')(X)

X = Dense(1, name='Dense1')(X)
X = Activation('sigmoid', name='output_layer')(X)

model = Model(inputs = sentences, outputs = X)
return model
Expand Down
5 changes: 3 additions & 2 deletions keras/train_character_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
X_dev, Y_dev = read_csv('data/dev.csv')

maxLen = max(len(max(X_train, key=len)), len(max(X_dev, key=len)))
print "Max length: %s" % maxLen

counts = get_char_counts_from_csv('data/train.csv') + get_char_counts_from_csv('data/dev.csv')
most_common = counts.most_common()
Expand All @@ -34,9 +35,9 @@
model = Character_Model_1((None, X_train_indices.shape[2]))

optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=0.0, epsilon=None)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy', 'loss'])
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
tensorboard = TensorBoard(log_dir="logs/character-model/{}".format(time()))
model.fit(X_train_indices, Y_train, epochs = 1, batch_size = 6, shuffle=True, callbacks = [tensorboard], validation_split = 0.2)
model.fit(X_train_indices, Y_train, epochs = 20, batch_size = 20, shuffle=True, callbacks = [tensorboard], validation_split = 0.2)

X_dev_indices = strings_to_character_vecs(X_dev, char_to_index, maxLen, alphabet_size)
loss, acc = model.evaluate(X_dev_indices, Y_dev)
Expand Down
11 changes: 6 additions & 5 deletions keras/train_word_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from utils import read_csv, read_glove_vecs

X_train, Y_train = read_csv('data/train.csv')
X_test, Y_test = read_csv('data/dev.csv')
X_dev, Y_dev = read_csv('data/dev.csv')

maxLen = len(max(X_train, key=len).split())
maxLen = max(len(max(X_train, key=len).split()), len(max(X_dev, key=len).split()))
print "Max length: %s" % maxLen

word_to_index, index_to_word, word_to_vec_map = read_glove_vecs('data/glove.6B.50d.txt')
embedding_layer = pretrained_embedding_layer(word_to_vec_map, word_to_index)
Expand All @@ -28,8 +29,8 @@
tensorboard = TensorBoard(log_dir="logs/word-model/{}".format(time()))
model.fit(X_train_indices, Y_train, epochs = 20, batch_size = 6, shuffle=True, callbacks = [tensorboard])

X_test_indices = strings_to_word_indices(X_test, word_to_index, max_len = maxLen)
loss, acc = model.evaluate(X_test_indices, Y_test)
X_dev_indices = strings_to_word_indices(X_dev, word_to_index, max_len = maxLen)
loss, acc = model.evaluate(X_dev_indices, Y_dev)
model.save('word-model.h5')
print()
print("Test accuracy = ", acc)
print("Dev accuracy = ", acc)

0 comments on commit 14eb039

Please sign in to comment.