Skip to content

Commit

Permalink
[chatgpt] optimize generation kwargs (hpcaitech#2717)
Browse files Browse the repository at this point in the history
* [chatgpt] ppo trainer use default generate args

* [chatgpt] example remove generation preparing fn

* [chatgpt] benchmark remove generation preparing fn

* [chatgpt] fix ci
  • Loading branch information
ver217 authored Feb 15, 2023
1 parent 21d6a48 commit 9c0943e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 52 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- name: Execute Examples
run: |
cd applications/ChatGPT
./examples/test_ci.sh
env:
NCCL_SHM_DISABLE: 1
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run_chatgpt_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Execute Unit Testing
run: |
cd applications/ChatGPT
pytest tests/
env:
NCCL_SHM_DISABLE: 1
Expand Down
3 changes: 0 additions & 3 deletions applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
Expand Down Expand Up @@ -151,8 +150,6 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator])

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
Expand Down
3 changes: 0 additions & 3 deletions applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.distributed as dist
import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
Expand Down Expand Up @@ -144,8 +143,6 @@ def main(args):
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=opt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator])

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
Expand Down
10 changes: 10 additions & 0 deletions applications/ChatGPT/chatgpt/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
from chatgpt.nn.generation_utils import update_model_kwargs_fn
from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer

Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(self,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model)
Expand Down Expand Up @@ -102,3 +104,11 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
self.critic_optim.zero_grad()

return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}

def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
45 changes: 18 additions & 27 deletions applications/ChatGPT/examples/train_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@

import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import (
bloom_prepare_inputs_fn,
gpt_prepare_inputs_fn,
opt_prepare_inputs_fn,
update_model_kwargs_fn,
)
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
Expand Down Expand Up @@ -66,36 +60,33 @@ def main(args):
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = gpt_prepare_inputs_fn
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = bloom_prepare_inputs_fn
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
prepare_inputs_fn = opt_prepare_inputs_fn
else:
raise ValueError(f'Unsupported model "{args.model}"')

# configure trainer
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=preprocess_batch,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn)
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=preprocess_batch,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts,
Expand Down
37 changes: 18 additions & 19 deletions applications/ChatGPT/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pandas as pd
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
Expand Down Expand Up @@ -70,24 +69,24 @@ def tokenize_fn(texts):
return {k: v.cuda() for k, v in batch.items()}

# configure trainer
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=tokenize_fn,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn)
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=tokenize_fn,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)

trainer.fit(dataset,
num_episodes=args.num_episodes,
Expand Down

0 comments on commit 9c0943e

Please sign in to comment.