Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
weiyu committed Sep 11, 2017
1 parent 4e827e0 commit 0f6bdd7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion compute_val_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main(unused_args):
language=FLAGS.language,
flag_shuffle=False,
method=config.fluency_method,
rootpath=rootpath)
rootpath=FLAGS.rootpath)
iter2loss = {}
for iter_n, model_path in model_path_list:
loss_file = os.path.join(output_dir, 'model_%d.ckpt' % iter_n, 'loss.txt')
Expand Down
2 changes: 1 addition & 1 deletion sampled_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _load_data(self, verbose=True):
# Encode sentences
tokens = TextTool.tokenize(sent, self.language)
data['sentence'] = self.textbank.encode_tokens(tokens, flag_add_bos=False)
data['sent_score'] = sid2score[sid] if self.sent_score_file else 1
data['sent_score'] = sid2score[sid] if self.sent_score_file and sid in sid2score else 1
self._data_queue.append(data)
if verbose and (ind_a + 1) % 20000 == 0:
logger.debug('%d/%d annotation', ind_a + 1, len(annos))
Expand Down
3 changes: 2 additions & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def main(unused_args):

iters_done = 0
data_provider = BucketDataProvider(FLAGS.train_collection, vocab_file, FLAGS.vf_name,
language=FLAGS.language, method=config.fluency_method, rootpath=rootpath)
language=FLAGS.language, method=config.fluency_method,
rootpath=FLAGS.rootpath)

for i in range(config.num_epoch):
logger.info('epoch %d', i)
Expand Down

0 comments on commit 0f6bdd7

Please sign in to comment.