Skip to content

Commit

Permalink
Fix ppov2 test case (huggingface#1661)
Browse files Browse the repository at this point in the history
* Fix PPOv2 / RLOO refactor's stuff

* update terminology to use stop token
  • Loading branch information
vwxyzjn authored May 23, 2024
1 parent bc8dfbf commit e7cb597
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 49 deletions.
8 changes: 4 additions & 4 deletions docs/source/ppov2_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ References:
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.

```bash
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand Down Expand Up @@ -55,7 +55,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.


## What is my model doing exactly?
Expand Down Expand Up @@ -186,7 +186,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
# 6.9B PPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
Expand All @@ -201,7 +201,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml
--reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \
--local_rollout_forward_batch_size 2 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
```

1B experiment can be found here:
Expand Down
8 changes: 4 additions & 4 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ References:
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.

```bash
python examples/scripts/minimal/rloo.py \
python examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
Expand Down Expand Up @@ -57,7 +57,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.


## What is my model doing exactly?
Expand Down Expand Up @@ -226,7 +226,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--kl_coef 0.03
# 6.9B RLOO experiment
Expand All @@ -244,7 +244,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml
--reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \
--local_rollout_forward_batch_size 2 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--kl_coef 0.03
```

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


"""
python -i examples/scripts/minimal/ppo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
Expand All @@ -24,7 +24,7 @@
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/ppo.py \
examples/scripts/ppo/ppo.py \
--output_dir models/minimal/ppo \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check
Expand All @@ -41,7 +41,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""


Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/ppo/ppo_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


"""
python -i examples/scripts/minimal/ppo_zephyr.py \
python -i examples/scripts/ppo/ppo_zephyr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 1 \
Expand All @@ -25,10 +25,10 @@
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/ppo_zephyr.py \
examples/scripts/ppo/ppo_zephyr.py \
--output_dir models/minimal/ppo_zephyr10 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand All @@ -43,7 +43,7 @@
--deepspeed3 \
--kl_coef 0.10 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 512 \
"""

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


"""
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand All @@ -25,7 +25,7 @@
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/rloo.py \
examples/scripts/rloo/rloo.py \
--output_dir models/minimal/rloo \
--rloo_k 2 \
--num_ppo_epochs 1 \
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check
Expand All @@ -43,7 +43,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""


Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/rloo/rloo_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


"""
python -i examples/scripts/minimal/rloo_zephyr.py \
python -i examples/scripts/rloo/rloo_zephyr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo_zephyr \
--per_device_train_batch_size 64 \
Expand All @@ -25,11 +25,11 @@
--sft_model_path HuggingFaceH4/mistral-7b-sft-beta \
--reward_model_path weqweasdas/RM-Mistral-7B \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/rloo_zephyr.py \
examples/scripts/rloo/rloo_zephyr.py \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--rloo_k 2 \
Expand All @@ -45,7 +45,7 @@
--deepspeed3 \
--kl_coef 0.10 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 512 \
"""

Expand Down
4 changes: 2 additions & 2 deletions tests/test_ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

def test():
command = """\
python -i examples/scripts/minimal/ppo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 5 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""
subprocess.run(
command,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

def test():
command = """\
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 5 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""
subprocess.run(
command,
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
#########
for module in [policy, ref_policy, value_model, reward_model]:
disable_dropout_in_model(module)
if args.truncate_token and args.truncate_token == "eos":
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
self.model = PolicyAndValueWrapper(policy, value_model)
self.create_optimizer_and_scheduler(num_training_steps=args.num_updates)
Expand Down Expand Up @@ -285,7 +285,7 @@ def repeat_generator():
query_response, logits = generate(
unwrapped_model.policy,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand Down Expand Up @@ -407,7 +407,7 @@ def repeat_generator():
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]

output, vpred_temp = forward(model, mb_query_responses, tokenizer)
output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
Expand Down Expand Up @@ -543,7 +543,7 @@ def generate_completions(self, sampling: bool = False):
query_response, _ = generate(
unwrapped_model.policy,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class RLOOConfig(OnpolicyRuntimeConfig, TrainingArguments):
"""the name of the pretrained model to use"""
response_length: int = 53
"""the length of the response"""
truncate_token: Optional[Literal["eos"]] = None
"""the truncate token"""
truncate_token_id: Optional[int] = None
"""the truncation token id"""
stop_token: Optional[Literal["eos"]] = None
"""the stop token"""
stop_token_id: Optional[int] = None
"""the stop token id"""
temperature: float = 0.7
"""the sampling temperature"""
penalty_reward_value: int = -1
"""the reward value for responses that do not contain `truncate_token_id`"""
"""the reward value for responses that do not contain `stop_token_id`"""
non_eos_penalty: bool = False
"""whether to penalize responses that do not contain `truncate_token_id`"""
"""whether to penalize responses that do not contain `stop_token_id`"""
reward_model_path: str = "EleutherAI/pythia-160m"
"""the path to the reward model"""
sft_model_path: str = "EleutherAI/pythia-160m"
Expand Down
20 changes: 9 additions & 11 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
#########
for module in [policy, ref_policy, reward_model]:
disable_dropout_in_model(module)
if args.truncate_token and args.truncate_token == "eos":
args.truncate_token_id = tokenizer.eos_token_id
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
self.model = policy
self.create_optimizer_and_scheduler(num_training_steps=args.num_updates)

Expand Down Expand Up @@ -246,7 +246,7 @@ def repeat_generator():
query_response, logits = generate(
unwrapped_model,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand All @@ -265,11 +265,9 @@ def repeat_generator():
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `truncate_token_id`
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if (
args.truncate_token_id is not None
): # handle the edge case when truncate_token_id exists but is 0
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, tokenizer.pad_token_id, response
)
Expand Down Expand Up @@ -299,7 +297,7 @@ def repeat_generator():
torch.cuda.empty_cache()
gc.collect()

# Response Processing 3. filter response. Ensure that the sample contains truncate_token_id
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1)
Expand Down Expand Up @@ -342,7 +340,7 @@ def repeat_generator():
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]

output = forward(model, mb_query_responses, tokenizer)
output = forward(model, mb_query_responses, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
Expand Down Expand Up @@ -439,12 +437,12 @@ def generate_completions(self, sampling: bool = False):
query_response, _ = generate(
unwrapped_model,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.truncate_token_id is not None: # handle the edge case when truncate_token_id exists but is 0
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response)
table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True)))
table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response)))
Expand Down

0 comments on commit e7cb597

Please sign in to comment.