Skip to content

Commit

Permalink
Fix masking of response tokens (huggingface#1718)
Browse files Browse the repository at this point in the history
Current handling of `response_masks` inside `batch_forward_pass`
function does not take padding into consideration which results with
shape unmatch during masking. Since response mask is a mask tensor of
response tokens, response tokens should not be concatenated with a
`torch.zeros(query_length)` and masking operation should be done without
slicing.

Remove the concatenation of the response mask, remove the slicing from
the response mask since response mask already has the length of `end -
start + 1`, which is equal to length of `masks[j, start:end]`.
  • Loading branch information
mertsayar8 authored Jun 20, 2024
1 parent ba6abee commit 3bf9449
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _step_safety_checker(
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
list of optional tensors containing the masks of shape (`response_length`)
Returns:
`tuple`: The input processed data.
"""
Expand Down Expand Up @@ -1033,15 +1033,11 @@ def batched_forward_pass(
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]

masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j]

if return_logits:
all_logits.append(logits)
Expand Down

0 comments on commit 3bf9449

Please sign in to comment.