From a69913403d344e386c1a4e604e19d7def92f29cb Mon Sep 17 00:00:00 2001 From: Qznan Date: Wed, 28 Oct 2020 16:52:21 +0800 Subject: [PATCH] refactor beam_search --- qiznlp/common/modules/beam_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qiznlp/common/modules/beam_search.py b/qiznlp/common/modules/beam_search.py index 3703222..6c1b587 100644 --- a/qiznlp/common/modules/beam_search.py +++ b/qiznlp/common/modules/beam_search.py @@ -157,7 +157,7 @@ def compute_batch_and_group_indices(batch_size, group_size, beam_size): def beam_search(symbols_to_logits_fn, initial_ids, # int32 beam_size, - decode_length, + max_decode_len, vocab_size, alpha=0, states=None, @@ -310,7 +310,7 @@ def other_step(): def loop_condition(i, unused_seq, unused_log_probs, finished_flags, unused_states): not_all_finish = tf.logical_not(tf.reduce_all(finished_flags)) - less_than_maxlength = tf.less(i, decode_length) + less_than_maxlength = tf.less(i, max_decode_len) return tf.logical_and(not_all_finish, less_than_maxlength) # start loop