Skip to content

Commit

Permalink
[BugFix] Fix RLHF (pytorch#1757)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 23, 2023
1 parent 64434df commit 15950d1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
Binary file modified test/assets/tldr_batch.zip
Binary file not shown.
5 changes: 4 additions & 1 deletion test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def tldr_batch_dir(tmp_path_factory):
with zipfile.ZipFile(dataset_path, "r") as zip_ref:
zip_ref.extractall(dest)
yield dest / Path(dataset_path).stem
from torchrl._utils import print_directory_tree

print_directory_tree(dest)


@pytest.mark.skipif(
Expand Down Expand Up @@ -436,7 +439,7 @@ def init_reward_model(device=None):

@staticmethod
def _get_dummy_batch(batch_dir):
return PromptData.from_tensordict(TensorDict.load_memmap(batch_dir))
return TensorDict.load_memmap(batch_dir)

@property
def _model(self):
Expand Down

0 comments on commit 15950d1

Please sign in to comment.