Skip to content

Commit

Permalink
Merge pull request #2276 from ShnitzelKiller/scatterfix
Browse files Browse the repository at this point in the history
fix error due to wrong argument name to Tensor.scatter()
  • Loading branch information
thomwolf authored Dec 23, 2019
2 parents ce50305 + 398bb03 commit e4e2a66
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,9 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
sorted_indices_to_remove[..., 0] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, source=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits

Expand Down

0 comments on commit e4e2a66

Please sign in to comment.