-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] RLHF dataloading #1309
Conversation
torchrl/data/rlhf/dataset.py
Outdated
Args: | ||
split (str): One of ``"train"`` or ``"valid"``. | ||
dataset_name (str): the name or path of the dataset. | ||
pre_tokenization_hook (callable): TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre_tokenization_hook
is called on the Dataset before tokenization and should return a modified Dataset object. The intended use is for carrying out tasks that require modifying the dataset as a whole as opposed to modifying individual datapoints, for example discarding certain datapoints based on a particular condition. Tokenization and other "elementwise" operations on the data are performed by the process function which is mapped over the dataset.
torchrl/data/rlhf/dataset.py
Outdated
max_length (int): the maximum sequence length. | ||
dataset_name (str): the name of the dataset. | ||
make_process_fn (callable): a preprocess function. | ||
pre_tokenization_hook (callable): TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below, this is passed on to the load_dataset
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to commit all of these arrow files etc. to the repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a dummy dataset. I agree it isn't perfect but I'm not sure what choice we have. could we zip it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vision has something of similar taste
https://github.com/pytorch/vision/tree/main/test/assets
I can put it all in a .pt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general I noticed that everything is assuming we use gpt2 models but we are not saying explicitly. E.g. tokenize function
LGTM
torchrl/data/rlhf/dataset.py
Outdated
)[prefix] | ||
|
||
|
||
def load_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only used internally? should we put a "_" in front?
def load_dataset( | |
def _load_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it? IDK if we want to expose it or not. To me it's a well self-contained function properly documented that could be used somewhere else, but I'm happy to make it private if there's no use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in the common case user will want to (possibly) create and load. Hence my proposal to keep this internal and expose only the maybe-renamed other function
torchrl/data/rlhf/dataset.py
Outdated
_has_datasets = importlib.util.find_spec("datasets") is not None | ||
|
||
|
||
def create_or_load_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about calling this just "load_dataset" and maybe add an argument such as "exists" in mkdir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure i see what you mean but the reason I named it that way is that it can either create a dataset for the first time or load it if the memmap tensordict copy is already there. load_dataset
already exists and does just that, nothing more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion/discussion is whether we could rename this to load_dataset without changing it's current behaviour and add an argument to specify the behaviour we want in case it does not exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im fine with renaming it, though I think that users should be informed that this function caches the results and the name says that quite explicitly.
I'm not sure I see what the extra arg should do. If it does not exist we must load it, and if it does we should use the cache. What extra control should we give?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a flag like exists_ok of mkdir that raises if it does not exists or something similar to from_disk
but control whether we want to store cache or not.
I don't have strong opinion on this I am just reasoning if it's really necessary such a long name (even if definitely more explicit) or if it will be enough to say that we cache in the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, a couple of comments on the generated data / zipping + unzipping it.
test/test_rlhf.py
Outdated
lmemmap_save = deepcopy(TensorDict.load_memmap) | ||
mocked_hello = mocker.patch("tensordict.TensorDict.load_memmap") | ||
mocked_hello.side_effect = lmemmap_save | ||
with zipfile.ZipFile(dataset_path, "r") as zip_ref: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing this in every test that loads the data, perhaps we should write a conftest.py
file that extracts on setup, then cleans up on completion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point yeah
But then how do we tell conftest to do it only when the test_rlhf.py script is executed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then how do we tell conftest to do it only when the test_rlhf.py script is executed?
How about instead of conftest.py
we add a fixture that has session scope and yields a directory which is the location of the extracted data? I think the fixture should only be executed if it is used, and it will only be executed once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fixtures are executed in each test, otherwise the same temporary directories would be reused multiple times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the default, but you can change that behaviour
@pytest.fixture(scope="session")
def minidata_dir():
# unzip
# yield path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh now i remember why i did not do that
we have 2 mini datasets but you can't pass a fixture to parametrize, so I couldn't write 2 different fixtures for the tests
The "ugly" solution would be
@pytest.mark.parametrize("dataset", ["comp", "prompt"])
def test_stuff(dataset, comp_fixture, prompt_fixture):
if dataset == "comp":
dataset = comp_fixture
else:
dataset = prompt_fixture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add logic here to zip the generated data so that the committed .zip
files are reproducible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not use python for that but we could do it
Would the tokenization change with another model? |
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
Co-authored-by: Tom Begley <tomcbegley@gmail.com>
If the tokenizer is huggingface-like I think the tokenize function should work provided that we somehow expose: https://github.com/pytorch/rl/pull/1309/files#r1243540647 |
torchrl/data/rlhf/dataset.py
Outdated
num_workers = max(os.cpu_count() // 2, 1) | ||
if excluded_features is None: | ||
excluded_features = {"text", "prompt", "label", "valid_sample"} | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it
Let's parametrize it then
torchrl/data/rlhf/dataset.py
Outdated
return dataset | ||
|
||
|
||
def tokenize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokenizers already have a method tokenize. maybe it's better to be more explicit about what this does
def tokenize( | |
def tokenize_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
cc @tcbegley @apbard
There are 3 TODO i need help with