Skip to content

Commit

Permalink
Fix bugs in 1)getting "sent_score_file"; 2)getting "vf_dir".
Browse files Browse the repository at this point in the history
  • Loading branch information
weiyuk committed Apr 13, 2018
1 parent e837606 commit b364fcd
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
6 changes: 2 additions & 4 deletions doit/do_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@ done
# --------------------------------------------

python ../generate_vocab.py $train_collection --language $lang --rootpath $rootpath
python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $train_sent_score_file
python ../trainer.py --model_name $model_name --train_collection $train_collection --language $lang --vf_name $vf --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method

# --------------------------------------------
# 2. validation
# --------------------------------------------

python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method --sent_score_file $val_sent_score_file
python ../compute_val_loss.py --train_collection $train_collection --val_collection $val_collection --model_name $model_name --vf_name $vf --language $lang --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method

# --------------------------------------------
# 3. test
# --------------------------------------------

top_k=1
beam_size=5
#beam_size=20
#beam_size=10
length_normalization_factor=0.5

python ../test_models.py --train_collection $train_collection --val_collection $val_collection --test_collection $test_collection --model_name $model_name --vf_name $vf --top_k $top_k --length_normalization_factor $length_normalization_factor --beam_size $beam_size --overwrite $overwrite --rootpath $rootpath --fluency_method $fluency_method
Expand Down
4 changes: 2 additions & 2 deletions sampled_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def __init__(self, collection, vocab_file, feature, language,
self.fluency_threshold = fluency_threshold
self.method = method
if method:
self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath)
assert method in ['sample','filter','weighted']
assert self.sent_score_file != None
assert fluency_threshold>0
assert fluency_threshold > 0
if method == 'weighted':
# Not sampling the data if fluency-guided method is weighted_loss
self.method = method = None
self.sent_score_file = utility.get_sent_score_file(collection, language, rootpath)
else:
self.sent_score_file = None

Expand Down
6 changes: 3 additions & 3 deletions test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def main(unused_args):
img_list = map(str.strip, open(img_set_file).readlines())

# have visual feature ready
FLAGS.vf_dir = os.path.join(rootpath, test_collection, 'FeatureData', feature)
vf_reader = BigFile( FLAGS.vf_dir )
vf_dir = utility.get_feat_dir(test_collection, feature, rootpath)
vf_reader = BigFile( vf_dir )

textbank = TextBank(utility.get_train_vocab_file(FLAGS))
config.vocab_size = len(textbank.vocab)
config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1])
config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1])

model_dir = utility.get_model_dir(FLAGS)
output_dir = utility.get_pred_dir(FLAGS)
Expand Down
12 changes: 8 additions & 4 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,16 @@ def main(unused_args):
# Load model configuration
config_path = os.path.join(os.path.dirname(__file__), 'model_conf', FLAGS.model_name + '.py')
config = utility.load_config(config_path)

FLAGS.vf_dir = os.path.join(FLAGS.rootpath, FLAGS.train_collection, 'FeatureData', FLAGS.vf_name)
vocab_file = utility.get_vocab_file(FLAGS.train_collection, FLAGS.word_cnt_thr, FLAGS.rootpath)

rootpath = FLAGS.rootpath
train_collection = FLAGS.train_collection
feature = FLAGS.vf_name

vf_dir = utility.get_feat_dir(train_collection, feature, rootpath)
vocab_file = utility.get_vocab_file(train_collection, FLAGS.word_cnt_thr, rootpath)
textbank = TextBank(vocab_file)
config.vocab_size = len(textbank.vocab)
config.vf_size = int(open(os.path.join(FLAGS.vf_dir, 'shape.txt')).read().split()[1])
config.vf_size = int(open(os.path.join(vf_dir, 'shape.txt')).read().split()[1])

if hasattr(config,'num_epoch_save'):
num_epoch_save = config.num_epoch_save
Expand Down

0 comments on commit b364fcd

Please sign in to comment.