Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix masking of response tokens (huggingface#1718)
Current handling of `response_masks` inside `batch_forward_pass` function does not take padding into consideration which results with shape unmatch during masking. Since response mask is a mask tensor of response tokens, response tokens should not be concatenated with a `torch.zeros(query_length)` and masking operation should be done without slicing. Remove the concatenation of the response mask, remove the slicing from the response mask since response mask already has the length of `end - start + 1`, which is equal to length of `masks[j, start:end]`.
- Loading branch information