Skip to content

Commit

Permalink
[core] refactor step method (huggingface#76)
Browse files Browse the repository at this point in the history
* refactor `step` method

- add safety checker + manual device assignment

* cleanup

* isort

* more generic

* make style

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* make style + add tests

* more tests

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
  • Loading branch information
younesbelkada and lvwerra authored Jan 5, 2023
1 parent 80985f8 commit d6fe301
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ device = ppo_trainer.accelerator.device

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0).to(device)]
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0].to(device)], [response_tensor[0].to(device)], reward)
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```

### Advanced example: IMDB sentiment
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def collater(data):
#### Compute sentiment score
texts = [q + r for q,r in zip(batch['query'], batch['response'])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]).to(device) for output in pipe_outputs]
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
Expand Down
75 changes: 75 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,78 @@ def test_ppo_step_with_ref_and_custom_layers_warning(self):
dataset=dummy_dataset,
num_shared_layers=num_shared_layers,
)

def test_ppo_step_rewards_shape(self):
"""
Test if the rewards shape is correct by asserting that if a wrong reward shape is passed, we get
a value error.
"""

# initialize dataset
dummy_dataset = self._init_dummy_dataset()

ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.gpt2_model,
ref_model=None,
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])]
# train model - this should raise an error
with self.assertRaises(ValueError):
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)

reward = [torch.tensor([1.0]), torch.tensor([0.0])]
# train model - this should work
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break

# check if the gradients are computed for the model
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")

# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")

def test_ppo_step_input_shape(self):
"""
Test if the shape of the expected inputs are correct
"""
# initialize dataset
dummy_dataset = self._init_dummy_dataset()

ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.gpt2_model,
ref_model=None,
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor([1.0]), torch.tensor([0.0])]
# train model - this should raise an error
bs = ppo_trainer.config.batch_size

queries, responses, _ = ppo_trainer._step_safety_checker(
bs, [q for q in query_tensor], [r for r in response_tensor], reward
)

self.assertTrue(isinstance(queries, list), f"queries should be a list, got {type(queries)}")
self.assertTrue(isinstance(responses, list), f"responses should be a list, got {type(responses)}")

# check the shapes
for i in range(bs):
self.assertEqual(queries[i].shape, torch.Size([7]))
self.assertEqual(responses[i].size(), torch.Size([7]))
break
19 changes: 18 additions & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def _step_safety_checker(
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
Returns:
queries, responses, scores (List[`torch.LongTensor`], List[`torch.LongTensor`], List[`torch.FloatTensor`]):
The input processed data.
"""
for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
if not isinstance(tensor_list, list):
Expand All @@ -252,6 +255,20 @@ def _step_safety_checker(
f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
)

# add queries, scores and responses on the correct device
queries = [tensor.to(self.accelerator.device) for tensor in queries]
responses = [tensor.to(self.accelerator.device) for tensor in responses]
scores = [tensor.to(self.accelerator.device) for tensor in scores]

# squeeze scores if needed
for i, score in enumerate(scores):
if score.dim() > 1:
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
elif score.dim() == 1:
scores[i] = score.squeeze()

return queries, responses, scores

def step(
self,
queries: List[torch.LongTensor],
Expand All @@ -276,7 +293,7 @@ def step(

bs = self.config.batch_size

self._step_safety_checker(bs, queries, responses, scores)
queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)

timing = dict()
t0 = time.time()
Expand Down

0 comments on commit d6fe301

Please sign in to comment.