Skip to content

Commit

Permalink
[Minor] Fix runnability of RLHF example in examples/rlhf (pytorch#1753)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianbarber authored Dec 19, 2023
1 parent 6cae0e3 commit 2e1d60c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/rlhf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(cfg):
dtype = cfg.sys.dtype
compile_ = cfg.sys.compile

ctx = setup(device=device, dtype=dtype)
ctx = setup(cfg.sys)

train_loader = get_dataloader(
data_cfg.batch_size,
Expand Down
2 changes: 1 addition & 1 deletion examples/rlhf/train_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(cfg):
dtype = cfg.sys.dtype
compile_ = cfg.sys.compile

ctx = setup(device=device, dtype=dtype)
ctx = setup(cfg.sys)

train_loader = get_dataloader(
data_cfg.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions examples/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
):
self.reward_estimator = reward_estimator
self.model = model
self.promp_logger = prompt_logger
self.prompt_logger = prompt_logger
self.io_cfg = io_cfg
self.eval_interval = io_cfg.eval_interval
self.log_interval = io_cfg.log_interval
Expand All @@ -154,7 +154,7 @@ def maybe_evaluate(self):
val_reward = self.reward_estimator(self.model, self.val_loader)
self.prompt_logger.log(self.model)
self.val_reward_logger.info(f"VALID: {self.it=}: {val_reward=:.4f}")
self.logger.log_scalar({"val_reward": val_reward}, step=self.it)
self.logger.log_scalar("val_reward", val_reward, step=self.it)
# pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}")
if val_reward > self.best_val_reward:
self.best_val_reward = val_reward
Expand Down

0 comments on commit 2e1d60c

Please sign in to comment.