Skip to content

Commit

Permalink
Fix some bugs and modify default settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 25, 2019
1 parent 33aa30b commit 17b97d8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
12 changes: 11 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def str2bool(v):
raise argparse.ArgumentTypeError('Boolean value expected.')


def lr_decay(optimizer, epoch, decay_rate, init_lr):
lr = init_lr * ((1-decay_rate)**epoch)
print( " Learning rate is setted as:", lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return optimizer


def data_initialization(data, word_file, train_file, dev_file, test_file):

data.build_word_file(word_file)
Expand Down Expand Up @@ -175,6 +183,7 @@ def train(data, args, saved_model_path):
epoch_start = time.time()
temp_start = epoch_start
print(("Epoch: %s/%s" %(idx, args.num_epoch)))
optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)
sample_loss = 0
batch_loss = 0
total_loss = 0
Expand Down Expand Up @@ -322,7 +331,7 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.')
parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')

parser.add_argument('--seed', help='Random seed', default=47, type=int)
parser.add_argument('--seed', help='Random seed', default=1023, type=int)
parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int)
parser.add_argument('--num_epoch',default=50, type=int, help="Epoch number.")
parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
Expand All @@ -338,6 +347,7 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--char_dim', type=int, help='Char embedding size.')
parser.add_argument('--word_dim', type=int, help='Word embedding size.')
parser.add_argument('--lr', type=float, default=5e-05)
parser.add_argument('--lr_decay', type=float, default=0)
parser.add_argument('--weight_decay', type=float, default=0)

args = parser.parse_args()
Expand Down
5 changes: 1 addition & 4 deletions utils/alphabet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def get_instance(self, index):
return self.instances[0]

def size(self):
if self.label:
return len(self.instances)
else:
return len(self.instances) + 1
return len(self.instances) + 1

def iteritems(self):
return self.instance2index.items()
Expand Down

0 comments on commit 17b97d8

Please sign in to comment.