Skip to content

Commit

Permalink
Represent query_length in a different way to solve jit issue (#25164)
Browse files Browse the repository at this point in the history
Fix jit trace
  • Loading branch information
jiqing-feng authored Jul 28, 2023
1 parent 2a78720 commit d23d2c2
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions src/transformers/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def forward(

attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale

query_length = seq_length
if past_key_value is not None:
query_length += past_key_value[0].shape[2]
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]

if position_bias is not None:
if len(position_bias.shape) != 3:
Expand Down

0 comments on commit d23d2c2

Please sign in to comment.