Skip to content

Commit

Permalink
Add ablation experiment settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 17, 2019
1 parent b1d5983 commit 651d3d2
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 3,137 deletions.
20 changes: 19 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
from utils.data import Data


def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


def data_initialization(data, word_file, train_file, dev_file, test_file):

data.build_word_file(word_file)
Expand Down Expand Up @@ -142,6 +153,8 @@ def train(data, args, saved_model_path):

print( "Training model...")
model = Graph(data, args)
if args.use_gpu:
model = model.cuda()
print('# generated parameters:', sum(param.numel() for param in model.parameters()))
print( "Finished built model.")

Expand Down Expand Up @@ -293,7 +306,7 @@ def load_model_decode(model_dir, data, args, name):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
parser.add_argument('--use_gpu', type=bool, default=True)
parser.add_argument('--use_gpu', type=str2bool, default=True)
parser.add_argument('--train', help='Training set.')
parser.add_argument('--dev', help='Developing set.')
parser.add_argument('--test', help='Testing set.')
Expand All @@ -304,6 +317,11 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")

parser.add_argument('--use_crf', type=str2bool, default=True)
parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).')
parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.')
parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')

parser.add_argument('--seed', help='Random seed', default=47, type=int)
parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int)
parser.add_argument('--num_epoch',default=2, type=int, help="Epoch number.")
Expand Down
Loading

0 comments on commit 651d3d2

Please sign in to comment.