Skip to content

Commit

Permalink
Performance optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jul 22, 2019
1 parent 4132abd commit 7e7e708
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
41 changes: 31 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def lr_decay(optimizer, epoch, decay_rate, init_lr):
lr = init_lr * ((1-decay_rate)**epoch)
print( " Learning rate is setted as:", lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if param_group['name'] == 'aggr':
param_group['lr'] = lr * 2.
else:
param_group['lr'] = lr
return optimizer


Expand Down Expand Up @@ -176,8 +179,26 @@ def train(data, args, saved_model_path):
best_test_r = -1

# Initialize the optimizer
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(parameters, lr=args.lr, weight_decay=args.weight_decay)
aggr_module_params = []
other_module_params = []
for m_name in model._modules:
m = model._modules[m_name]
if isinstance(m, torch.nn.ModuleList):
for p in m.parameters():
if p.requires_grad:
aggr_module_params.append(p)
else:
for p in m.parameters():
if p.requires_grad:
other_module_params.append(p)

optimizer = optim.Adam([
{"params": (aggr_module_params), "name": "aggr"},
{"params": (other_module_params), "name": "other"}
],
lr=args.lr,
weight_decay=args.weight_decay
)

for idx in range(args.num_epoch):
epoch_start = time.time()
Expand Down Expand Up @@ -316,13 +337,13 @@ def load_model_decode(model_dir, data, args, name):
parser = argparse.ArgumentParser()
parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
parser.add_argument('--use_gpu', type=str2bool, default=True)
parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')
parser.add_argument('--train', help='Training set.', default='data/ontonote.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/ontonote.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/ontonote.cn/test.char.bmes')
parser.add_argument('--raw', help='Raw file for decoding.')
parser.add_argument('--output', help='Output results for decoding.')
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')
parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model")
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/ontonote.cn/saved.dset')
parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model_ontonote")
parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")

Expand All @@ -333,7 +354,7 @@ def load_model_decode(model_dir, data, args, name):

parser.add_argument('--seed', help='Random seed', default=1023, type=int)
parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int)
parser.add_argument('--num_epoch',default=50, type=int, help="Epoch number.")
parser.add_argument('--num_epoch',default=100, type=int, help="Epoch number.")
parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')
parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.')
Expand All @@ -346,7 +367,7 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.')
parser.add_argument('--char_dim', type=int, help='Char embedding size.')
parser.add_argument('--word_dim', type=int, help='Word embedding size.')
parser.add_argument('--lr', type=float, default=5e-05)
parser.add_argument('--lr', type=float, default=2e-05)
parser.add_argument('--lr_decay', type=float, default=0)
parser.add_argument('--weight_decay', type=float, default=0)

Expand Down
26 changes: 24 additions & 2 deletions model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import torch
import math
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
Expand Down Expand Up @@ -110,13 +111,19 @@ def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Nodes_Cell, self).__init__()

self.use_global = use_global

self.hidden_size = hid_h
self.Wix = nn.Linear(input_h, hid_h)
self.Wi2 = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wcx = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)

def forward(self, h, h2, x, glo=None):

Expand Down Expand Up @@ -144,13 +151,20 @@ def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Edges_Cell, self).__init__()

self.use_global = use_global

self.hidden_size = hid_h
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)

self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)

def forward(self, h, x, glo=None):

x = self.drop(x)
Expand All @@ -175,12 +189,20 @@ class Global_Cell(nn.Module):
def __init__(self, input_h, hid_h, dropout=0.2):
super(Global_Cell, self).__init__()

self.hidden_size = hid_h
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)

self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)

def forward(self, h, x):

x = self.drop(x)
Expand Down

0 comments on commit 7e7e708

Please sign in to comment.