diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py index 2e554f3edb9..6d9e758503d 100644 --- a/examples/rlhf/train.py +++ b/examples/rlhf/train.py @@ -62,7 +62,7 @@ def main(cfg): dtype = cfg.sys.dtype compile_ = cfg.sys.compile - ctx = setup(device=device, dtype=dtype) + ctx = setup(cfg.sys) train_loader = get_dataloader( data_cfg.batch_size, diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py index e16fbf45474..75d3ad86adc 100644 --- a/examples/rlhf/train_reward.py +++ b/examples/rlhf/train_reward.py @@ -65,7 +65,7 @@ def main(cfg): dtype = cfg.sys.dtype compile_ = cfg.sys.compile - ctx = setup(device=device, dtype=dtype) + ctx = setup(cfg.sys) train_loader = get_dataloader( data_cfg.batch_size, diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py index 198b2e72bcb..d9a3ce42d48 100644 --- a/examples/rlhf/utils.py +++ b/examples/rlhf/utils.py @@ -130,7 +130,7 @@ def __init__( ): self.reward_estimator = reward_estimator self.model = model - self.promp_logger = prompt_logger + self.prompt_logger = prompt_logger self.io_cfg = io_cfg self.eval_interval = io_cfg.eval_interval self.log_interval = io_cfg.log_interval @@ -154,7 +154,7 @@ def maybe_evaluate(self): val_reward = self.reward_estimator(self.model, self.val_loader) self.prompt_logger.log(self.model) self.val_reward_logger.info(f"VALID: {self.it=}: {val_reward=:.4f}") - self.logger.log_scalar({"val_reward": val_reward}, step=self.it) + self.logger.log_scalar("val_reward", val_reward, step=self.it) # pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") if val_reward > self.best_val_reward: self.best_val_reward = val_reward