Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] RLHF end-to-end, clean #1597

Merged
merged 52 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f24281c
RLHF end2end example
apbard Jun 27, 2023
ef3f76f
add VmapModule and from_lmhead_model method
apbard Jun 27, 2023
02a909b
Update examples/rlhf/train_rlhf.py
apbard Jun 28, 2023
953e4af
addressing comments
apbard Jun 28, 2023
ffb8661
Merge remote-tracking branch 'origin/main' into rlhf-networks
vmoens Jun 28, 2023
f43faea
Update torchrl/modules/tensordict_module/common.py
vmoens Jun 28, 2023
69b0588
Update torchrl/modules/tensordict_module/actors.py
vmoens Jun 28, 2023
b6fecbb
Add RolloutFromModel class
tcbegley Jun 26, 2023
bd8fbb6
Add rollout tests
tcbegley Jun 26, 2023
6fbb603
Apply suggestions from code review
tcbegley Jun 26, 2023
3e80a55
Address comments
tcbegley Jun 26, 2023
385ac90
Docstring lint
tcbegley Jun 26, 2023
8d0a152
Apply suggestions from code review
tcbegley Jun 27, 2023
fcddc97
Address comments
tcbegley Jun 27, 2023
5c7c72e
Fix tests
tcbegley Jun 28, 2023
92d5757
Handle missing transformers import
tcbegley Jun 28, 2023
eec0eaf
Import transformers locally
tcbegley Jun 28, 2023
87501ea
lint
vmoens Jun 28, 2023
043fcf6
Merge branch 'rlhf-rollout' into rlhf-example
tcbegley Jun 29, 2023
3f53046
Merge branch 'rlhf-networks' into rlhf-example
tcbegley Jun 29, 2023
8b69e41
lint
tcbegley Jun 29, 2023
24eaa3a
Example bugfixes
tcbegley Jun 29, 2023
fba43a1
Move KL controller logic
tcbegley Jun 29, 2023
20fa920
Merge branch 'main' into rlhf-example
vmoens Jul 4, 2023
c07ac93
amend
vmoens Jul 4, 2023
f463e0e
addressing comments about klcontroller
apbard Jul 4, 2023
eac5374
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Sep 5, 2023
8d2dde7
Merge remote-tracking branch 'origin/main' into rlhf-example
vmoens Oct 1, 2023
a2ba045
Merge branch 'main' into rlhf-example
vmoens Oct 2, 2023
a9b94f0
amend
vmoens Oct 2, 2023
d983ebd
init
vmoens Oct 3, 2023
097c443
readme
vmoens Oct 3, 2023
0efd93a
amend
vmoens Oct 3, 2023
fba9f03
amend
vmoens Oct 3, 2023
cc535e5
amend
vmoens Oct 4, 2023
28c116f
amend
vmoens Oct 4, 2023
0f128a6
amend
vmoens Oct 4, 2023
e0ad043
amend
vmoens Oct 4, 2023
e8cad9b
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 4, 2023
c93c134
amend
vmoens Oct 4, 2023
56f7597
init
vmoens Oct 5, 2023
3fa6ea5
Merge branch 'refactor_ddpg_loss' into rlhf-example-refactor
vmoens Oct 5, 2023
c1c41dc
amend
vmoens Oct 5, 2023
880e5b4
amend
vmoens Oct 5, 2023
d36ce77
Update run_test.sh
vmoens Oct 5, 2023
942b311
amend
vmoens Oct 5, 2023
fca9f7b
amend
vmoens Oct 5, 2023
6362715
lint
vmoens Oct 5, 2023
e3b2d4f
amend
vmoens Oct 5, 2023
7918f86
amend
vmoens Oct 5, 2023
9658a44
Merge remote-tracking branch 'origin/main' into rlhf-example-refactor
vmoens Oct 5, 2023
eb041a4
lint
vmoens Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
init
  • Loading branch information
vmoens committed Oct 3, 2023
commit d983ebdd83c71aec59ca244be2e967a1cd00b230
10 changes: 8 additions & 2 deletions examples/rlhf/models/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
__all__ = ["init_actor_critic"]


def init_actor_critic(transformer_name_or_path, dropout, device, compile_):
def init_actor_critic(model_cfg, sys_cfg):

transformer_name_or_path = model_cfg.name_or_path
dropout = model_cfg.dropout

device = sys_cfg.device
compile_model = sys_cfg.compile
base_model = init_transformer(
transformer_name_or_path,
dropout,
device,
as_tensordictmodule=False,
compile_=compile_,
compile_model=compile_model,
inference=True,
)
model = LMHeadActorValueOperator(base_model)
Expand Down
4 changes: 2 additions & 2 deletions examples/rlhf/models/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def init_reward_model(
transformer_path=None, reward_model_path=None, device=None, compile_=False
transformer_path=None, reward_model_path=None, device=None, compile_model=False
):
if not ((transformer_path is None) ^ (reward_model_path is None)):
raise ValueError(
Expand All @@ -22,7 +22,7 @@ def init_reward_model(
model = GPT2RewardModel.from_pretrained(reward_model_path)

model.to(device)
if compile_:
if compile_model:
print("Compiling the reward model...")
model = torch.compile(model)

Expand Down
4 changes: 2 additions & 2 deletions examples/rlhf/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def init_transformer(
name_or_path,
dropout,
device,
compile_,
compile_model,
as_tensordictmodule=True,
inference=False,
):
Expand All @@ -26,7 +26,7 @@ def init_transformer(
)
model.to(device)

if compile_:
if compile_model:
# TODO: logging instead of printing?
print("Compiling transformer model...")
model = torch.compile(model)
Expand Down
2 changes: 1 addition & 1 deletion examples/rlhf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def main(cfg):
resolve_name_or_path(model_cfg.name_or_path),
model_cfg.dropout,
device,
compile_=compile_,
compile_model=compile_,
)
optimizer = torch.optim.AdamW(model.parameters(), **train_cfg.optimizer)
scheduler = None
Expand Down
4 changes: 2 additions & 2 deletions examples/rlhf/train_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def main(cfg):
model = init_reward_model(
reward_model_path=resolve_name_or_path(reward_model_cfg.out_dir),
device=device,
compile_=compile_,
compile_model=compile_,
)
else:
model = init_reward_model(
transformer_path=resolve_name_or_path(model_cfg.name_or_path),
device=device,
compile_=compile_,
compile_model=compile_,
)
# Freeze the first 70% of the hidden layers of the reward model backbone
layers = model.transformer.h
Expand Down
Loading
Loading