Skip to content

Commit

Permalink
refactor beam_search
Browse files Browse the repository at this point in the history
  • Loading branch information
Qznan committed Oct 28, 2020
1 parent d8cb638 commit a699134
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions qiznlp/common/modules/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def compute_batch_and_group_indices(batch_size, group_size, beam_size):
def beam_search(symbols_to_logits_fn,
initial_ids, # int32
beam_size,
decode_length,
max_decode_len,
vocab_size,
alpha=0,
states=None,
Expand Down Expand Up @@ -310,7 +310,7 @@ def other_step():

def loop_condition(i, unused_seq, unused_log_probs, finished_flags, unused_states):
not_all_finish = tf.logical_not(tf.reduce_all(finished_flags))
less_than_maxlength = tf.less(i, decode_length)
less_than_maxlength = tf.less(i, max_decode_len)
return tf.logical_and(not_all_finish, less_than_maxlength)

# start loop
Expand Down

0 comments on commit a699134

Please sign in to comment.