From b364fcdbbc7c8416313ebdfd4bcb9fd18384e365 Mon Sep 17 00:00:00 2001
From: lanweiyu
Date: Fri, 13 Apr 2018 17:40:37 +0800
Subject: [PATCH] Fix bugs in 1)getting "sent_score_file"; 2)getting "vf_dir".
---
doit/do_all.sh | 6 ++----
sampled_data_provider.py | 4 ++--
test_models.py | 6 +++---
trainer.py | 12 ++++++++----
4 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/doit/do_all.sh b/doit/do_all.sh
index 8ee6046..c4553d9 100755
--- a/doit/do_all.sh
+++ b/doit/do_all.sh
@@ -15,13 +15,13 @@ done
# --------------------------------------------
python ../generate_vocab.py $train_collection --language $lang --rootpath $rootpath
-python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $train_sent_score_file
+python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method
# --------------------------------------------
# 2. validation
# --------------------------------------------
-python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $val_sent_score_file
+python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method
# --------------------------------------------
# 3. test
@@ -29,8 +29,6 @@ python ../compute_val_loss.py --train_collection $train_collection --val_collect
top_k=1
beam_size=5
-#beam_size=20
-#beam_size=10
length_normalization_factor=0.5
python ../test_models.py --train_collection $train_collection --val_collection $val_collection --test_collection $test_collection --model_name $model_name --vf_name $vf --top_k $top_k --length_normalization_factor $length_normalization_factor --beam_size $beam_size --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method
diff --git a/sampled_data_provider.py b/sampled_data_provider.py
index 926dac6..41de6a7 100644
--- a/sampled_data_provider.py
+++ b/sampled_data_provider.py
@@ -81,13 +81,13 @@ def __init__(self, collection, vocab_file, feature, language,
self.fluency_threshold = fluency_threshold
self.method = method
if method:
+ self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath)
assert method in ['sample','filter','weighted']
assert self.sent_score_file != None
- assert fluency_threshold>0
+ assert fluency_threshold > 0
if method == 'weighted':
# Not sampling the data if fluency-guided method is weighted_loss
self.method = method = None
- self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath)
else:
self.sent_score_file = None
diff --git a/test_models.py b/test_models.py
index 0f1ced2..ef2c652 100644
--- a/test_models.py
+++ b/test_models.py
@@ -84,12 +84,12 @@ def main(unused_args):
img_list = map(str.strip, open(img_set_file).readlines())
# have visual feature ready
- FLAGS.vf_dir = os.path.join(rootpath, test_collection, 'FeatureData', feature)
- vf_reader = BigFile( FLAGS.vf_dir )
+ vf_dir = utility.get_feat_dir(test_collection, feature, rootpath)
+ vf_reader = BigFile( vf_dir )
textbank = TextBank(utility.get_train_vocab_file(FLAGS))
config.vocab_size = len(textbank.vocab)
- config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1])
+ config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1])
model_dir = utility.get_model_dir(FLAGS)
output_dir = utility.get_pred_dir(FLAGS)
diff --git a/trainer.py b/trainer.py
index 755f268..2977840 100644
--- a/trainer.py
+++ b/trainer.py
@@ -125,12 +125,16 @@ def main(unused_args):
# Load model configuration
config_path = os.path.join(os.path.dirname(__file__), 'model_conf', FLAGS.model_name + '.py')
config = utility.load_config(config_path)
-
- FLAGS.vf_dir = os.path.join(FLAGS.rootpath, FLAGS.train_collection, 'FeatureData', FLAGS.vf_name)
- vocab_file = utility.get_vocab_file(FLAGS.train_collection, FLAGS.word_cnt_thr, FLAGS.rootpath)
+
+ rootpath = FLAGS.rootpath
+ train_collection = FLAGS.train_collection
+ feature = FLAGS.vf_name
+
+ vf_dir = utility.get_feat_dir(train_collection, feature, rootpath)
+ vocab_file = utility.get_vocab_file(train_collection, FLAGS.word_cnt_thr, rootpath)
textbank = TextBank(vocab_file)
config.vocab_size = len(textbank.vocab)
- config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1])
+ config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1])
if hasattr(config,'num_epoch_save'):
num_epoch_save = config.num_epoch_save