From b364fcdbbc7c8416313ebdfd4bcb9fd18384e365 Mon Sep 17 00:00:00 2001 From: lanweiyu Date: Fri, 13 Apr 2018 17:40:37 +0800 Subject: [PATCH] Fix bugs in 1)getting "sent_score_file"; 2)getting "vf_dir". --- doit/do_all.sh | 6 ++---- sampled_data_provider.py | 4 ++-- test_models.py | 6 +++--- trainer.py | 12 ++++++++---- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/doit/do_all.sh b/doit/do_all.sh index 8ee6046..c4553d9 100755 --- a/doit/do_all.sh +++ b/doit/do_all.sh @@ -15,13 +15,13 @@ done # -------------------------------------------- python ../generate_vocab.py $train_collection --language $lang --rootpath $rootpath -python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $train_sent_score_file +python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method # -------------------------------------------- # 2. validation # -------------------------------------------- -python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $val_sent_score_file +python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method # -------------------------------------------- # 3. test @@ -29,8 +29,6 @@ python ../compute_val_loss.py --train_collection $train_collection --val_collect top_k=1 beam_size=5 -#beam_size=20 -#beam_size=10 length_normalization_factor=0.5 python ../test_models.py --train_collection $train_collection --val_collection $val_collection --test_collection $test_collection --model_name $model_name --vf_name $vf --top_k $top_k --length_normalization_factor $length_normalization_factor --beam_size $beam_size --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method diff --git a/sampled_data_provider.py b/sampled_data_provider.py index 926dac6..41de6a7 100644 --- a/sampled_data_provider.py +++ b/sampled_data_provider.py @@ -81,13 +81,13 @@ def __init__(self, collection, vocab_file, feature, language, self.fluency_threshold = fluency_threshold self.method = method if method: + self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath) assert method in ['sample','filter','weighted'] assert self.sent_score_file != None - assert fluency_threshold>0 + assert fluency_threshold > 0 if method == 'weighted': # Not sampling the data if fluency-guided method is weighted_loss self.method = method = None - self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath) else: self.sent_score_file = None diff --git a/test_models.py b/test_models.py index 0f1ced2..ef2c652 100644 --- a/test_models.py +++ b/test_models.py @@ -84,12 +84,12 @@ def main(unused_args): img_list = map(str.strip, open(img_set_file).readlines()) # have visual feature ready - FLAGS.vf_dir = os.path.join(rootpath, test_collection, 'FeatureData', feature) - vf_reader = BigFile( FLAGS.vf_dir ) + vf_dir = utility.get_feat_dir(test_collection, feature, rootpath) + vf_reader = BigFile( vf_dir ) textbank = TextBank(utility.get_train_vocab_file(FLAGS)) config.vocab_size = len(textbank.vocab) - config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1]) + config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1]) model_dir = utility.get_model_dir(FLAGS) output_dir = utility.get_pred_dir(FLAGS) diff --git a/trainer.py b/trainer.py index 755f268..2977840 100644 --- a/trainer.py +++ b/trainer.py @@ -125,12 +125,16 @@ def main(unused_args): # Load model configuration config_path = os.path.join(os.path.dirname(__file__), 'model_conf', FLAGS.model_name + '.py') config = utility.load_config(config_path) - - FLAGS.vf_dir = os.path.join(FLAGS.rootpath, FLAGS.train_collection, 'FeatureData', FLAGS.vf_name) - vocab_file = utility.get_vocab_file(FLAGS.train_collection, FLAGS.word_cnt_thr, FLAGS.rootpath) + + rootpath = FLAGS.rootpath + train_collection = FLAGS.train_collection + feature = FLAGS.vf_name + + vf_dir = utility.get_feat_dir(train_collection, feature, rootpath) + vocab_file = utility.get_vocab_file(train_collection, FLAGS.word_cnt_thr, rootpath) textbank = TextBank(vocab_file) config.vocab_size = len(textbank.vocab) - config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1]) + config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1]) if hasattr(config,'num_epoch_save'): num_epoch_save = config.num_epoch_save