Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and benchmark #662

Merged
merged 34 commits into from
Sep 13, 2023
Merged

Refactor and benchmark #662

merged 34 commits into from
Sep 13, 2023

Conversation

vwxyzjn
Copy link
Contributor

@vwxyzjn vwxyzjn commented Aug 18, 2023

This PR does a few refactor. Some raw thoughts:

  • For better experiment tracking, we should also create a few more variables
    • query_dataset="imdb"
    • reward_model="sentiment-analysis:lvwerra/distilbert-imdb", it could work with a pipeline or a trained reward model
  • By default, we probably should use a vanilla model like gpt2 in place of lvwerra/gpt2-imdb
  • By default, we should demonstrate end-to-end training (involving a reward model training then policy training).

We have multiple benchmark axes:

  • w/ different models (gpt2, gpt2-xl, falcon, llama2)
    • key research engineering questions
      • how do different model sizes scale?
      • given that the preference labels come from a source model M_s (e.g., gpt2), how does that affect the performance of a target model M_t (e.g., falcon, gptj, llama2)?
        • This is actually an important assumption we have been operating.
  • w/ and w/o gradient accumulation / multi-GPU
    • key research engineering question: do we need to whiten advantage across the entire batch?
  • w/ and w/o peft
    • key research engineering question: how well does PEFT work with RL
  • w/ and w/o quantization or 4 bits
    • key research engineering question: how well does quantization work with RL training
  • w/ and w/o deepspeed
    • sanity check to make sure it works.
  • w/ different datasets

We can probably have a train.py that can do

accelerate launch train.py --config deepspeed.yaml \
    --model_name falcon-40b \
    --query_dataset book \
    --label_dataset openai-sentiment-label \
    --gradient_accumulation_steps=12 \

Uses tyro to eliminate duplicate code / comments.

Help text also works

image
tyro existing
@dataclass
class ScriptArguments:
    ppo: PPOConfig = field(
        default_factory=lambda: PPOConfig(
            model_name="lvwerra/gpt2-imdb",
            learning_rate=1.41e-5,
            log_with=None,
            mini_batch_size=128,
            batch_size=128,
            gradient_accumulation_steps=1,
            early_stopping=False,
            target_kl=6,
            kl_penalty="kl",
            seed=0,
        )
    )
args = tyro.cli(ScriptArguments)

print(args.ppo.seed)
@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine with PPO
    """

    # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
    # models like gpt-neo* models are more suitable.
    model_name: Optional[str] = field(default="lvwerra/gpt2-imdb", metadata={"help": "the model name"})
    log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
    mini_batch_size: Optional[int] = field(default=128, metadata={"help": "the PPO minibatch size"})
    batch_size: Optional[int] = field(default=128, metadata={"help": "the batch size"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
    use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
    use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
    kl_penalty: Optional[str] = field(
        default="kl",
        metadata={
            "help": "kl penalty options: 'kl': model_logp - ref_logp,  'abs': abs(kl),  'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"
        },
    )
    target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
    seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
    use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
    use_score_norm: Optional[bool] = field(
        default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
    )
    score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    log_with=script_args.log_with,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    early_stopping=script_args.early_stopping,
    target_kl=script_args.target_kl,
    kl_penalty=script_args.kl_penalty,
    seed=script_args.seed,
    use_score_scaling=script_args.use_score_scaling,
    use_score_norm=script_args.use_score_norm,
    score_clip=script_args.score_clip,
)

more controlled terminology and tracking config

add accelerate logging

Log global_backward_batch_size global_batch_size world_size

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@vwxyzjn vwxyzjn marked this pull request as ready for review September 6, 2023 14:08
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Sep 6, 2023

Benchmark and documentation are ready at https://github.com/vwxyzjn/trl/blob/refactor-benchmark/benchmark/README.md. We can probably better tune some of these models for better performance in follow up PRs, including testing for deepspeed integration cc @lewtun.

main
image

with grad accu
image

with different models
image

with peft
image

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vwxyzjn, looks really good to me. The openbenchmark code itself is a bit obscure to me so if we could document well how it works that would be great. Left some comments here and there but in general happy to add it. Since this is not strictly a part of the library and more part of our test suite we can also be a bit more experimental here :)

Regarding tyro - this looks good to me in general, would also love to hear @younesbelkada feedback who maybe also knows about design decisions in transformers about the CLI.

benchmark/plot.sh Outdated Show resolved Hide resolved
benchmark/upload_benchmark.py Outdated Show resolved Hide resolved
trl/trainer/ppo_config.py Show resolved Hide resolved
benchmark/plot.sh Show resolved Hide resolved
examples/scripts/sentiment_tuning.py Show resolved Hide resolved
benchmark/README.md Outdated Show resolved Hide resolved
benchmark/plot.sh Show resolved Hide resolved
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Sep 8, 2023

Thanks @lvwerra! I have addressed the comments :)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank a lot for this great work , as discussed offline :D !
Feel free to merge once the CI is green

@vwxyzjn vwxyzjn merged commit e4f9a48 into huggingface:main Sep 13, 2023
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* refactor and benchmark

* update code

* Add accelerate logging

* logs

* quick fix

* update config

* precommit

* modify training example

* fix multi-gpu all_reduce error `Tensors must be CUDA and dense`

* support more models and benchmark

* update

* add changes

* upload benchmark

* precommit

* add tyro as a dependency

* add tyro

* pre-commit

* precommit

* weird...

* lol typo

* precommit

* sigh

* push changes

* Update benchmark/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Add experiments

* upload image to tag specific folder

* add openrlbenchmark documentation

* rename

* remove unused field

* precommit

* push changes

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* refactor and benchmark

* update code

* Add accelerate logging

* logs

* quick fix

* update config

* precommit

* modify training example

* fix multi-gpu all_reduce error `Tensors must be CUDA and dense`

* support more models and benchmark

* update

* add changes

* upload benchmark

* precommit

* add tyro as a dependency

* add tyro

* pre-commit

* precommit

* weird...

* lol typo

* precommit

* sigh

* push changes

* Update benchmark/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Add experiments

* upload image to tag specific folder

* add openrlbenchmark documentation

* rename

* remove unused field

* precommit

* push changes

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants